-
Notifications
You must be signed in to change notification settings - Fork 15k
[MLIR][XeGPU] add xegpu.set_desc_layout transform op #165615
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@llvm/pr-subscribers-mlir Author: Tuomas Kärnä (tkarna) ChangesAdds the first XeGPU transform op, Given a handle to module {
func.func @<!-- -->run(%arg0: memref<4096x4096xf32>) {
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf32> ->
!xegpu.tensor_desc<256x256xf32>
return
}
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @<!-- -->__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["xegpu.create_nd_tdesc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%1 = transform.xegpu.set_desc_layout %0 sg_layout = [8, 8] sg_data = [32, 32] inst_data = [8, 16] : (!transform.any_op) -> !transform.any_op
transform.yield
}
}Applying the transform op produces: func.func @<!-- -->run(%arg0: memref<4096x4096xf32>) {
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf32> ->
!xegpu.tensor_desc<256x256xf32,
#xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>>
return
}For reference, the rationale behind xegpu transform ops is outlined in this RFC document. Patch is 25.93 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/165615.diff 13 Files Affected:
diff --git a/mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt b/mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt
index 9f57627c321fb..cb1e9d01821a2 100644
--- a/mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt
@@ -1,2 +1,3 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
+add_subdirectory(TransformOps)
diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/XeGPU/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..5924606402a02
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS XeGPUTransformOps.td)
+mlir_tablegen(XeGPUTransformOps.h.inc -gen-op-decls)
+mlir_tablegen(XeGPUTransformOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRXeGPUTransformOpsIncGen)
+
+add_mlir_doc(XeGPUTransformOps XeGPUTransformOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h
new file mode 100644
index 0000000000000..dab0c3f35adda
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h
@@ -0,0 +1,28 @@
+//===- XeGPUTransformOps.h - XeGPU transformation ops -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_XEGPU_TRANSFORMOPS_XEGPUTRANSFORMOPS_H
+#define MLIR_DIALECT_XEGPU_TRANSFORMOPS_XEGPUTRANSFORMOPS_H
+
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformTypes.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+
+#define GET_OP_CLASSES
+#include <mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h.inc>
+
+namespace mlir {
+class DialectRegistry;
+
+namespace xegpu {
+void registerTransformDialectExtension(DialectRegistry ®istry);
+} // namespace xegpu
+} // namespace mlir
+
+#endif // MLIR_DIALECT_XEGPU_TRANSFORMOPS_XEGPUTRANSFORMOPS_H
diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
new file mode 100644
index 0000000000000..681b4861f0aeb
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
@@ -0,0 +1,85 @@
+//===- XeGPUTransformOps.td - XeGPU transformation ops -----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef XEGPU_EXTENSION
+#define XEGPU_EXTENSION
+
+include "mlir/Dialect/Transform/IR/TransformAttrs.td"
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Dialect/Transform/IR/TransformTypes.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/OpBase.td"
+
+def TransformAnyParamTypeOrAnyHandle : Type<
+ Or<[TransformHandleTypeInterface.predicate,
+ TransformParamTypeInterface.predicate]>,
+ "transform any param type or any handle type">;
+
+def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
+ AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ TransformOpInterface
+]> {
+
+ let summary = "Set xegpu.layout attribute to an xegpu op result.";
+ let description = [{
+ Given an `xegpu.create_nd_desc` operation, this transform adds `xegpu.layout`
+ attribute to the result tensor descriptor. The layout is defined by the
+ `sg_layout`, `sg_data` and `inst_data` attributes. Returns a handle to the transformed op.
+ }];
+
+ let arguments = (ins
+ TransformHandleTypeInterface : $target,
+ Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_layout,
+ Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_data,
+ Variadic<TransformAnyParamTypeOrAnyHandle> : $inst_data,
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data
+ );
+
+ let results = (outs TransformHandleTypeInterface : $transformed);
+ let builders = [
+ OpBuilder<(ins "Value":$target,
+ "ArrayRef<OpFoldResult>":$mixedSgLayout,
+ "ArrayRef<OpFoldResult>":$mixedSgData,
+ "ArrayRef<OpFoldResult>":$mixedInstData
+ )>,
+ ];
+
+ let assemblyFormat = [{
+ $target
+ `sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
+ `sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
+ `inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)
+ attr-dict `:` functional-type(operands, results)
+ }];
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure apply(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::transform::TransformResults &transformResults,
+ ::mlir::transform::TransformState &state);
+
+ ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgLayout() {
+ Builder b(getContext());
+ return getMixedValues(getStaticSgLayout(), getSgLayout(), b);
+ }
+ ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgData() {
+ Builder b(getContext());
+ return getMixedValues(getStaticSgData(), getSgData(), b);
+ }
+ ::llvm::SmallVector<::mlir::OpFoldResult> getMixedInstData() {
+ Builder b(getContext());
+ return getMixedValues(getStaticInstData(), getInstData(), b);
+ }
+ }];
+}
+
+#endif // XEGPU_EXTENSION
diff --git a/mlir/lib/Dialect/XeGPU/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/CMakeLists.txt
index 31167e6af908b..46b8251a57797 100644
--- a/mlir/lib/Dialect/XeGPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/CMakeLists.txt
@@ -1,3 +1,4 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
add_subdirectory(Utils)
+add_subdirectory(TransformOps)
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..48fe841afaa83
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_dialect_library(MLIRXeGPUTransformOps
+ XeGPUTransformOps.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${PROJECT_SOURCE_DIR}/mlir/Dialect/XeGPU/TransformOps/
+
+ DEPENDS
+ MLIRXeGPUTransformOpsIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRXeGPUDialect
+ MLIRXeGPUTransforms
+ MLIRIR
+ MLIRTransformDialect
+ MLIRFuncDialect
+ MLIRSCFDialect
+)
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
new file mode 100644
index 0000000000000..1875f1050eb03
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -0,0 +1,250 @@
+//===- XeGPUTransformOps.cpp - Implementation of XeGPU transformation ops -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
+#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/Dialect/SCF/Utils/Utils.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformTypes.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/Utils/Utils.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+
+#include <numeric>
+
+#include "llvm/Support/Debug.h"
+#define DEBUG_TYPE "xegpu-transforms"
+
+using namespace mlir;
+using namespace mlir::transform;
+
+class XeGPUTransformDialectExtension
+ : public transform::TransformDialectExtension<
+ XeGPUTransformDialectExtension> {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(XeGPUTransformDialectExtension)
+
+ using Base::Base;
+
+ void init();
+};
+
+void XeGPUTransformDialectExtension::init() {
+ declareGeneratedDialect<scf::SCFDialect>();
+ declareGeneratedDialect<arith::ArithDialect>();
+ declareGeneratedDialect<gpu::GPUDialect>();
+ declareGeneratedDialect<xegpu::XeGPUDialect>();
+
+ registerTransformOps<
+#define GET_OP_LIST
+#include <mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc>
+ >();
+}
+
+#define GET_OP_CLASSES
+#include <mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc>
+
+void mlir::xegpu::registerTransformDialectExtension(DialectRegistry ®istry) {
+ registry.addExtensions<XeGPUTransformDialectExtension>();
+}
+
+/// Assuming that `ofr` is an index attr or a param of index type
+/// or a transform dialect handle mapped to exactly one op
+/// with one index result, get that value and cast it to int type.
+static DiagnosedSilenceableFailure convertMixedValuesToInt(
+ transform::TransformState &state, TransformOpInterface transformOp,
+ SmallVectorImpl<int32_t> &result, ArrayRef<OpFoldResult> ofrs) {
+ for (OpFoldResult ofr : ofrs) {
+ // Attribute case.
+ if (auto attr = dyn_cast<Attribute>(ofr)) {
+ if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
+ result.push_back(intAttr.getInt());
+ } else {
+ return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
+ }
+ continue;
+ }
+
+ // Transform param case.
+ Value transformValue = cast<Value>(ofr);
+ if (isa<TransformParamTypeInterface>(transformValue.getType())) {
+ ArrayRef<Attribute> params = state.getParams(transformValue);
+ if (params.size() != 1)
+ return transformOp.emitDefiniteFailure()
+ << "requires exactly one parameter associated";
+ result.push_back(
+ cast<IntegerAttr>(params.front()).getValue().getSExtValue());
+ continue;
+ }
+
+ // Payload value case.
+ auto payloadOps = state.getPayloadOps(transformValue);
+ if (!llvm::hasSingleElement(payloadOps)) {
+ DiagnosedSilenceableFailure diag =
+ transformOp.emitSilenceableError()
+ << "handle must be mapped to exactly one payload op";
+ diag.attachNote(transformValue.getLoc())
+ << "mapped to " << llvm::range_size(payloadOps) << " payload ops";
+ return diag;
+ }
+
+ Operation *op = *payloadOps.begin();
+ if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
+ DiagnosedSilenceableFailure diag =
+ transformOp.emitSilenceableError()
+ << "payload op must have exactly 1 index result";
+ diag.attachNote(op->getLoc())
+ << "has " << op->getNumResults() << " results";
+ return diag;
+ }
+
+ IntegerAttr intAttr;
+ if (!matchPattern(op->getResult(0), m_Constant(&intAttr)))
+ return transformOp.emitSilenceableError()
+ << "requires param or handle to be the result of a constant like "
+ "op";
+
+ result.push_back(intAttr.getInt());
+ }
+ return DiagnosedSilenceableFailure::success();
+}
+
+/// Create a layout attribute from the given parameters.
+xegpu::LayoutAttr createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
+ ArrayRef<int32_t> sgData,
+ std::optional<ArrayRef<int32_t>> instData) {
+ return xegpu::LayoutAttr::get(
+ ctx, DenseI32ArrayAttr::get(ctx, sgLayout),
+ DenseI32ArrayAttr::get(ctx, sgData),
+ instData ? DenseI32ArrayAttr::get(ctx, instData.value()) : nullptr,
+ /*lane_layout=*/nullptr,
+ /*lane_data=*/nullptr,
+ /*order=*/nullptr);
+}
+
+/// Replace xegpu.create_nd_desc op with a new one with the given layout.
+xegpu::CreateNdDescOp setDescLayout(transform::TransformRewriter &rewriter,
+ xegpu::CreateNdDescOp descOp,
+ xegpu::LayoutAttr layout) {
+ auto oldTensorDesc = descOp.getResult();
+ auto descShapedType = cast<ShapedType>(oldTensorDesc.getType());
+ auto descType = xegpu::TensorDescType::get(
+ descShapedType.getShape(), descShapedType.getElementType(),
+ /*array_length=*/1,
+ /*boundary_check=*/true,
+ /*memory_space=*/xegpu::MemorySpace::Global,
+ /*layout=*/layout);
+
+ rewriter.setInsertionPointAfter(descOp);
+ if (descOp.getMixedOffsets().size() > 0) {
+ auto newDescOp = rewriter.replaceOpWithNewOp<xegpu::CreateNdDescOp>(
+ descOp, descType, descOp.getSource(), descOp.getMixedOffsets(),
+ descOp.getMixedSizes(), descOp.getMixedStrides());
+ return newDescOp;
+ }
+ auto newDescOp = rewriter.replaceOpWithNewOp<xegpu::CreateNdDescOp>(
+ descOp, descType, descOp.getSource(), descOp.getMixedSizes(),
+ descOp.getMixedStrides());
+ return newDescOp;
+}
+
+void transform::SetDescLayoutOp::build(OpBuilder &builder,
+ OperationState &result, Value target,
+ ArrayRef<OpFoldResult> mixedSgLayout,
+ ArrayRef<OpFoldResult> mixedSgData,
+ ArrayRef<OpFoldResult> mixedInstData) {
+ SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
+ SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
+ dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
+ dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData);
+ dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData);
+ build(builder, result, target.getType(),
+ /*target=*/target,
+ /*sg_layout=*/dynamicSgLayout,
+ /*sg_data=*/dynamicSgData,
+ /*inst_data=*/dynamicInstData,
+ /*static_sg_layout=*/staticSgLayout,
+ /*static_sg_data=*/staticSgData,
+ /*static_inst_data=*/staticInstData);
+}
+
+DiagnosedSilenceableFailure
+transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+
+ auto targetOps = state.getPayloadOps(getTarget());
+ if (!llvm::hasSingleElement(targetOps)) {
+ return emitDefiniteFailure() << "requires exactly one targetOp handle (got "
+ << llvm::range_size(targetOps) << ")";
+ }
+ Operation *target = *targetOps.begin();
+
+ auto transformOp = cast<TransformOpInterface>(getOperation());
+
+ SmallVector<int32_t> sgLayout;
+ DiagnosedSilenceableFailure status =
+ convertMixedValuesToInt(state, transformOp, sgLayout, getMixedSgLayout());
+ if (!status.succeeded())
+ return status;
+
+ SmallVector<int32_t> sgData;
+ status =
+ convertMixedValuesToInt(state, transformOp, sgData, getMixedSgData());
+ if (!status.succeeded())
+ return status;
+
+ SmallVector<int32_t> instData;
+ status =
+ convertMixedValuesToInt(state, transformOp, instData, getMixedInstData());
+ if (!status.succeeded())
+ return status;
+
+ // For now only create_nd_desc op is supported.
+ auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target);
+ if (!descOp) {
+ auto diag = emitSilenceableFailure(getLoc())
+ << "Expected a xegpu.create_nd_desc op, but got: "
+ << target->getName();
+ diag.attachNote(target->getLoc()) << "target op";
+ return diag;
+ }
+
+ // Set layout attr in desc op's return type. Replaces old desc op.
+ auto layoutAttr =
+ createLayoutAttr(rewriter.getContext(), sgLayout, sgData, instData);
+ auto newdescOp = setDescLayout(rewriter, descOp, layoutAttr);
+
+ // Map result handles.
+ results.set(cast<OpResult>(getTransformed()), {newdescOp.getOperation()});
+
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::SetDescLayoutOp::getEffects(
+ ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ consumesHandle(getTargetMutable(), effects);
+ onlyReadsHandle(getSgLayoutMutable(), effects);
+ onlyReadsHandle(getSgDataMutable(), effects);
+ onlyReadsHandle(getInstDataMutable(), effects);
+ producesHandle(getOperation()->getOpResults(), effects);
+ modifiesPayload(effects);
+}
diff --git a/mlir/lib/RegisterAllExtensions.cpp b/mlir/lib/RegisterAllExtensions.cpp
index 3839172fd0b42..c857c38df717c 100644
--- a/mlir/lib/RegisterAllExtensions.cpp
+++ b/mlir/lib/RegisterAllExtensions.cpp
@@ -56,6 +56,7 @@
#include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h"
#include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h"
#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
+#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
@@ -112,6 +113,7 @@ void mlir::registerAllExtensions(DialectRegistry ®istry) {
transform::registerSMTExtension(registry);
transform::registerTuneExtension(registry);
vector::registerTransformDialectExtension(registry);
+ xegpu::registerTransformDialectExtension(registry);
arm_neon::registerTransformDialectExtension(registry);
arm_sve::registerTransformDialectExtension(registry);
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 20ed3ab41a0b4..51c75764faf3c 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -322,6 +322,15 @@ declare_mlir_dialect_extension_python_bindings(
"../../include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td"
)
+declare_mlir_dialect_extension_python_bindings(
+ ADD_TO_PARENT MLIRPythonSources.Dialects
+ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+ TD_FILE dialects/XeGPUTransformOps.td
+ SOURCES
+ dialects/transform/xegpu.py
+ DIALECT_NAME transform
+ EXTENSION_NAME xegpu_transform)
+
declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
diff --git a/mlir/python/mlir/dialects/XeGPUTransformOps.td b/mlir/python/mlir/dialects/XeGPUTransformOps.td
new file mode 100644
index 0000000000000..5a5e7b912c4a5
--- /dev/null
+++ b/mlir/python/mlir/dialects/XeGPUTransformOps.td
@@ -0,0 +1,19 @@
+//===---- XeGPUTransformOps.td -----------------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Entry point of the Python bindings generator for the XeGPU transform ops.
+//
+//===----------------------------------------------------------------------===//
+
+
+#ifndef PYTHON_BINDINGS_XEGPU_TRANSFORM_OPS
+#define PYTHON_BINDINGS_XEGPU_TRANSFORM_OPS
+
+include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td"
+
+#endif // PYTHON_BINDINGS_XEGPU_TRANSFORM_OPS
diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py
new file mode 100644
index 0000000000000..720ee4070fbec
--- /dev/null
+++ b/mlir/python/mlir/dialects/transform/xe...
[truncated]
|
|
@llvm/pr-subscribers-mlir-gpu Author: Tuomas Kärnä (tkarna) ChangesAdds the first XeGPU transform op, Given a handle to module {
func.func @<!-- -->run(%arg0: memref<4096x4096xf32>) {
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf32> ->
!xegpu.tensor_desc<256x256xf32>
return
}
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @<!-- -->__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["xegpu.create_nd_tdesc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%1 = transform.xegpu.set_desc_layout %0 sg_layout = [8, 8] sg_data = [32, 32] inst_data = [8, 16] : (!transform.any_op) -> !transform.any_op
transform.yield
}
}Applying the transform op produces: func.func @<!-- -->run(%arg0: memref<4096x4096xf32>) {
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf32> ->
!xegpu.tensor_desc<256x256xf32,
#xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>>
return
}For reference, the rationale behind xegpu transform ops is outlined in this RFC document. Patch is 25.93 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/165615.diff 13 Files Affected:
diff --git a/mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt b/mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt
index 9f57627c321fb..cb1e9d01821a2 100644
--- a/mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt
@@ -1,2 +1,3 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
+add_subdirectory(TransformOps)
diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/XeGPU/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..5924606402a02
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS XeGPUTransformOps.td)
+mlir_tablegen(XeGPUTransformOps.h.inc -gen-op-decls)
+mlir_tablegen(XeGPUTransformOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRXeGPUTransformOpsIncGen)
+
+add_mlir_doc(XeGPUTransformOps XeGPUTransformOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h
new file mode 100644
index 0000000000000..dab0c3f35adda
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h
@@ -0,0 +1,28 @@
+//===- XeGPUTransformOps.h - XeGPU transformation ops -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_XEGPU_TRANSFORMOPS_XEGPUTRANSFORMOPS_H
+#define MLIR_DIALECT_XEGPU_TRANSFORMOPS_XEGPUTRANSFORMOPS_H
+
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformTypes.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+
+#define GET_OP_CLASSES
+#include <mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h.inc>
+
+namespace mlir {
+class DialectRegistry;
+
+namespace xegpu {
+void registerTransformDialectExtension(DialectRegistry ®istry);
+} // namespace xegpu
+} // namespace mlir
+
+#endif // MLIR_DIALECT_XEGPU_TRANSFORMOPS_XEGPUTRANSFORMOPS_H
diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
new file mode 100644
index 0000000000000..681b4861f0aeb
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
@@ -0,0 +1,85 @@
+//===- XeGPUTransformOps.td - XeGPU transformation ops -----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef XEGPU_EXTENSION
+#define XEGPU_EXTENSION
+
+include "mlir/Dialect/Transform/IR/TransformAttrs.td"
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Dialect/Transform/IR/TransformTypes.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/OpBase.td"
+
+def TransformAnyParamTypeOrAnyHandle : Type<
+ Or<[TransformHandleTypeInterface.predicate,
+ TransformParamTypeInterface.predicate]>,
+ "transform any param type or any handle type">;
+
+def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
+ AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ TransformOpInterface
+]> {
+
+ let summary = "Set xegpu.layout attribute to an xegpu op result.";
+ let description = [{
+ Given an `xegpu.create_nd_desc` operation, this transform adds `xegpu.layout`
+ attribute to the result tensor descriptor. The layout is defined by the
+ `sg_layout`, `sg_data` and `inst_data` attributes. Returns a handle to the transformed op.
+ }];
+
+ let arguments = (ins
+ TransformHandleTypeInterface : $target,
+ Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_layout,
+ Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_data,
+ Variadic<TransformAnyParamTypeOrAnyHandle> : $inst_data,
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data
+ );
+
+ let results = (outs TransformHandleTypeInterface : $transformed);
+ let builders = [
+ OpBuilder<(ins "Value":$target,
+ "ArrayRef<OpFoldResult>":$mixedSgLayout,
+ "ArrayRef<OpFoldResult>":$mixedSgData,
+ "ArrayRef<OpFoldResult>":$mixedInstData
+ )>,
+ ];
+
+ let assemblyFormat = [{
+ $target
+ `sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
+ `sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
+ `inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)
+ attr-dict `:` functional-type(operands, results)
+ }];
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure apply(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::transform::TransformResults &transformResults,
+ ::mlir::transform::TransformState &state);
+
+ ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgLayout() {
+ Builder b(getContext());
+ return getMixedValues(getStaticSgLayout(), getSgLayout(), b);
+ }
+ ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgData() {
+ Builder b(getContext());
+ return getMixedValues(getStaticSgData(), getSgData(), b);
+ }
+ ::llvm::SmallVector<::mlir::OpFoldResult> getMixedInstData() {
+ Builder b(getContext());
+ return getMixedValues(getStaticInstData(), getInstData(), b);
+ }
+ }];
+}
+
+#endif // XEGPU_EXTENSION
diff --git a/mlir/lib/Dialect/XeGPU/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/CMakeLists.txt
index 31167e6af908b..46b8251a57797 100644
--- a/mlir/lib/Dialect/XeGPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/CMakeLists.txt
@@ -1,3 +1,4 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
add_subdirectory(Utils)
+add_subdirectory(TransformOps)
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..48fe841afaa83
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_dialect_library(MLIRXeGPUTransformOps
+ XeGPUTransformOps.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${PROJECT_SOURCE_DIR}/mlir/Dialect/XeGPU/TransformOps/
+
+ DEPENDS
+ MLIRXeGPUTransformOpsIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRXeGPUDialect
+ MLIRXeGPUTransforms
+ MLIRIR
+ MLIRTransformDialect
+ MLIRFuncDialect
+ MLIRSCFDialect
+)
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
new file mode 100644
index 0000000000000..1875f1050eb03
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -0,0 +1,250 @@
+//===- XeGPUTransformOps.cpp - Implementation of XeGPU transformation ops -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
+#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/Dialect/SCF/Utils/Utils.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformTypes.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/Utils/Utils.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+
+#include <numeric>
+
+#include "llvm/Support/Debug.h"
+#define DEBUG_TYPE "xegpu-transforms"
+
+using namespace mlir;
+using namespace mlir::transform;
+
+class XeGPUTransformDialectExtension
+ : public transform::TransformDialectExtension<
+ XeGPUTransformDialectExtension> {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(XeGPUTransformDialectExtension)
+
+ using Base::Base;
+
+ void init();
+};
+
+void XeGPUTransformDialectExtension::init() {
+ declareGeneratedDialect<scf::SCFDialect>();
+ declareGeneratedDialect<arith::ArithDialect>();
+ declareGeneratedDialect<gpu::GPUDialect>();
+ declareGeneratedDialect<xegpu::XeGPUDialect>();
+
+ registerTransformOps<
+#define GET_OP_LIST
+#include <mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc>
+ >();
+}
+
+#define GET_OP_CLASSES
+#include <mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc>
+
+void mlir::xegpu::registerTransformDialectExtension(DialectRegistry ®istry) {
+ registry.addExtensions<XeGPUTransformDialectExtension>();
+}
+
+/// Assuming that `ofr` is an index attr or a param of index type
+/// or a transform dialect handle mapped to exactly one op
+/// with one index result, get that value and cast it to int type.
+static DiagnosedSilenceableFailure convertMixedValuesToInt(
+ transform::TransformState &state, TransformOpInterface transformOp,
+ SmallVectorImpl<int32_t> &result, ArrayRef<OpFoldResult> ofrs) {
+ for (OpFoldResult ofr : ofrs) {
+ // Attribute case.
+ if (auto attr = dyn_cast<Attribute>(ofr)) {
+ if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
+ result.push_back(intAttr.getInt());
+ } else {
+ return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
+ }
+ continue;
+ }
+
+ // Transform param case.
+ Value transformValue = cast<Value>(ofr);
+ if (isa<TransformParamTypeInterface>(transformValue.getType())) {
+ ArrayRef<Attribute> params = state.getParams(transformValue);
+ if (params.size() != 1)
+ return transformOp.emitDefiniteFailure()
+ << "requires exactly one parameter associated";
+ result.push_back(
+ cast<IntegerAttr>(params.front()).getValue().getSExtValue());
+ continue;
+ }
+
+ // Payload value case.
+ auto payloadOps = state.getPayloadOps(transformValue);
+ if (!llvm::hasSingleElement(payloadOps)) {
+ DiagnosedSilenceableFailure diag =
+ transformOp.emitSilenceableError()
+ << "handle must be mapped to exactly one payload op";
+ diag.attachNote(transformValue.getLoc())
+ << "mapped to " << llvm::range_size(payloadOps) << " payload ops";
+ return diag;
+ }
+
+ Operation *op = *payloadOps.begin();
+ if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
+ DiagnosedSilenceableFailure diag =
+ transformOp.emitSilenceableError()
+ << "payload op must have exactly 1 index result";
+ diag.attachNote(op->getLoc())
+ << "has " << op->getNumResults() << " results";
+ return diag;
+ }
+
+ IntegerAttr intAttr;
+ if (!matchPattern(op->getResult(0), m_Constant(&intAttr)))
+ return transformOp.emitSilenceableError()
+ << "requires param or handle to be the result of a constant like "
+ "op";
+
+ result.push_back(intAttr.getInt());
+ }
+ return DiagnosedSilenceableFailure::success();
+}
+
+/// Create a layout attribute from the given parameters.
+xegpu::LayoutAttr createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
+ ArrayRef<int32_t> sgData,
+ std::optional<ArrayRef<int32_t>> instData) {
+ return xegpu::LayoutAttr::get(
+ ctx, DenseI32ArrayAttr::get(ctx, sgLayout),
+ DenseI32ArrayAttr::get(ctx, sgData),
+ instData ? DenseI32ArrayAttr::get(ctx, instData.value()) : nullptr,
+ /*lane_layout=*/nullptr,
+ /*lane_data=*/nullptr,
+ /*order=*/nullptr);
+}
+
+/// Replace xegpu.create_nd_desc op with a new one with the given layout.
+xegpu::CreateNdDescOp setDescLayout(transform::TransformRewriter &rewriter,
+ xegpu::CreateNdDescOp descOp,
+ xegpu::LayoutAttr layout) {
+ auto oldTensorDesc = descOp.getResult();
+ auto descShapedType = cast<ShapedType>(oldTensorDesc.getType());
+ auto descType = xegpu::TensorDescType::get(
+ descShapedType.getShape(), descShapedType.getElementType(),
+ /*array_length=*/1,
+ /*boundary_check=*/true,
+ /*memory_space=*/xegpu::MemorySpace::Global,
+ /*layout=*/layout);
+
+ rewriter.setInsertionPointAfter(descOp);
+ if (descOp.getMixedOffsets().size() > 0) {
+ auto newDescOp = rewriter.replaceOpWithNewOp<xegpu::CreateNdDescOp>(
+ descOp, descType, descOp.getSource(), descOp.getMixedOffsets(),
+ descOp.getMixedSizes(), descOp.getMixedStrides());
+ return newDescOp;
+ }
+ auto newDescOp = rewriter.replaceOpWithNewOp<xegpu::CreateNdDescOp>(
+ descOp, descType, descOp.getSource(), descOp.getMixedSizes(),
+ descOp.getMixedStrides());
+ return newDescOp;
+}
+
+void transform::SetDescLayoutOp::build(OpBuilder &builder,
+ OperationState &result, Value target,
+ ArrayRef<OpFoldResult> mixedSgLayout,
+ ArrayRef<OpFoldResult> mixedSgData,
+ ArrayRef<OpFoldResult> mixedInstData) {
+ SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
+ SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
+ dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
+ dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData);
+ dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData);
+ build(builder, result, target.getType(),
+ /*target=*/target,
+ /*sg_layout=*/dynamicSgLayout,
+ /*sg_data=*/dynamicSgData,
+ /*inst_data=*/dynamicInstData,
+ /*static_sg_layout=*/staticSgLayout,
+ /*static_sg_data=*/staticSgData,
+ /*static_inst_data=*/staticInstData);
+}
+
+DiagnosedSilenceableFailure
+transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+
+ auto targetOps = state.getPayloadOps(getTarget());
+ if (!llvm::hasSingleElement(targetOps)) {
+ return emitDefiniteFailure() << "requires exactly one targetOp handle (got "
+ << llvm::range_size(targetOps) << ")";
+ }
+ Operation *target = *targetOps.begin();
+
+ auto transformOp = cast<TransformOpInterface>(getOperation());
+
+ SmallVector<int32_t> sgLayout;
+ DiagnosedSilenceableFailure status =
+ convertMixedValuesToInt(state, transformOp, sgLayout, getMixedSgLayout());
+ if (!status.succeeded())
+ return status;
+
+ SmallVector<int32_t> sgData;
+ status =
+ convertMixedValuesToInt(state, transformOp, sgData, getMixedSgData());
+ if (!status.succeeded())
+ return status;
+
+ SmallVector<int32_t> instData;
+ status =
+ convertMixedValuesToInt(state, transformOp, instData, getMixedInstData());
+ if (!status.succeeded())
+ return status;
+
+ // For now only create_nd_desc op is supported.
+ auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target);
+ if (!descOp) {
+ auto diag = emitSilenceableFailure(getLoc())
+ << "Expected a xegpu.create_nd_desc op, but got: "
+ << target->getName();
+ diag.attachNote(target->getLoc()) << "target op";
+ return diag;
+ }
+
+ // Set layout attr in desc op's return type. Replaces old desc op.
+ auto layoutAttr =
+ createLayoutAttr(rewriter.getContext(), sgLayout, sgData, instData);
+ auto newdescOp = setDescLayout(rewriter, descOp, layoutAttr);
+
+ // Map result handles.
+ results.set(cast<OpResult>(getTransformed()), {newdescOp.getOperation()});
+
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::SetDescLayoutOp::getEffects(
+ ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ consumesHandle(getTargetMutable(), effects);
+ onlyReadsHandle(getSgLayoutMutable(), effects);
+ onlyReadsHandle(getSgDataMutable(), effects);
+ onlyReadsHandle(getInstDataMutable(), effects);
+ producesHandle(getOperation()->getOpResults(), effects);
+ modifiesPayload(effects);
+}
diff --git a/mlir/lib/RegisterAllExtensions.cpp b/mlir/lib/RegisterAllExtensions.cpp
index 3839172fd0b42..c857c38df717c 100644
--- a/mlir/lib/RegisterAllExtensions.cpp
+++ b/mlir/lib/RegisterAllExtensions.cpp
@@ -56,6 +56,7 @@
#include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h"
#include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h"
#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
+#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
@@ -112,6 +113,7 @@ void mlir::registerAllExtensions(DialectRegistry ®istry) {
transform::registerSMTExtension(registry);
transform::registerTuneExtension(registry);
vector::registerTransformDialectExtension(registry);
+ xegpu::registerTransformDialectExtension(registry);
arm_neon::registerTransformDialectExtension(registry);
arm_sve::registerTransformDialectExtension(registry);
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 20ed3ab41a0b4..51c75764faf3c 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -322,6 +322,15 @@ declare_mlir_dialect_extension_python_bindings(
"../../include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td"
)
+declare_mlir_dialect_extension_python_bindings(
+ ADD_TO_PARENT MLIRPythonSources.Dialects
+ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+ TD_FILE dialects/XeGPUTransformOps.td
+ SOURCES
+ dialects/transform/xegpu.py
+ DIALECT_NAME transform
+ EXTENSION_NAME xegpu_transform)
+
declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
diff --git a/mlir/python/mlir/dialects/XeGPUTransformOps.td b/mlir/python/mlir/dialects/XeGPUTransformOps.td
new file mode 100644
index 0000000000000..5a5e7b912c4a5
--- /dev/null
+++ b/mlir/python/mlir/dialects/XeGPUTransformOps.td
@@ -0,0 +1,19 @@
+//===---- XeGPUTransformOps.td -----------------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Entry point of the Python bindings generator for the XeGPU transform ops.
+//
+//===----------------------------------------------------------------------===//
+
+
+#ifndef PYTHON_BINDINGS_XEGPU_TRANSFORM_OPS
+#define PYTHON_BINDINGS_XEGPU_TRANSFORM_OPS
+
+include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td"
+
+#endif // PYTHON_BINDINGS_XEGPU_TRANSFORM_OPS
diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py
new file mode 100644
index 0000000000000..720ee4070fbec
--- /dev/null
+++ b/mlir/python/mlir/dialects/transform/xe...
[truncated]
|
|
✅ With the latest revision this PR passed the Python code formatter. |
44a5e63 to
1bbe829
Compare
|
also cc @dchigarev |
|
Looking at the other proposed transforms, ‘set_op_layout_attr’ appears more generic. |
This transform can only be applied to XeGPU ops that return a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good, it's a great addition to Xe tools 👍
Minor comments here and there
mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
Outdated
Show resolved
Hide resolved
mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h
Outdated
Show resolved
Hide resolved
| `sg_layout`, `sg_data` and `inst_data` attributes. Returns a handle to the transformed op. | ||
| }]; | ||
|
|
||
| let arguments = (ins |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not cover whole layout with lane and order too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could include the rest, lane_layout, lane_data, and order. These are however not needed in the current lowering workflow (lane layout/data only shows up later on in the pipeline, after xegpu-propagate-layout pass). I'd prefer to add features as they are needed (including those 3 variadic operands to all the transform ops would also imply a lot of boilerplate code).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds valid, let's leave it for later.
mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall design and functionally looks good to me 👍
Still, I'm not an expert on transforms or python bindings.
Let's wait for sb more knowledgeable in these areas to double check as well.
Adds the first XeGPU transform op,
xegpu.set_desc_layout.Given a handle to
xegpu.create_nd_tdescop, this transform op adds axegpu.layoutattribute to the returned descriptor:Applying the transform op produces:
sg_layoutis required argument,sg_dataandinst_dataare optional.For reference, the rationale behind xegpu transform ops is outlined in this RFC document.