Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -1002,11 +1002,17 @@ def relu(x):

@staticmethod
def minimum(a, b):
return f"triton_helpers.minimum({a}, {b})"
if torch.version.hip:
return f"tl.minimum({a}, {b}, tl.PropagateNan.ALL)"
else:
return f"triton_helpers.minimum({a}, {b})"

@staticmethod
def maximum(a, b):
return f"triton_helpers.maximum({a}, {b})"
if torch.version.hip:
return f"tl.maximum({a}, {b}, tl.PropagateNan.ALL)"
else:
return f"triton_helpers.maximum({a}, {b})"

@staticmethod
def where(a, b, c):
Expand Down Expand Up @@ -1202,7 +1208,10 @@ def load_seed(name, offset):
@staticmethod
@maybe_upcast_float32()
def rsqrt(x):
return f"libdevice.rsqrt({x})"
if torch.version.hip:
return f"tl.rsqrt({x})"
else:
return f"libdevice.rsqrt({x})"

@staticmethod
@maybe_upcast_float32()
Expand Down Expand Up @@ -3227,8 +3236,9 @@ def codegen_body(self):
loop_end = (
"rsplit_end" if self.cooperative_reduction else f"{prefix}numel"
)
num_stages = ", num_stages = 2" if torch.version.hip else ""
self.body.writeline(
f"for {prefix}offset in range({loop_start}, {loop_end}, {prefix.upper()}BLOCK):"
f"for {prefix}offset in tl.range({loop_start}, {loop_end}, {prefix.upper()}BLOCK{num_stages}):"
)
with self.body.indent(offset=level + 1):
self.iteration_ranges_codegen_header(tree, self.body)
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,7 +1135,7 @@ class triton:
# So far we see a fixed 8 spilled registers for kernels using sin/cos.
# Raise the threshold to 16 to be safe.
# We should revisit this once we understand more of the source of register spills.
spill_threshold: int = 16
spill_threshold: int = 32 if torch.version.hip else 16

# Generate code containing the newer tl.make_block_ptr() API for loads/store
use_block_ptr = False
Expand Down
3 changes: 2 additions & 1 deletion torch/_inductor/runtime/hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
from enum import auto, Enum
from typing import Optional, Union

import torch
from torch.utils._triton import has_triton_package


# The following maximums only apply to runtime autotuning, when using FixedTritonConfig one may see larger values
# NOTE: if these fail asserts submit a PR to increase them
TRITON_MAX_BLOCK = {
"X": 4096,
"X": 8192 if torch.version.hip else 4096,
"Y": 1024,
"Z": 1024,
"R0_": 4096 * 16, # * 16 is multi-kernel only
Expand Down
Loading