-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[MLIR][XeGPU] add xegpu.set_desc_layout transform op #165615
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
tkarna
wants to merge
9
commits into
llvm:main
Choose a base branch
from
tkarna:xegpu-tr-ops-set-desc-layout
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+584
−7
Open
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
409ac2f
[mlir][xegpu] add xegpu.set_desc_layout transform op
tkarna 52a3058
address Adam's comments
tkarna 1be8cef
nit comments
tkarna a89af77
more nit comments
tkarna 8543b91
move TransformAnyParamTypeOrAnyHandle to transform dialect
tkarna 6992b14
xegpu: setDescLayout retains TensorDesc BlockTensorDescAttrs
tkarna 05250fb
move extension registration to the end + minor updates
tkarna 3a2fa81
sg_data is now required arg
tkarna 7ea2528
py bindings: target_value -> target_handle
tkarna File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,2 +1,3 @@ | ||
| add_subdirectory(IR) | ||
| add_subdirectory(Transforms) | ||
| add_subdirectory(TransformOps) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| set(LLVM_TARGET_DEFINITIONS XeGPUTransformOps.td) | ||
| mlir_tablegen(XeGPUTransformOps.h.inc -gen-op-decls) | ||
| mlir_tablegen(XeGPUTransformOps.cpp.inc -gen-op-defs) | ||
| add_public_tablegen_target(MLIRXeGPUTransformOpsIncGen) | ||
|
|
||
| add_mlir_doc(XeGPUTransformOps XeGPUTransformOps Dialects/ -gen-op-doc) |
28 changes: 28 additions & 0 deletions
28
mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| //===- XeGPUTransformOps.h - XeGPU transformation ops -----------*- C++ -*-===// | ||
| // | ||
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
| // See https://llvm.org/LICENSE.txt for license information. | ||
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #ifndef MLIR_DIALECT_XEGPU_TRANSFORMOPS_XEGPUTRANSFORMOPS_H | ||
| #define MLIR_DIALECT_XEGPU_TRANSFORMOPS_XEGPUTRANSFORMOPS_H | ||
|
|
||
| #include "mlir/Dialect/Transform/IR/TransformDialect.h" | ||
| #include "mlir/Dialect/Transform/IR/TransformTypes.h" | ||
| #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" | ||
| #include "mlir/Dialect/Utils/StaticValueUtils.h" | ||
|
|
||
| #define GET_OP_CLASSES | ||
| #include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h.inc" | ||
|
|
||
| namespace mlir { | ||
| class DialectRegistry; | ||
|
|
||
| namespace xegpu { | ||
| void registerTransformDialectExtension(DialectRegistry ®istry); | ||
| } // namespace xegpu | ||
| } // namespace mlir | ||
|
|
||
| #endif // MLIR_DIALECT_XEGPU_TRANSFORMOPS_XEGPUTRANSFORMOPS_H |
81 changes: 81 additions & 0 deletions
81
mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,81 @@ | ||
| //===- XeGPUTransformOps.td - XeGPU transformation ops -----*- tablegen -*-===// | ||
| // | ||
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
| // See https://llvm.org/LICENSE.txt for license information. | ||
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #ifndef XEGPU_TRANSFORM_OPS | ||
| #define XEGPU_TRANSFORM_OPS | ||
|
|
||
| include "mlir/Dialect/Transform/IR/TransformAttrs.td" | ||
| include "mlir/Dialect/Transform/IR/TransformDialect.td" | ||
| include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" | ||
| include "mlir/Dialect/Transform/IR/TransformTypes.td" | ||
| include "mlir/Interfaces/SideEffectInterfaces.td" | ||
| include "mlir/IR/OpBase.td" | ||
|
|
||
| def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [ | ||
| AttrSizedOperandSegments, | ||
| DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, | ||
| TransformOpInterface | ||
| ]> { | ||
|
|
||
| let summary = "Set xegpu.layout attribute to a xegpu.create_nd_desc op result."; | ||
| let description = [{ | ||
| Given an `xegpu.create_nd_desc` operation, this transform adds `xegpu.layout` | ||
| attribute to the result tensor descriptor. The layout is defined by the | ||
| `sg_layout`, and `sg_data` and optional `inst_data` attributes. Returns a handle | ||
| to the transformed op. | ||
| }]; | ||
|
|
||
| let arguments = (ins | ||
| TransformHandleTypeInterface : $target, | ||
| Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_layout, | ||
| Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_data, | ||
| Variadic<TransformAnyParamTypeOrAnyHandle> : $inst_data, | ||
adam-smnk marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout, | ||
| DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data, | ||
| DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data | ||
| ); | ||
|
|
||
| let results = (outs TransformHandleTypeInterface : $transformed); | ||
| let builders = [ | ||
| OpBuilder<(ins "Value":$target, | ||
| "ArrayRef<OpFoldResult>":$mixedSgLayout, | ||
| "ArrayRef<OpFoldResult>":$mixedSgData, | ||
| "ArrayRef<OpFoldResult>":$mixedInstData | ||
| )>, | ||
| ]; | ||
|
|
||
| let assemblyFormat = [{ | ||
| $target | ||
| `sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout) | ||
| `sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data) | ||
| (`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)? | ||
| attr-dict `:` functional-type(operands, results) | ||
| }]; | ||
|
|
||
| let extraClassDeclaration = [{ | ||
| ::mlir::DiagnosedSilenceableFailure apply( | ||
| ::mlir::transform::TransformRewriter &rewriter, | ||
| ::mlir::transform::TransformResults &transformResults, | ||
| ::mlir::transform::TransformState &state); | ||
|
|
||
| ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgLayout() { | ||
| Builder b(getContext()); | ||
| return getMixedValues(getStaticSgLayout(), getSgLayout(), b); | ||
| } | ||
| ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgData() { | ||
| Builder b(getContext()); | ||
| return getMixedValues(getStaticSgData(), getSgData(), b); | ||
| } | ||
| ::llvm::SmallVector<::mlir::OpFoldResult> getMixedInstData() { | ||
| Builder b(getContext()); | ||
| return getMixedValues(getStaticInstData(), getInstData(), b); | ||
| } | ||
| }]; | ||
| } | ||
|
|
||
| #endif // XEGPU_TRANSFORM_OPS | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,4 @@ | ||
| add_subdirectory(IR) | ||
| add_subdirectory(Transforms) | ||
| add_subdirectory(Utils) | ||
| add_subdirectory(TransformOps) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,17 @@ | ||
| add_mlir_dialect_library(MLIRXeGPUTransformOps | ||
| XeGPUTransformOps.cpp | ||
|
|
||
| ADDITIONAL_HEADER_DIRS | ||
| ${PROJECT_SOURCE_DIR}/mlir/Dialect/XeGPU/TransformOps/ | ||
|
|
||
| DEPENDS | ||
| MLIRXeGPUTransformOpsIncGen | ||
|
|
||
| LINK_LIBS PUBLIC | ||
| MLIRXeGPUDialect | ||
| MLIRXeGPUTransforms | ||
| MLIRIR | ||
| MLIRTransformDialect | ||
| MLIRFuncDialect | ||
| MLIRSCFDialect | ||
| ) |
225 changes: 225 additions & 0 deletions
225
mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,225 @@ | ||
| //===- XeGPUTransformOps.cpp - Implementation of XeGPU transformation ops -===// | ||
| // | ||
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
| // See https://llvm.org/LICENSE.txt for license information. | ||
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h" | ||
| #include "mlir/Dialect/SCF/IR/SCF.h" | ||
| #include "mlir/Dialect/XeGPU/IR/XeGPU.h" | ||
| #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" | ||
|
|
||
| #include <optional> | ||
|
|
||
| using namespace mlir; | ||
| using namespace mlir::transform; | ||
|
|
||
| /// Assuming that `ofr` is an index attr or a param of index type | ||
| /// or a transform dialect handle mapped to exactly one op | ||
| /// with one index result, get that value and cast it to int type. | ||
| static DiagnosedSilenceableFailure convertMixedValuesToInt( | ||
| transform::TransformState &state, TransformOpInterface transformOp, | ||
| SmallVectorImpl<int32_t> &result, ArrayRef<OpFoldResult> ofrs) { | ||
| for (OpFoldResult ofr : ofrs) { | ||
| // Attribute case. | ||
| if (auto attr = dyn_cast<Attribute>(ofr)) { | ||
| if (auto intAttr = dyn_cast<IntegerAttr>(attr)) { | ||
adam-smnk marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| result.push_back(intAttr.getInt()); | ||
| continue; | ||
| } | ||
| return transformOp.emitDefiniteFailure() << "expected IntegerAttr"; | ||
| } | ||
|
|
||
| // Transform param case. | ||
| Value transformValue = cast<Value>(ofr); | ||
| if (isa<TransformParamTypeInterface>(transformValue.getType())) { | ||
| ArrayRef<Attribute> params = state.getParams(transformValue); | ||
| if (params.size() != 1) | ||
| return transformOp.emitDefiniteFailure() | ||
| << "requires exactly one parameter associated"; | ||
| result.push_back( | ||
| cast<IntegerAttr>(params.front()).getValue().getSExtValue()); | ||
| continue; | ||
| } | ||
|
|
||
| // Payload value case. | ||
| auto payloadOps = state.getPayloadOps(transformValue); | ||
| if (!llvm::hasSingleElement(payloadOps)) { | ||
| DiagnosedSilenceableFailure diag = | ||
| transformOp.emitSilenceableError() | ||
| << "handle must be mapped to exactly one payload op"; | ||
| diag.attachNote(transformValue.getLoc()) | ||
| << "mapped to " << llvm::range_size(payloadOps) << " payload ops"; | ||
| return diag; | ||
| } | ||
|
|
||
| Operation *op = *payloadOps.begin(); | ||
| if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) { | ||
| DiagnosedSilenceableFailure diag = | ||
| transformOp.emitSilenceableError() | ||
| << "payload op must have exactly 1 index result"; | ||
| diag.attachNote(op->getLoc()) | ||
| << "has " << op->getNumResults() << " results"; | ||
| return diag; | ||
| } | ||
|
|
||
| IntegerAttr intAttr; | ||
| if (!matchPattern(op->getResult(0), m_Constant(&intAttr))) | ||
| return transformOp.emitSilenceableError() | ||
| << "requires param or handle to be the result of a constant like " | ||
| "op"; | ||
|
|
||
| result.push_back(intAttr.getInt()); | ||
| } | ||
| return DiagnosedSilenceableFailure::success(); | ||
| } | ||
|
|
||
| /// Create a layout attribute from the given parameters. | ||
| static xegpu::LayoutAttr | ||
| createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout, | ||
| ArrayRef<int32_t> sgData, | ||
| std::optional<ArrayRef<int32_t>> instData) { | ||
| return xegpu::LayoutAttr::get( | ||
| ctx, DenseI32ArrayAttr::get(ctx, sgLayout), | ||
| DenseI32ArrayAttr::get(ctx, sgData), | ||
| instData ? DenseI32ArrayAttr::get(ctx, instData.value()) : nullptr, | ||
| /*lane_layout=*/nullptr, | ||
| /*lane_data=*/nullptr, | ||
| /*order=*/nullptr); | ||
| } | ||
|
|
||
| /// Replace xegpu.create_nd_desc op with a new one with the given layout. | ||
| static xegpu::CreateNdDescOp | ||
| setDescLayout(transform::TransformRewriter &rewriter, | ||
| xegpu::CreateNdDescOp descOp, xegpu::LayoutAttr layout) { | ||
| auto oldTensorDesc = descOp.getType(); | ||
| auto descType = xegpu::TensorDescType::get( | ||
| oldTensorDesc.getShape(), oldTensorDesc.getElementType(), | ||
| /*array_length=*/oldTensorDesc.getArrayLength(), | ||
| /*boundary_check=*/oldTensorDesc.getBoundaryCheck(), | ||
| /*memory_space=*/oldTensorDesc.getMemorySpace(), | ||
| /*layout=*/layout); | ||
|
|
||
| rewriter.setInsertionPointAfter(descOp); | ||
| assert(descOp.getMixedOffsets().size() == 0 && | ||
| "create desc op with offsets is not supported"); | ||
| auto newDescOp = rewriter.replaceOpWithNewOp<xegpu::CreateNdDescOp>( | ||
| descOp, descType, descOp.getSource(), descOp.getMixedSizes(), | ||
| descOp.getMixedStrides()); | ||
| return newDescOp; | ||
| } | ||
|
|
||
| void transform::SetDescLayoutOp::build(OpBuilder &builder, | ||
| OperationState &result, Value target, | ||
| ArrayRef<OpFoldResult> mixedSgLayout, | ||
| ArrayRef<OpFoldResult> mixedSgData, | ||
| ArrayRef<OpFoldResult> mixedInstData) { | ||
| SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData; | ||
| SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData; | ||
| dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout); | ||
| dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData); | ||
| dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData); | ||
| build(builder, result, target.getType(), | ||
| /*target=*/target, | ||
| /*sg_layout=*/dynamicSgLayout, | ||
| /*sg_data=*/dynamicSgData, | ||
| /*inst_data=*/dynamicInstData, | ||
| /*static_sg_layout=*/staticSgLayout, | ||
| /*static_sg_data=*/staticSgData, | ||
| /*static_inst_data=*/staticInstData); | ||
| } | ||
|
|
||
| DiagnosedSilenceableFailure | ||
| transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter, | ||
| transform::TransformResults &results, | ||
| transform::TransformState &state) { | ||
| auto targetOps = state.getPayloadOps(getTarget()); | ||
| if (!llvm::hasSingleElement(targetOps)) { | ||
| return emitDefiniteFailure() << "requires exactly one targetOp handle (got " | ||
| << llvm::range_size(targetOps) << ")"; | ||
| } | ||
| Operation *target = *targetOps.begin(); | ||
|
|
||
| SmallVector<int32_t> sgLayout; | ||
| DiagnosedSilenceableFailure status = | ||
| convertMixedValuesToInt(state, (*this), sgLayout, getMixedSgLayout()); | ||
| if (!status.succeeded()) | ||
| return status; | ||
|
|
||
| SmallVector<int32_t> sgData; | ||
| status = convertMixedValuesToInt(state, (*this), sgData, getMixedSgData()); | ||
| if (!status.succeeded()) | ||
| return status; | ||
|
|
||
| SmallVector<int32_t> instData; | ||
| status = | ||
| convertMixedValuesToInt(state, (*this), instData, getMixedInstData()); | ||
| if (!status.succeeded()) | ||
| return status; | ||
| auto maybeInstData = instData.empty() | ||
| ? std::nullopt | ||
| : std::optional<ArrayRef<int32_t>>(instData); | ||
|
|
||
| // For now only create_nd_desc op is supported. | ||
adam-smnk marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target); | ||
| if (!descOp) { | ||
| auto diag = emitSilenceableFailure(getLoc()) | ||
| << "Expected a xegpu.create_nd_desc op, but got: " | ||
| << target->getName(); | ||
| diag.attachNote(target->getLoc()) << "target op"; | ||
| return diag; | ||
| } | ||
|
|
||
| // Set layout attr in desc op's return type. Replaces old desc op. | ||
| auto layoutAttr = | ||
| createLayoutAttr(rewriter.getContext(), sgLayout, sgData, maybeInstData); | ||
| auto newdescOp = setDescLayout(rewriter, descOp, layoutAttr); | ||
|
|
||
| // Map result handles. | ||
| results.set(cast<OpResult>(getTransformed()), {newdescOp.getOperation()}); | ||
|
|
||
| return DiagnosedSilenceableFailure::success(); | ||
| } | ||
|
|
||
| void transform::SetDescLayoutOp::getEffects( | ||
| ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { | ||
| consumesHandle(getTargetMutable(), effects); | ||
| onlyReadsHandle(getSgLayoutMutable(), effects); | ||
| onlyReadsHandle(getSgDataMutable(), effects); | ||
| onlyReadsHandle(getInstDataMutable(), effects); | ||
| producesHandle(getOperation()->getOpResults(), effects); | ||
| modifiesPayload(effects); | ||
| } | ||
|
|
||
| namespace { | ||
| class XeGPUTransformDialectExtension | ||
| : public transform::TransformDialectExtension< | ||
| XeGPUTransformDialectExtension> { | ||
| public: | ||
| MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(XeGPUTransformDialectExtension) | ||
|
|
||
| using Base::Base; | ||
|
|
||
| void init(); | ||
| }; | ||
|
|
||
| void XeGPUTransformDialectExtension::init() { | ||
| declareGeneratedDialect<scf::SCFDialect>(); | ||
| declareGeneratedDialect<arith::ArithDialect>(); | ||
| declareGeneratedDialect<xegpu::XeGPUDialect>(); | ||
|
|
||
| registerTransformOps< | ||
| #define GET_OP_LIST | ||
| #include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc" | ||
| >(); | ||
| } | ||
| } // namespace | ||
|
|
||
| #define GET_OP_CLASSES | ||
| #include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc" | ||
|
|
||
| void mlir::xegpu::registerTransformDialectExtension(DialectRegistry ®istry) { | ||
| registry.addExtensions<XeGPUTransformDialectExtension>(); | ||
| } | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.