@@ -26,20 +26,20 @@ def run(f):
2626# CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>,
2727# CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>,
2828# CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> {
29- # CHECK: %[[S0:.+]] = torch.symbolic_int "s0 " {min_val = 5, max_val = 10} : !torch.int
30- # CHECK: %[[S1:.+]] = torch.symbolic_int "s1 " {min_val = {{[0-9]+}}, max_val = 100} : !torch.int
31- # CHECK: %[[S2:.+]] = torch.symbolic_int "s3 " {min_val = {{[0-9]+}}, max_val = 50} : !torch.int
32- # CHECK: %[[S3:.+]] = torch.symbolic_int "s5 " {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
33- # CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0 ]], %[[S1 ]]], affine_map<()[s0, s1] -> (s0, s1 , 3)> : !torch.vtensor<[?,?,3],f32>
29+ # CHECK: %[[S0:.+]] = torch.symbolic_int "s{{[0-9]+}} " {min_val = 5, max_val = 10} : !torch.int
30+ # CHECK: %[[S1:.+]] = torch.symbolic_int "s{{[0-9]+}} " {min_val = {{[0-9]+}}, max_val = 100} : !torch.int
31+ # CHECK: %[[S2:.+]] = torch.symbolic_int "s{{[0-9]+}} " {min_val = {{[0-9]+}}, max_val = 50} : !torch.int
32+ # CHECK: %[[S3:.+]] = torch.symbolic_int "s{{[0-9]+}} " {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
33+ # CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S1 ]], %[[S0 ]]], affine_map<()[s0, s1] -> (s1, s0 , 3)> : !torch.vtensor<[?,?,3],f32>
3434# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]], %[[S2]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
35- # CHECK: torch.bind_symbolic_shape %[[ARG2]], [%[[S0 ]], %[[S3 ]]], affine_map<()[s0, s1] -> (s0, s1 , 3)> : !torch.vtensor<[?,?,3],f32>
35+ # CHECK: torch.bind_symbolic_shape %[[ARG2]], [%[[S3 ]], %[[S0 ]]], affine_map<()[s0, s1] -> (s1, s0 , 3)> : !torch.vtensor<[?,?,3],f32>
3636# CHECK: %[[TANH:.+]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>
37- # CHECK: torch.bind_symbolic_shape %[[TANH]], [%[[S0 ]], %[[S1 ]]], affine_map<()[s0, s1] -> (s0, s1 , 3)> : !torch.vtensor<[?,?,3],f32>
37+ # CHECK: torch.bind_symbolic_shape %[[TANH]], [%[[S1 ]], %[[S0 ]]], affine_map<()[s0, s1] -> (s1, s0 , 3)> : !torch.vtensor<[?,?,3],f32>
3838# CHECK: %[[SIG:.+]] = torch.aten.sigmoid %[[ARG1]] : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>
3939# CHECK: torch.bind_symbolic_shape %[[SIG]], [%[[S0]], %[[S2]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
4040# CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[TANH]], %[[TANH]], %[[SIG]], %[[ARG2]] : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.list<vtensor>
4141# CHECK: %[[CAT:.+]] = torch.aten.cat %[[LIST]], {{.*}} : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?,3],f32>
42- # CHECK: torch.bind_symbolic_shape %[[CAT]], [%[[S0 ]], %[[S1 ]], %[[S2 ]], %[[S3 ]]], affine_map<()[s0, s1, s2, s3] -> (s0, s2 + s3 + s1 * 2, 3)> : !torch.vtensor<[?,?,3],f32>
42+ # CHECK: torch.bind_symbolic_shape %[[CAT]], [%[[S1 ]], %[[S3 ]], %[[S0 ]], %[[S2 ]]], affine_map<()[s0, s1, s2, s3] -> (s2, s1 + s3 + s0 * 2, 3)> : !torch.vtensor<[?,?,3],f32>
4343# CHECK: return %[[CAT]] : !torch.vtensor<[?,?,3],f32>
4444def test_tanh_sigmoid_cat ():
4545 class TanhSigmoidCat (nn .Module ):
0 commit comments