-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[mlir][tosa] Add a pass to narrow i64 to i32 #165581
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This pass aims to narrow i64 types on TOSA operations to i32. It comes with the following options: - "aggressive-rewrite" - This option is typically able to narrow more values, but may impact numerical behaviour if not used carefully. - "convert-function-boundaries" - If enabled, parameters/ results to/from a function may be narrowed. Otherwise, casts are inserted to preserve the I/O of the function. Currently the non aggressive mode is very limited, targeting an argmax -> cast sequence that has been observed during legalization as well as some data layout operations that can always narrow. Support for more operations will be added in the future. Co-authored-by: Vitalii Shutov <vitalii.shutov@arm.com> Co-authored-by: Shubham <shubham@arm.com> Co-authored-by: Declan Flavin <declan.flavin@arm.com> Signed-off-by: Luke Hutton <luke.hutton@arm.com> Change-Id: Ia8a766c88e6f8e8d019bce6c47114ce3b8a06969
|
@llvm/pr-subscribers-mlir-tosa @llvm/pr-subscribers-mlir Author: Luke Hutton (lhutton1) ChangesThis pass aims to narrow i64 types on TOSA operations to i32. It comes with the following options:
Currently the non aggressive mode is very limited, targeting an argmax -> cast sequence that has been observed during legalization as well as some data layout operations that can always narrow. Support for more operations will be added in the future. Co-authored-by: Vitalii Shutov <vitalii.shutov@arm.com> Patch is 27.56 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/165581.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index 14b00b04ccc18..420e58192b8fd 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -166,4 +166,27 @@ def TosaAttachTarget : Pass<"tosa-attach-target", "ModuleOp"> {
];
}
+def TosaNarrowI64ToI32Pass : Pass<"tosa-narrow-i64-to-i32", "func::FuncOp"> {
+ let summary = "Narrow I64 TOSA operations to I32";
+ let description = [{
+ This pass narrows TOSA operations with 64-bit integer tensor types to
+ 32-bit integer tensor types. This can be useful for backends that do not
+ support the EXT-INT64 extension of TOSA.
+ }];
+
+ let options = [
+ Option<"aggressiveRewrite", "aggressive-rewrite", "bool", "false",
+ "If enabled, all TOSA operations are rewritten, regardless or whether the narrowing"
+ "is safe. This option may lead to data loss if not used carefully.">,
+ Option<"convertFunctionBoundaries", "convert-function-boundaries", "bool", "false",
+ "If enabled, the pass will convert function I/O types as well. Otherwise casts will"
+ "be inserted at the I/O boundaries.">
+ ];
+
+ let dependentDialects = [
+ "func::FuncDialect",
+ "tosa::TosaDialect",
+ ];
+}
+
#endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index 41b338d6e7189..987ce4ed870c9 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRTosaTransforms
TosaTypeConverters.cpp
TosaProfileCompliance.cpp
TosaValidation.cpp
+ TosaNarrowI64ToI32.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa/Transforms
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowI64ToI32.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowI64ToI32.cpp
new file mode 100644
index 0000000000000..ddaf7d8a5e033
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowI64ToI32.cpp
@@ -0,0 +1,310 @@
+//===- TosaNarrowI64ToI32.cpp ---------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass narrows TOSA operations with 64-bit integer tensor types to
+// 32-bit integer tensor types. This can be useful for backends that do not
+// support the EXT-INT64 extension of TOSA. The pass has two options:
+//
+// - aggressive-rewrite - If enabled, all TOSA operations are rewritten,
+// regardless or whether the narrowing is safe. This option may lead to
+// data loss if not used carefully.
+// - convert-function-boundaries - If enabled, the pass will convert function
+// I/O types as well. Otherwise casts will be inserted at the I/O
+// boundaries.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace tosa {
+#define GEN_PASS_DEF_TOSANARROWI64TOI32PASS
+#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
+} // namespace tosa
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::tosa;
+
+namespace {
+
+LogicalResult convertGenericOp(Operation *op, ValueRange operands,
+ ConversionPatternRewriter &rewriter,
+ const TypeConverter *typeConverter) {
+ // Convert types of results
+ SmallVector<Type, 4> newResults;
+ if (failed(typeConverter->convertTypes(op->getResultTypes(), newResults)))
+ return failure();
+
+ // Create a new operation state
+ OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
+ newResults, {}, op->getSuccessors());
+
+ for (const NamedAttribute &namedAttribute : op->getAttrs()) {
+ const Attribute attribute = namedAttribute.getValue();
+
+ // Convert integer attribute type
+ if (const auto intAttr = dyn_cast<IntegerAttr>(attribute)) {
+ const std::optional<Attribute> convertedAttribute =
+ typeConverter->convertTypeAttribute(intAttr.getType(), attribute);
+ state.addAttribute(namedAttribute.getName(), convertedAttribute.value());
+ continue;
+ }
+
+ if (const auto typeAttr = dyn_cast<TypeAttr>(attribute)) {
+ Type type = typeAttr.getValue();
+ const std::optional<Attribute> convertedAttribute =
+ typeConverter->convertTypeAttribute(type, attribute);
+ if (!convertedAttribute)
+ return rewriter.notifyMatchFailure(op,
+ "Failed to convert type attribute.");
+ state.addAttribute(namedAttribute.getName(), convertedAttribute.value());
+ continue;
+ }
+
+ if (const auto denseElementsAttr = dyn_cast<DenseElementsAttr>(attribute)) {
+ const Type type = denseElementsAttr.getType();
+ const std::optional<Attribute> convertedAttribute =
+ typeConverter->convertTypeAttribute(type, denseElementsAttr);
+ if (!convertedAttribute)
+ return rewriter.notifyMatchFailure(
+ op, "Failed to convert dense elements attribute.");
+ state.addAttribute(namedAttribute.getName(), convertedAttribute.value());
+ continue;
+ }
+
+ state.addAttribute(namedAttribute.getName(), attribute);
+ }
+
+ for (Region ®ion : op->getRegions()) {
+ Region *newRegion = state.addRegion();
+ rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
+ if (failed(rewriter.convertRegionTypes(newRegion, *typeConverter)))
+ return failure();
+ }
+
+ Operation *newOp = rewriter.create(state);
+ rewriter.replaceOp(op, newOp->getResults());
+ return success();
+}
+
+// ===========================
+// Aggressive rewrite patterns
+// ===========================
+
+class ConvertGenericOp : public ConversionPattern {
+public:
+ ConvertGenericOp(TypeConverter &typeConverter, MLIRContext *context)
+ : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ if (!isa<tosa::TosaOp>(op))
+ return rewriter.notifyMatchFailure(
+ op,
+ "Support for operations other than TOSA has not been implemented.");
+
+ return convertGenericOp(op, operands, rewriter, typeConverter);
+ }
+};
+
+// ===============================
+// Bounds checked rewrite patterns
+// ===============================
+
+class ConvertArgMaxOpWithBoundsChecking
+ : public OpConversionPattern<tosa::ArgMaxOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(tosa::ArgMaxOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ // Output type can be narrowed based on the size of the axis dimension
+ const int32_t axis = op.getAxis();
+ const auto inputType = dyn_cast<ShapedType>(adaptor.getInput().getType());
+ if (!inputType || !inputType.isStaticDim(axis))
+ return rewriter.notifyMatchFailure(
+ op, "Requires a static axis dimension for bounds checking.");
+ const int64_t axisDim = inputType.getDimSize(axis);
+ if (axisDim >= std::numeric_limits<int32_t>::max())
+ return rewriter.notifyMatchFailure(
+ op, "Axis dimension is too large to narrow safely.");
+
+ const Type resultType = op.getOutput().getType();
+ const Type newResultType = typeConverter->convertType(resultType);
+ rewriter.replaceOpWithNewOp<tosa::ArgMaxOp>(op, newResultType,
+ adaptor.getInput(), axis);
+ return success();
+ }
+};
+
+class ConvertCastOpWithBoundsChecking
+ : public OpConversionPattern<tosa::CastOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(tosa::CastOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ const auto inputType = dyn_cast<ShapedType>(adaptor.getInput().getType());
+ const auto resultType = dyn_cast<ShapedType>(op.getResult().getType());
+ if (!inputType || !resultType)
+ return failure();
+
+ const auto elementInputIntType =
+ dyn_cast<IntegerType>(inputType.getElementType());
+ const auto elementResultIntType =
+ dyn_cast<IntegerType>(resultType.getElementType());
+ if (elementInputIntType && elementResultIntType &&
+ elementInputIntType.getWidth() > elementResultIntType.getWidth())
+ return rewriter.notifyMatchFailure(
+ op, "Narrowing cast may lead to data loss.");
+
+ rewriter.replaceOpWithNewOp<tosa::CastOp>(
+ op, typeConverter->convertType(resultType), adaptor.getInput());
+ return success();
+ }
+};
+
+template <typename OpTy>
+class ConvertTypedOp : public OpConversionPattern<OpTy> {
+ using OpConversionPattern<OpTy>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ return convertGenericOp(op, adaptor.getOperands(), rewriter,
+ this->getTypeConverter());
+ }
+};
+
+struct TosaNarrowI64ToI32
+ : public tosa::impl::TosaNarrowI64ToI32PassBase<TosaNarrowI64ToI32> {
+public:
+ explicit TosaNarrowI64ToI32() = default;
+ explicit TosaNarrowI64ToI32(const TosaNarrowI64ToI32PassOptions &options)
+ : TosaNarrowI64ToI32() {
+ this->aggressiveRewrite = options.aggressiveRewrite;
+ this->convertFunctionBoundaries = options.convertFunctionBoundaries;
+ }
+
+ void runOnOperation() override {
+ MLIRContext *context = &getContext();
+
+ TypeConverter typeConverter;
+ typeConverter.addConversion([](Type type) -> Type { return type; });
+ typeConverter.addConversion([](IntegerType type) -> Type {
+ if (!type.isInteger(64))
+ return type;
+ return IntegerType::get(type.getContext(), 32);
+ });
+ typeConverter.addConversion(
+ [&typeConverter](RankedTensorType type) -> Type {
+ const Type elementType = type.getElementType();
+ if (!elementType.isInteger(64))
+ return type;
+ return RankedTensorType::get(type.getShape(),
+ typeConverter.convertType(elementType));
+ });
+
+ const auto materializeCast = [](OpBuilder &builder, Type resultType,
+ ValueRange inputs, Location loc) -> Value {
+ if (inputs.size() != 1)
+ return Value();
+ return tosa::CastOp::create(builder, loc, resultType, inputs.front());
+ };
+ typeConverter.addSourceMaterialization(materializeCast);
+ typeConverter.addTargetMaterialization(materializeCast);
+
+ typeConverter.addTypeAttributeConversion(
+ [](IntegerType type, IntegerAttr attribute) -> Attribute {
+ const APInt value = attribute.getValue().truncSSat(32);
+ return IntegerAttr::get(IntegerType::get(type.getContext(), 32),
+ value);
+ });
+ typeConverter.addTypeAttributeConversion(
+ [&typeConverter](ShapedType type,
+ DenseIntElementsAttr attr) -> Attribute {
+ const ShapedType newType =
+ cast<ShapedType>(typeConverter.convertType(type));
+ const auto oldElementType = cast<IntegerType>(type.getElementType());
+ const auto newElementType =
+ cast<IntegerType>(newType.getElementType());
+ if (oldElementType.getWidth() == newElementType.getWidth())
+ return attr;
+
+ DenseElementsAttr mapped =
+ attr.mapValues(newElementType, [&](const APInt &v) {
+ return v.truncSSat(newElementType.getWidth());
+ });
+ return mapped;
+ });
+
+ ConversionTarget target(*context);
+ target.addDynamicallyLegalDialect<tosa::TosaDialect>(
+ [&typeConverter](Operation *op) {
+ return typeConverter.isLegal(op->getResultTypes()) &&
+ typeConverter.isLegal(op->getOperandTypes());
+ });
+ if (convertFunctionBoundaries) {
+ target.addDynamicallyLegalOp<func::FuncOp>(
+ [&typeConverter](func::FuncOp op) {
+ return typeConverter.isSignatureLegal(op.getFunctionType()) &&
+ typeConverter.isLegal(&op.getBody());
+ });
+ target.addDynamicallyLegalOp<func::ReturnOp>([](func::ReturnOp op) {
+ const FunctionType funcType =
+ op->getParentOfType<func::FuncOp>().getFunctionType();
+ return llvm::equal(op.getOperandTypes(), funcType.getResults());
+ });
+ } else {
+ target.addDynamicallyLegalOp<func::FuncOp>(
+ [](func::FuncOp op) { return true; });
+ target.addDynamicallyLegalOp<func::ReturnOp>(
+ [](func::ReturnOp op) { return true; });
+ }
+
+ RewritePatternSet patterns(context);
+ if (convertFunctionBoundaries) {
+ populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
+ patterns, typeConverter);
+ populateReturnOpTypeConversionPattern(patterns, typeConverter);
+ }
+ if (aggressiveRewrite) {
+ patterns.add<ConvertGenericOp>(typeConverter, context);
+ } else {
+ // Tensor
+ patterns.add<ConvertArgMaxOpWithBoundsChecking>(typeConverter, context);
+ // Data layout
+ patterns.add<ConvertTypedOp<tosa::ConcatOp>>(typeConverter, context);
+ patterns.add<ConvertTypedOp<tosa::PadOp>>(typeConverter, context);
+ patterns.add<ConvertTypedOp<tosa::ReshapeOp>>(typeConverter, context);
+ patterns.add<ConvertTypedOp<tosa::ReverseOp>>(typeConverter, context);
+ patterns.add<ConvertTypedOp<tosa::SliceOp>>(typeConverter, context);
+ patterns.add<ConvertTypedOp<tosa::TileOp>>(typeConverter, context);
+ patterns.add<ConvertTypedOp<tosa::TransposeOp>>(typeConverter, context);
+ patterns.add<ConvertTypedOp<tosa::IdentityOp>>(typeConverter, context);
+ // Type conversion
+ patterns.add<ConvertCastOpWithBoundsChecking>(typeConverter, context);
+ // Controlflow
+ patterns.add<ConvertTypedOp<tosa::IfOp>>(typeConverter, context);
+ patterns.add<ConvertTypedOp<tosa::WhileOp>>(typeConverter, context);
+ }
+
+ if (failed(
+ applyFullConversion(getOperation(), target, std::move(patterns))))
+ signalPassFailure();
+ }
+};
+
+} // namespace
diff --git a/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32-aggressive.mlir b/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32-aggressive.mlir
new file mode 100644
index 0000000000000..1a36177a37033
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32-aggressive.mlir
@@ -0,0 +1,81 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics -tosa-narrow-i64-to-i32="aggressive-rewrite=1" %s | FileCheck %s --allow-unused-prefixes --check-prefixes=COMMON,DEFAULT
+// RUN: mlir-opt -split-input-file -verify-diagnostics -tosa-narrow-i64-to-i32="aggressive-rewrite=1 convert-function-boundaries=1" %s | FileCheck %s --allow-unused-prefixes --check-prefixes=COMMON,FUNCBOUND
+
+// CHECK-LABEL: test_i64_argmax_large_axis_dim
+func.func @test_i64_argmax_large_axis_dim(%arg0: tensor<1x513x513x2147483650xi8>) -> tensor<1x513x513xi64> {
+ // DEFAULT: tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x2147483650xi8>) -> tensor<1x513x513xi32>
+ %0 = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x2147483650xi8>) -> tensor<1x513x513xi64>
+ return %0 : tensor<1x513x513xi64>
+}
+
+// -----
+
+// CHECK-LABEL: test_convert_input_parameters
+// DEFAULT: %[[IN:.*]]: tensor<1x513x513x3xi64>
+// FUNCBOUND: %[[IN:.*]]: tensor<1x513x513x3xi32>
+func.func @test_convert_input_parameters(%arg0: tensor<1x513x513x3xi64>) -> tensor<1x513x513x3xf32> {
+ // DEFAULT: %[[FUNC_BOUND_CAST:.*]] = tosa.cast %[[IN]] : (tensor<1x513x513x3xi64>) -> tensor<1x513x513x3xi32>
+ // DEFAULT: %[[CAST1:.*]] = tosa.cast %[[FUNC_BOUND_CAST]] : (tensor<1x513x513x3xi32>) -> tensor<1x513x513x3xi32>
+ // FUNCBOUND: %[[CAST1:.*]] = tosa.cast %[[IN]] : (tensor<1x513x513x3xi32>) -> tensor<1x513x513x3xi32>
+ %0 = tosa.cast %arg0 : (tensor<1x513x513x3xi64>) -> tensor<1x513x513x3xi32>
+
+ // COMMON: %[[CAST2:.*]] = tosa.cast %[[CAST1]] : (tensor<1x513x513x3xi32>) -> tensor<1x513x513x3xf32>
+ %1 = tosa.cast %0 : (tensor<1x513x513x3xi32>) -> tensor<1x513x513x3xf32>
+ return %1 : tensor<1x513x513x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_add
+// DEFAULT: %[[IN0:.*]]: tensor<13x21x1xi64>, %[[IN1:.*]]: tensor<13x21x3xi64>
+// FUNCBOUND: %[[IN0:.*]]: tensor<13x21x1xi32>, %[[IN1:.*]]: tensor<13x21x3xi32>
+func.func @test_add(%arg0: tensor<13x21x1xi64>, %arg1: tensor<13x21x3xi64>) -> tensor<13x21x3xi64> {
+ // DEFAULT-DAG: %[[FUNC_BOUND_CAST0:.*]] = tosa.cast %[[IN0]] : (tensor<13x21x1xi64>) -> tensor<13x21x1xi32>
+ // DEFAULT-DAG: %[[FUNC_BOUND_CAST1:.*]] = tosa.cast %[[IN1]] : (tensor<13x21x3xi64>) -> tensor<13x21x3xi32>
+ // DEFAULT: %[[ADD:.*]] = tosa.add %[[FUNC_BOUND_CAST0]], %[[FUNC_BOUND_CAST1]] : (tensor<13x21x1xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi32>
+ // DEFAULT: %[[CAST:.*]] = tosa.cast %[[ADD]] : (tensor<13x21x3xi32>) -> tensor<13x21x3xi64>
+ // DEFAULT: return %[[CAST]] : tensor<13x21x3xi64>
+ // FUNCBOUND: %[[ADD:.*]] = tosa.add %[[IN0]], %[[IN1]] : (tensor<13x21x1xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi32>
+ // FUNCBOUND: return %[[ADD]] : tensor<13x21x3xi32>
+ %0 = tosa.add %arg0, %arg1 : (tensor<13x21x1xi64>, tensor<13x21x3xi64>) -> tensor<13x21x3xi64>
+ return %0 : tensor<13x21x3xi64>
+}
+
+// -----
+
+// CHECK-LABEL: test_regions
+// DEFAULT: %[[IN0:.*]]: tensor<i64>, %[[IN1:.*]]: tensor<i64>
+func.func @test_regions(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i1>) -> tensor<i64> {
+ // DEFAULT-DAG: %[[CAST0:.*]] = tosa.cast %[[IN0]] : (tensor<i64>) -> tensor<i32>
+ // DEFAULT-DAG: %[[CAST1:.*]] = tosa.cast %[[IN1]] : (tensor<i64>) -> tensor<i32>
+ // COMMON: %[[IF_RESULT:.*]] = tosa.cond_if
+ %0 = tosa.cond_if %arg2 : tensor<i1> -> (tensor<i64>) {
+ // DEFAULT: %[[ADD:.*]] = tosa.add %[[CAST0]], %[[CAST1]] : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ // FUNCBOUND: %[[ADD:.*]] = tosa.add %[[IN0]], %[[IN1]] : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ %1 = tosa.add %arg0, %arg1 : (tensor<i64>, tensor<i64>) -> tensor<i64>
+ // COMMON: tosa.yield %[[ADD]] : tensor<i32>
+ tosa.yield %1 : tensor<i64>
+ } else {
+ // DEFAULT: %[[SUB:.*]] = tosa.sub %[[CAST0]], %[[CAST1]] : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ // FUNCBOUND: %[[SUB:.*]] = tosa.sub %[[IN0]], %[[IN1]] : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ %1 = tosa.sub %arg0, %arg1 : (tensor<i64>, tensor<i64>) -> tensor<i64>
+ // COMMON: tosa.yield %[[SUB]] : tensor<i32>
+ tosa.yield %1 : tensor<i64>
+ }
+ // DEFAULT: %[[OUT:.*]] = tosa.cast %[[IF_RESULT]] : (tensor<i32>) -> tensor<i64>
+ // DEFAULT: return %[[OUT]] : tensor<i64>
+ // FUNCBOUND: return %[[IF_RESULT]] : tensor<i32>
+ return %0 : tensor<i64>
+}
+
+// -----
+
+// CHECK-LABEL: test_const
+func.func @test_const() -> tensor<2xi64> {
+ // COMMON: %[[CONST:.*]] = "tosa.const"() <{values = dense<[1, 2]> : tensor<2xi32>}> : () -> tensor<2xi32>
+ %0 = "tosa.const"() <{values = dense<[1, 2]> : tensor<2xi64>}> : () -> tensor<2xi64>
+ // DEFAULT: %[[OUT:.*]] = tosa.cast %[[CONST]] : (tensor<2xi32>) -> tensor<2xi64>
+ // DEFAULT: return %[[OUT]] : tensor<2xi64>
+ // FUNCBOUND: return %[[CONST]] : tensor<2xi32>
+ return %0 : tensor<2xi64>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32.mlir b/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32.mlir
new file mode 100644
index 0000000000000..a14483fcdd7b0
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32.mlir
@@ -0,0 +1,162 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics -tosa-narrow-i64-to-i32="convert-function-boundaries=0" %s | FileCheck %s --allow-unused-p...
[truncated]
|
This pass aims to narrow i64 types on TOSA operations to i32. It can be useful for legalizations from various frameworks. It comes with the following options:
Currently the non aggressive mode is very limited, targeting an argmax -> cast sequence that has been observed during legalization as well as some data layout operations that can always narrow. Support for more operations will be added in the future.
Co-authored-by: Vitalii Shutov vitalii.shutov@arm.com
Co-authored-by: Shubham shubham@arm.com
Co-authored-by: Declan Flavin declan.flavin@arm.com