Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,6 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/IR/OpBase.td"
include "mlir/IR/RegionKindInterface.td"

// This is roughly similar to OpFoldResult assuming the handle produces a single
// value in the payload IR.
def TransformAnyParamTypeOrAnyHandle : Type<
Or<[TransformHandleTypeInterface.predicate,
TransformParamTypeInterface.predicate]>,
"transform any param type or any handle type">;

//===----------------------------------------------------------------------===//
// Apply...PatternsOp
//===----------------------------------------------------------------------===//
Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,9 @@ def TransformAnyHandle : Type<
TransformValueHandleTypeInterface.predicate]>,
"transform operation or value handle">;

def TransformAnyParamTypeOrAnyHandle : Type<
Or<[TransformHandleTypeInterface.predicate,
TransformParamTypeInterface.predicate]>,
"transform any param type or any handle type">;

#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMTYPES
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
add_subdirectory(TransformOps)
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/XeGPU/TransformOps/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
set(LLVM_TARGET_DEFINITIONS XeGPUTransformOps.td)
mlir_tablegen(XeGPUTransformOps.h.inc -gen-op-decls)
mlir_tablegen(XeGPUTransformOps.cpp.inc -gen-op-defs)
add_public_tablegen_target(MLIRXeGPUTransformOpsIncGen)

add_mlir_doc(XeGPUTransformOps XeGPUTransformOps Dialects/ -gen-op-doc)
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
//===- XeGPUTransformOps.h - XeGPU transformation ops -----------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_XEGPU_TRANSFORMOPS_XEGPUTRANSFORMOPS_H
#define MLIR_DIALECT_XEGPU_TRANSFORMOPS_XEGPUTRANSFORMOPS_H

#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"

#define GET_OP_CLASSES
#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h.inc"

namespace mlir {
class DialectRegistry;

namespace xegpu {
void registerTransformDialectExtension(DialectRegistry &registry);
} // namespace xegpu
} // namespace mlir

#endif // MLIR_DIALECT_XEGPU_TRANSFORMOPS_XEGPUTRANSFORMOPS_H
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
//===- XeGPUTransformOps.td - XeGPU transformation ops -----*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef XEGPU_TRANSFORM_OPS
#define XEGPU_TRANSFORM_OPS

include "mlir/Dialect/Transform/IR/TransformAttrs.td"
include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
include "mlir/Dialect/Transform/IR/TransformTypes.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"

