diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index a4888a218fae..cf80829bf1a5 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -4147,6 +4147,10 @@ void AtenCatOp::getCanonicalizationPatterns(RewritePatternSet &patterns, if (!operandTy || !operandTy.hasSizes()) return failure(); int64_t adim = dim < 0 ? dim + operandTy.getSizes().size() : dim; + bool is1DEmptyTensor = + operandTy.getSizes().size() == 1 && operandTy.getSizes()[0] == 0; + if (is1DEmptyTensor) + continue; if (operandTy.getSizes()[adim] != 0) filtered.push_back(operand); } diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 48092b71a875..57bdc3db6555 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -3242,6 +3242,16 @@ func.func @aten_cat_zero(%arg0 : !torch.vtensor<[4,5,6],f32>, %arg1 : !torch.vte // ----- +// CHECK-LABEL: @aten_cat_empty +func.func @aten_cat_empty(%arg0 : !torch.vtensor<[4,5,6],f32>, %arg1 : !torch.vtensor<[0],f32>) -> !torch.vtensor<[4,5,6],f32> { + // CHECK: return %arg0 : !torch.vtensor<[4,5,6],f32> + %list = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[4,5,6],f32>, !torch.vtensor<[0],f32>) -> !torch.list + %dim = torch.constant.int -2 + %0 = torch.aten.cat %list, %dim : !torch.list, !torch.int -> !torch.vtensor<[4,5,6],f32> + return %0 : !torch.vtensor<[4,5,6],f32> +} +// ----- + // CHECK-LABEL: @aten_tensor_scalar_lt func.func @aten_tensor_scalar_lt() -> (!torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>) { // CHECK: %[[CST:.+]] = torch.vtensor.literal(dense : tensor<4xi1>) : !torch.vtensor<[4],i1>