Skip to content

Commit a6a3069

Browse files
naromero77amdjeffdaily
authored andcommitted
[release/2.9][ROCm][inductor] Improved fast_tanh code generation (#2804)
In the ROCm fork of PyTorch 2.9, Inductor currently has codegen support for fast_tanhf. However, there were some NaN issues in the original Triton implementation of fast_tanhf . Upstream Triton has an improved fast_tanhf where the NaN issues are now fixed. This upstream commit has been backported to ROCm fork of Triton (see code comments). A bump in the Triton commit is also needed. Other notes: - In support of [SWDEV-560271](https://ontrack-internal.amd.com/browse/SWDEV-560271) - Triton 3.5 backport of upstream Triton commit ROCm/triton#901 - Similar to #2802, #2803 - Related to pytorch#162052
1 parent a2b0fd7 commit a6a3069

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
d704bc6e69c1a588c8edd3cbb67505d554ed65f6
1+
5df9c723de8c23508773b07fe16dd34e4c444541

torch/_inductor/codegen/triton.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from torch._prims_common import is_integer_dtype
2727
from torch.utils._ordered_set import OrderedSet
2828
from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing
29-
from torch.utils._triton import has_triton_package, has_triton_stable_tma_api
29+
from torch.utils._triton import has_triton_package, has_triton_stable_tma_api, get_triton_version
3030

3131
from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_type, SymT
3232
from ...utils._sympy.value_ranges import ValueRanges
@@ -1315,7 +1315,12 @@ def tan(x):
13151315
@staticmethod
13161316
@maybe_upcast_float32()
13171317
def tanh(x):
1318-
return f"libdevice.fast_tanhf({x})"
1318+
if torch.version.hip and get_triton_version() > (3, 2):
1319+
# On ROCm, use fast_tanhf depending on Triton version
1320+
# Requires ROCm fork of Triton 3.3, 3.4, 3.5 or upstream Triton 3.6+
1321+
return f"libdevice.fast_tanhf({x})"
1322+
else:
1323+
return f"libdevice.tanh({x})"
13191324

13201325
@staticmethod
13211326
@maybe_upcast_float32()

0 commit comments

Comments
 (0)