def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
AttrSizedOperandSegments,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
TransformOpInterface
]> {

let summary = "Set xegpu.layout attribute to a xegpu.create_nd_desc op result.";
let description = [{
Given an `xegpu.create_nd_desc` operation, this transform adds `xegpu.layout`
attribute to the result tensor descriptor. The layout is defined by the
`sg_layout`, and `sg_data` and optional `inst_data` attributes. Returns a handle
to the transformed op.
}];

let arguments = (ins
TransformHandleTypeInterface : $target,
Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_layout,
Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_data,
Variadic<TransformAnyParamTypeOrAnyHandle> : $inst_data,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data
);

let results = (outs TransformHandleTypeInterface : $transformed);
let builders = [
OpBuilder<(ins "Value":$target,
"ArrayRef<OpFoldResult>":$mixedSgLayout,
"ArrayRef<OpFoldResult>":$mixedSgData,
"ArrayRef<OpFoldResult>":$mixedInstData
)>,
];

let assemblyFormat = [{
$target
`sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
`sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
(`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)?
attr-dict `:` functional-type(operands, results)
}];

let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure apply(
::mlir::transform::TransformRewriter &rewriter,
::mlir::transform::TransformResults &transformResults,
::mlir::transform::TransformState &state);

::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgLayout() {
Builder b(getContext());
return getMixedValues(getStaticSgLayout(), getSgLayout(), b);
}
::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgData() {
Builder b(getContext());
return getMixedValues(getStaticSgData(), getSgData(), b);
}
::llvm::SmallVector<::mlir::OpFoldResult> getMixedInstData() {
Builder b(getContext());
return getMixedValues(getStaticInstData(), getInstData(), b);
}
}];
}

#endif // XEGPU_TRANSFORM_OPS
1 change: 1 addition & 0 deletions mlir/lib/Dialect/XeGPU/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
add_subdirectory(Utils)
add_subdirectory(TransformOps)
17 changes: 17 additions & 0 deletions mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
add_mlir_dialect_library(MLIRXeGPUTransformOps
XeGPUTransformOps.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/mlir/Dialect/XeGPU/TransformOps/

DEPENDS
MLIRXeGPUTransformOpsIncGen

LINK_LIBS PUBLIC
MLIRXeGPUDialect
MLIRXeGPUTransforms
MLIRIR
MLIRTransformDialect
MLIRFuncDialect
MLIRSCFDialect
)
225 changes: 225 additions & 0 deletions mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
//===- XeGPUTransformOps.cpp - Implementation of XeGPU transformation ops -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"

#include <optional>

using namespace mlir;
using namespace mlir::transform;

/// Assuming that `ofr` is an index attr or a param of index type
/// or a transform dialect handle mapped to exactly one op
/// with one index result, get that value and cast it to int type.
static DiagnosedSilenceableFailure convertMixedValuesToInt(
transform::TransformState &state, TransformOpInterface transformOp,
SmallVectorImpl<int32_t> &result, ArrayRef<OpFoldResult> ofrs) {
for (OpFoldResult ofr : ofrs) {
// Attribute case.
if (auto attr = dyn_cast<Attribute>(ofr)) {
if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
result.push_back(intAttr.getInt());
continue;
}
return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
}

// Transform param case.
Value transformValue = cast<Value>(ofr);
if (isa<TransformParamTypeInterface>(transformValue.getType())) {
ArrayRef<Attribute> params = state.getParams(transformValue);
if (params.size() != 1)
return transformOp.emitDefiniteFailure()
<< "requires exactly one parameter associated";
result.push_back(
cast<IntegerAttr>(params.front()).getValue().getSExtValue());
continue;
}

// Payload value case.
auto payloadOps = state.getPayloadOps(transformValue);
if (!llvm::hasSingleElement(payloadOps)) {
DiagnosedSilenceableFailure diag =
transformOp.emitSilenceableError()
<< "handle must be mapped to exactly one payload op";
diag.attachNote(transformValue.getLoc())
<< "mapped to " << llvm::range_size(payloadOps) << " payload ops";
return diag;
}

Operation *op = *payloadOps.begin();
if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
DiagnosedSilenceableFailure diag =
transformOp.emitSilenceableError()
<< "payload op must have exactly 1 index result";
diag.attachNote(op->getLoc())
<< "has " << op->getNumResults() << " results";
return diag;
}

IntegerAttr intAttr;
if (!matchPattern(op->getResult(0), m_Constant(&intAttr)))
return transformOp.emitSilenceableError()
<< "requires param or handle to be the result of a constant like "
"op";

result.push_back(intAttr.getInt());
}
return DiagnosedSilenceableFailure::success();
}

/// Create a layout attribute from the given parameters.
static xegpu::LayoutAttr
createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
ArrayRef<int32_t> sgData,
std::optional<ArrayRef<int32_t>> instData) {
return xegpu::LayoutAttr::get(
ctx, DenseI32ArrayAttr::get(ctx, sgLayout),
DenseI32ArrayAttr::get(ctx, sgData),
instData ? DenseI32ArrayAttr::get(ctx, instData.value()) : nullptr,
/*lane_layout=*/nullptr,
/*lane_data=*/nullptr,
/*order=*/nullptr);
}

/// Replace xegpu.create_nd_desc op with a new one with the given layout.
static xegpu::CreateNdDescOp
setDescLayout(transform::TransformRewriter &rewriter,
xegpu::CreateNdDescOp descOp, xegpu::LayoutAttr layout) {
auto oldTensorDesc = descOp.getType();
auto descType = xegpu::TensorDescType::get(
oldTensorDesc.getShape(), oldTensorDesc.getElementType(),
/*array_length=*/oldTensorDesc.getArrayLength(),
/*boundary_check=*/oldTensorDesc.getBoundaryCheck(),
/*memory_space=*/oldTensorDesc.getMemorySpace(),
/*layout=*/layout);

rewriter.setInsertionPointAfter(descOp);
assert(descOp.getMixedOffsets().size() == 0 &&
"create desc op with offsets is not supported");
auto newDescOp = rewriter.replaceOpWithNewOp<xegpu::CreateNdDescOp>(
descOp, descType, descOp.getSource(), descOp.getMixedSizes(),
descOp.getMixedStrides());
return newDescOp;
}

void transform::SetDescLayoutOp::build(OpBuilder &builder,
OperationState &result, Value target,
ArrayRef<OpFoldResult> mixedSgLayout,
ArrayRef<OpFoldResult> mixedSgData,
ArrayRef<OpFoldResult> mixedInstData) {
SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData);
dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData);
build(builder, result, target.getType(),
/*target=*/target,
/*sg_layout=*/dynamicSgLayout,
/*sg_data=*/dynamicSgData,
/*inst_data=*/dynamicInstData,
/*static_sg_layout=*/staticSgLayout,
/*static_sg_data=*/staticSgData,
/*static_inst_data=*/staticInstData);
}

DiagnosedSilenceableFailure
transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
auto targetOps = state.getPayloadOps(getTarget());
if (!llvm::hasSingleElement(targetOps)) {
return emitDefiniteFailure() << "requires exactly one targetOp handle (got "
<< llvm::range_size(targetOps) << ")";
}
Operation *target = *targetOps.begin();

SmallVector<int32_t> sgLayout;
DiagnosedSilenceableFailure status =
convertMixedValuesToInt(state, (*this), sgLayout, getMixedSgLayout());
if (!status.succeeded())
return status;

SmallVector<int32_t> sgData;
status = convertMixedValuesToInt(state, (*this), sgData, getMixedSgData());
if (!status.succeeded())
return status;

SmallVector<int32_t> instData;
status =
convertMixedValuesToInt(state, (*this), instData, getMixedInstData());
if (!status.succeeded())
return status;
auto maybeInstData = instData.empty()
? std::nullopt
: std::optional<ArrayRef<int32_t>>(instData);

// For now only create_nd_desc op is supported.
auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target);
if (!descOp) {
auto diag = emitSilenceableFailure(getLoc())
<< "Expected a xegpu.create_nd_desc op, but got: "
<< target->getName();
diag.attachNote(target->getLoc()) << "target op";
return diag;
}

// Set layout attr in desc op's return type. Replaces old desc op.
auto layoutAttr =
createLayoutAttr(rewriter.getContext(), sgLayout, sgData, maybeInstData);
auto newdescOp = setDescLayout(rewriter, descOp, layoutAttr);

// Map result handles.
results.set(cast<OpResult>(getTransformed()), {newdescOp.getOperation()});

return DiagnosedSilenceableFailure::success();
}

void transform::SetDescLayoutOp::getEffects(
::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
consumesHandle(getTargetMutable(), effects);
onlyReadsHandle(getSgLayoutMutable(), effects);
onlyReadsHandle(getSgDataMutable(), effects);
onlyReadsHandle(getInstDataMutable(), effects);
producesHandle(getOperation()->getOpResults(), effects);
modifiesPayload(effects);
}

namespace {
class XeGPUTransformDialectExtension
: public transform::TransformDialectExtension<
XeGPUTransformDialectExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(XeGPUTransformDialectExtension)

using Base::Base;

void init();
};

void XeGPUTransformDialectExtension::init() {
declareGeneratedDialect<scf::SCFDialect>();
declareGeneratedDialect<arith::ArithDialect>();
declareGeneratedDialect<xegpu::XeGPUDialect>();

registerTransformOps<
#define GET_OP_LIST
#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
>();
}
} // namespace

#define GET_OP_CLASSES
#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"

void mlir::xegpu::registerTransformDialectExtension(DialectRegistry &registry) {
registry.addExtensions<XeGPUTransformDialectExtension>();
}
Loading