From 7a388c1846fce4060a7f4fac75c9f948a0aea872 Mon Sep 17 00:00:00 2001 From: Jordan Rupprecht Date: Wed, 29 Oct 2025 11:31:54 -0700 Subject: [PATCH] [mlir][sparse][NFC] Include sparse emit strategy in wrapping iterator --- .../Transforms/Utils/SparseTensorIterator.cpp | 18 +++++++++++++----- .../Transforms/Utils/SparseTensorIterator.h | 6 +++++- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp index 46d0baac58f06..61b5ad600a16e 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp @@ -504,6 +504,14 @@ class SimpleWrapIterator : public SparseIterator { unsigned extraCursorVal = 0) : SparseIterator(kind, *wrap, extraCursorVal), wrap(std::move(wrap)) {} + void setSparseEmitStrategy(SparseEmitStrategy strategy) override { + wrap->setSparseEmitStrategy(strategy); + } + + SparseEmitStrategy getSparseEmitStrategy() const override { + return wrap->getSparseEmitStrategy(); + } + SmallVector getCursorValTypes(OpBuilder &b) const override { return wrap->getCursorValTypes(b); } @@ -979,7 +987,7 @@ class SubSectIterator : public SparseIterator { void SparseIterator::genInit(OpBuilder &b, Location l, const SparseIterator *p) { - if (emitStrategy == SparseEmitStrategy::kDebugInterface) { + if (getSparseEmitStrategy() == SparseEmitStrategy::kDebugInterface) { std::string prefix = getDebugInterfacePrefix(); Operation *begin = b.create(l, b.getStringAttr(prefix + ".begin"), {}, getCursorValTypes(b)); @@ -994,7 +1002,7 @@ void SparseIterator::genInit(OpBuilder &b, Location l, } Value SparseIterator::genNotEnd(OpBuilder &b, Location l) { - if (emitStrategy == SparseEmitStrategy::kDebugInterface) { + if (getSparseEmitStrategy() == SparseEmitStrategy::kDebugInterface) { std::string prefix = getDebugInterfacePrefix(); Operation *notEnd = b.create(l, b.getStringAttr(prefix + ".not_end"), getCursor(), b.getI1Type()); @@ -1005,7 +1013,7 @@ Value SparseIterator::genNotEnd(OpBuilder &b, Location l) { } void SparseIterator::locate(OpBuilder &b, Location l, Value crd) { - if (emitStrategy == SparseEmitStrategy::kDebugInterface) { + if (getSparseEmitStrategy() == SparseEmitStrategy::kDebugInterface) { std::string prefix = getDebugInterfacePrefix(); SmallVector args = getCursor(); args.push_back(crd); @@ -1019,7 +1027,7 @@ void SparseIterator::locate(OpBuilder &b, Location l, Value crd) { } Value SparseIterator::deref(OpBuilder &b, Location l) { - if (emitStrategy == SparseEmitStrategy::kDebugInterface) { + if (getSparseEmitStrategy() == SparseEmitStrategy::kDebugInterface) { std::string prefix = getDebugInterfacePrefix(); SmallVector args = getCursor(); Operation *deref = b.create(l, b.getStringAttr(prefix + ".deref"), @@ -1032,7 +1040,7 @@ Value SparseIterator::deref(OpBuilder &b, Location l) { ValueRange SparseIterator::forward(OpBuilder &b, Location l) { assert(!randomAccessible()); - if (emitStrategy == SparseEmitStrategy::kDebugInterface) { + if (getSparseEmitStrategy() == SparseEmitStrategy::kDebugInterface) { std::string prefix = getDebugInterfacePrefix(); Operation *next = b.create(l, b.getStringAttr(prefix + ".next"), getCursor(), getCursorValTypes(b)); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h index 642cb1afa156b..3636f3f01adb5 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h @@ -177,10 +177,14 @@ class SparseIterator { public: virtual ~SparseIterator() = default; - void setSparseEmitStrategy(SparseEmitStrategy strategy) { + virtual void setSparseEmitStrategy(SparseEmitStrategy strategy) { emitStrategy = strategy; } + virtual SparseEmitStrategy getSparseEmitStrategy() const { + return emitStrategy; + } + virtual std::string getDebugInterfacePrefix() const = 0; virtual SmallVector getCursorValTypes(OpBuilder &b) const = 0;