From 1dc663199b92b6d2815681c62fbb42e5513bc5d1 Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Fri, 31 Oct 2025 00:13:51 +0000 Subject: [PATCH 1/2] example: using nvrtc kernel for aot plugin --- examples/dynamo/nvrtc_aot_plugin.py | 247 ++++++++++++++++++++++++++++ 1 file changed, 247 insertions(+) create mode 100644 examples/dynamo/nvrtc_aot_plugin.py diff --git a/examples/dynamo/nvrtc_aot_plugin.py b/examples/dynamo/nvrtc_aot_plugin.py new file mode 100644 index 0000000000..04b3ea7b72 --- /dev/null +++ b/examples/dynamo/nvrtc_aot_plugin.py @@ -0,0 +1,247 @@ +""" +Minimal reproducible example demonstrating TensorRT fp16 custom_op() issue. + +This module shows the bug where torch_tensorrt.dynamo.conversion.plugins.custom_op() +fails to compile operations that use fp16 (half-precision) tensors. + +The issue occurs because the JIT plugin generator doesn't properly declare format +support for fp16 data types in the generated TensorRT plugin. +""" + +from typing import List, Tuple, Union + +import torch + +# Import triton for kernel implementation +import triton +import triton.language as tl + +import torch_tensorrt + +# ============================================================================ +# Triton Kernel for Eager Execution +# ============================================================================ + + +@triton.jit +def pointwise_sigmoid_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + # Program ID determines the block of data each thread will process + pid = tl.program_id(0) + # Compute the range of elements that this thread block will work on + block_start = pid * BLOCK_SIZE + # Range of indices this thread will handle + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Mask for boundary checking + mask = offsets < n_elements + # Load elements from the X tensor + x = tl.load(x_ptr + offsets, mask=mask, other=0.0) + # Convert to float32 for computation + x_f32 = x.to(tl.float32) + # Compute sigmoid: 1 / (1 + exp(-x)) + output = tl.sigmoid(x_f32) + # Convert back to original dtype + output_casted = output.to(x.dtype) + # Store the result in Y + tl.store(y_ptr + offsets, output_casted, mask=mask) + + +# ============================================================================ +# Custom Op Registration +# ============================================================================ + + +@torch.library.custom_op("pointwise_sigmoid_ops::pointwise_sigmoid", mutates_args=()) # type: ignore[misc] +def pointwise_sigmoid(X: torch.Tensor) -> torch.Tensor: + # Ensure the tensor is on the GPU + assert X.is_cuda, "Tensor must be on CUDA device." + + # Create output tensor + Y = torch.empty_like(X) + + # Define block size + BLOCK_SIZE = 256 + + # Grid of programs + grid = lambda meta: (triton.cdiv(X.numel(), meta["BLOCK_SIZE"]),) + + # Launch the kernel + pointwise_sigmoid_kernel[grid](X, Y, X.numel(), BLOCK_SIZE=BLOCK_SIZE) + + return Y + + +@torch.library.register_fake("pointwise_sigmoid_ops::pointwise_sigmoid") +def _(input: torch.Tensor) -> torch.Tensor: + """Fake implementation for TorchDynamo tracing of base operation.""" + return torch.empty_like(input) + + +# ============================================================================ +# TensorRT Wrapper with custom_op() - THIS FAILS WITH FP16 +# ============================================================================ + +import tensorrt.plugin as trtp +from cuda.core.experimental import Device, LaunchConfig, Program, ProgramOptions + + +@trtp.register("pointwise_sigmoid_ops::pointwise_sigmoid") +def sigmoid_plugin_desc(input: trtp.TensorDesc) -> Tuple[trtp.TensorDesc]: + return (input.like(),) + + +@trtp.autotune("pointwise_sigmoid_ops::pointwise_sigmoid") +def sigmoid_autotune( + input: trtp.TensorDesc, + outputs: Tuple[trtp.TensorDesc], +) -> List[trtp.AutoTuneCombination]: + return [trtp.AutoTuneCombination("FP16, FP16", "LINEAR")] + + +# @trtp.aot_impl("pointwise_sigmoid_ops::pointwise_sigmoid") +# def sigmoid_aot_triton_impl( +# input: trtp.TensorDesc, +# outputs: Tuple[trtp.TensorDesc], +# tactic: int, +# ) -> Tuple[ +# Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs +# ]: +# print("WE ARE NOW GENERATING THE PTX FOR THE PLUGIN (Triton)!!!") + +# # Reuse the same Triton kernel we use for eager execution +# src = triton.compiler.ASTSource( +# fn=pointwise_sigmoid_kernel, +# signature={ +# "x_ptr": "*fp16", +# "y_ptr": "*fp16", +# "n_elements": "i32", +# "BLOCK_SIZE": "constexpr", +# }, +# constexprs={"BLOCK_SIZE": 256}, +# ) + +# compiled_kernel = triton.compile(src) + +# N = input.shape_expr.numel() +# launch_params = trtp.KernelLaunchParams() +# launch_params.grid_x = trtp.cdiv(N, 256) +# launch_params.block_x = compiled_kernel.metadata.num_warps * 32 +# launch_params.shared_mem = compiled_kernel.metadata.shared + +# extra_args = trtp.SymIntExprs(1) +# extra_args[0] = trtp.SymInt32(N) + +# print(compiled_kernel.asm["ptx"]) + +# return ( +# compiled_kernel.metadata.name, +# compiled_kernel.asm["ptx"], +# launch_params, +# extra_args, +# ) + + +cu_code = """ +#include + +// Simple pointwise Sigmoid kernel: f(x) = 1 / (1 + exp(-x)) +__global__ void pointwise_sigmoid_kernel_nvrtc(const __half* __restrict__ input, + __half* __restrict__ output, + const int size) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (idx < size) { + const float x = __half2float(input[idx]); + const float result = 1.0f / (1.0f + expf(-x)); + output[idx] = __float2half(result); + } +} +""" + + +@trtp.aot_impl("pointwise_sigmoid_ops::pointwise_sigmoid") +def sigmoid_aot_nvrtc_impl( + input: trtp.TensorDesc, + outputs: Tuple[trtp.TensorDesc], + tactic: int, +) -> Tuple[ + Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs +]: + print("WE ARE NOW GENERATING THE PTX FOR THE PLUGIN (NVRTC)!!!") + + dev = Device() + dev.set_current() + program_options = ProgramOptions( + std="c++17", arch=f"sm_{dev.arch}", include_path=["/usr/local/cuda/include"] + ) + program = Program(cu_code, code_type="c++", options=program_options) + mod = program.compile("ptx", name_expressions=("pointwise_sigmoid_kernel_nvrtc",)) + compiled_kernel = mod.code.decode("utf-8") + print(compiled_kernel) + + N = input.shape_expr.numel() + launch_params = trtp.KernelLaunchParams() + launch_params.grid_x = trtp.cdiv((N + 256 - 1), 256) + launch_params.block_x = 256 + launch_params.shared_mem = 0 + + extra_args = trtp.SymIntExprs(1) + extra_args[0] = trtp.SymInt32(N) + + return ( + "pointwise_sigmoid_kernel_nvrtc", + compiled_kernel, + launch_params, + extra_args, + ) + + +torch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter( + "pointwise_sigmoid_ops::pointwise_sigmoid", + supports_dynamic_shapes=True, + requires_output_allocator=False, +) + + +# ============================================================================ +# Test Model +# ============================================================================ + + +class PointwiseSigmoidModel_WithTRTWrapper(torch.nn.Module): + """ + Test model that uses the TRT wrapper with custom_op() registration. + + When compiled with torch_tensorrt.compile() using fp16 inputs, this will + fail with: "could not find any supported formats consistent with input/output + data types" + """ + + def forward(self, input: torch.Tensor) -> torch.Tensor: + x = torch.mul(input, 2) + y = torch.div(x, 2) + z = torch.ops.pointwise_sigmoid_ops.pointwise_sigmoid(y) + a = torch.add(z, 1) + return a + + +if __name__ == "__main__": + model = PointwiseSigmoidModel_WithTRTWrapper().to("cuda").eval() + input = torch.randn(1, 1024, device="cuda", dtype=torch.float16) + + with torch_tensorrt.logging.debug(): + trt_inputs = [input] + model_trt = torch_tensorrt.compile( + model, + inputs=trt_inputs, + min_block_size=1, + ) + print("Model compiled successfully!") + print("Running inference with compiled model...") + with torch.no_grad(): + for i in range(10): + res = model_trt(input) + assert torch.allclose( + res, model(input), rtol=1e-2, atol=1e-2 + ), "Results do not match!" + + print("Inference successful!") From 67abbd525e75e15d1832398f99aa9bd51bbe4a31 Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Fri, 31 Oct 2025 23:30:26 +0000 Subject: [PATCH 2/2] update --- examples/dynamo/nvrtc_aot_plugin.py | 194 ++++++++++++---------------- 1 file changed, 82 insertions(+), 112 deletions(-) diff --git a/examples/dynamo/nvrtc_aot_plugin.py b/examples/dynamo/nvrtc_aot_plugin.py index 04b3ea7b72..b66952ca2e 100644 --- a/examples/dynamo/nvrtc_aot_plugin.py +++ b/examples/dynamo/nvrtc_aot_plugin.py @@ -12,37 +12,46 @@ import torch -# Import triton for kernel implementation -import triton -import triton.language as tl - import torch_tensorrt -# ============================================================================ -# Triton Kernel for Eager Execution -# ============================================================================ +# CUDA kernel source (NVRTC) used by the torch custom op +cu_code = """ +#include + +// Simple pointwise Sigmoid kernel: f(x) = 1 / (1 + exp(-x)) +__global__ void pointwise_sigmoid_kernel_nvrtc(const __half* __restrict__ input, + __half* __restrict__ output, + const int size) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (idx < size) { + const float x = __half2float(input[idx]); + const float result = 1.0f / (1.0f + expf(-x)); + output[idx] = __float2half(result); + } +} +""" +# Prepare NVRTC program, kernel, and stream once (simple eager path) +from cuda.core.experimental import ( + Device as _CudaDevice, + LaunchConfig as _LaunchConfig, + Program as _CudaProgram, + ProgramOptions as _CudaProgramOptions, + launch as _cuda_launch, +) -@triton.jit -def pointwise_sigmoid_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr): - # Program ID determines the block of data each thread will process - pid = tl.program_id(0) - # Compute the range of elements that this thread block will work on - block_start = pid * BLOCK_SIZE - # Range of indices this thread will handle - offsets = block_start + tl.arange(0, BLOCK_SIZE) - # Mask for boundary checking - mask = offsets < n_elements - # Load elements from the X tensor - x = tl.load(x_ptr + offsets, mask=mask, other=0.0) - # Convert to float32 for computation - x_f32 = x.to(tl.float32) - # Compute sigmoid: 1 / (1 + exp(-x)) - output = tl.sigmoid(x_f32) - # Convert back to original dtype - output_casted = output.to(x.dtype) - # Store the result in Y - tl.store(y_ptr + offsets, output_casted, mask=mask) +_cuda_device = _CudaDevice() +_cuda_device.set_current() +_cuda_stream = _cuda_device.create_stream() +_program_options = _CudaProgramOptions( + std="c++17", arch=f"sm_{_cuda_device.arch}", include_path=["/usr/local/cuda/include"] +) +_program = _CudaProgram(cu_code, code_type="c++", options=_program_options) +_module = _program.compile("ptx", name_expressions=("pointwise_sigmoid_kernel_nvrtc",)) +_kernel = _module.get_kernel("pointwise_sigmoid_kernel_nvrtc") + +# Eager torch custom_op implemented using the CUDA kernel above (no Triton) # ============================================================================ @@ -52,20 +61,37 @@ def pointwise_sigmoid_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr) @torch.library.custom_op("pointwise_sigmoid_ops::pointwise_sigmoid", mutates_args=()) # type: ignore[misc] def pointwise_sigmoid(X: torch.Tensor) -> torch.Tensor: - # Ensure the tensor is on the GPU assert X.is_cuda, "Tensor must be on CUDA device." - # Create output tensor Y = torch.empty_like(X) + N = int(X.numel()) + + block = 256 - # Define block size - BLOCK_SIZE = 256 + grid_x = max(1, (N + block - 1) // block) + config = _LaunchConfig(grid=(grid_x), block=(block)) - # Grid of programs - grid = lambda meta: (triton.cdiv(X.numel(), meta["BLOCK_SIZE"]),) + # Use PyTorch's current stream by wrapping it for cuda.core + class _PyTorchStreamWrapper: + def __init__(self, pt_stream): + self.pt_stream = pt_stream - # Launch the kernel - pointwise_sigmoid_kernel[grid](X, Y, X.numel(), BLOCK_SIZE=BLOCK_SIZE) + def __cuda_stream__(self): + stream_id = self.pt_stream.cuda_stream + return (0, stream_id) + + pt_stream = torch.cuda.current_stream() + s = _cuda_device.create_stream(_PyTorchStreamWrapper(pt_stream)) + + # Launch kernel with raw pointers as in cuda.core example + _cuda_launch( + s, + config, + _kernel, + X.data_ptr(), + Y.data_ptr(), + N, + ) return Y @@ -97,66 +123,6 @@ def sigmoid_autotune( return [trtp.AutoTuneCombination("FP16, FP16", "LINEAR")] -# @trtp.aot_impl("pointwise_sigmoid_ops::pointwise_sigmoid") -# def sigmoid_aot_triton_impl( -# input: trtp.TensorDesc, -# outputs: Tuple[trtp.TensorDesc], -# tactic: int, -# ) -> Tuple[ -# Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs -# ]: -# print("WE ARE NOW GENERATING THE PTX FOR THE PLUGIN (Triton)!!!") - -# # Reuse the same Triton kernel we use for eager execution -# src = triton.compiler.ASTSource( -# fn=pointwise_sigmoid_kernel, -# signature={ -# "x_ptr": "*fp16", -# "y_ptr": "*fp16", -# "n_elements": "i32", -# "BLOCK_SIZE": "constexpr", -# }, -# constexprs={"BLOCK_SIZE": 256}, -# ) - -# compiled_kernel = triton.compile(src) - -# N = input.shape_expr.numel() -# launch_params = trtp.KernelLaunchParams() -# launch_params.grid_x = trtp.cdiv(N, 256) -# launch_params.block_x = compiled_kernel.metadata.num_warps * 32 -# launch_params.shared_mem = compiled_kernel.metadata.shared - -# extra_args = trtp.SymIntExprs(1) -# extra_args[0] = trtp.SymInt32(N) - -# print(compiled_kernel.asm["ptx"]) - -# return ( -# compiled_kernel.metadata.name, -# compiled_kernel.asm["ptx"], -# launch_params, -# extra_args, -# ) - - -cu_code = """ -#include - -// Simple pointwise Sigmoid kernel: f(x) = 1 / (1 + exp(-x)) -__global__ void pointwise_sigmoid_kernel_nvrtc(const __half* __restrict__ input, - __half* __restrict__ output, - const int size) { - const int idx = blockIdx.x * blockDim.x + threadIdx.x; - - if (idx < size) { - const float x = __half2float(input[idx]); - const float result = 1.0f / (1.0f + expf(-x)); - output[idx] = __float2half(result); - } -} -""" - @trtp.aot_impl("pointwise_sigmoid_ops::pointwise_sigmoid") def sigmoid_aot_nvrtc_impl( @@ -166,22 +132,19 @@ def sigmoid_aot_nvrtc_impl( ) -> Tuple[ Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs ]: - print("WE ARE NOW GENERATING THE PTX FOR THE PLUGIN (NVRTC)!!!") - dev = Device() - dev.set_current() - program_options = ProgramOptions( - std="c++17", arch=f"sm_{dev.arch}", include_path=["/usr/local/cuda/include"] - ) - program = Program(cu_code, code_type="c++", options=program_options) - mod = program.compile("ptx", name_expressions=("pointwise_sigmoid_kernel_nvrtc",)) - compiled_kernel = mod.code.decode("utf-8") + compiled_kernel= _module.code.decode("utf-8") + print(type(compiled_kernel)) print(compiled_kernel) + # import pdb; pdb.set_trace() + + N = input.shape_expr.numel() launch_params = trtp.KernelLaunchParams() - launch_params.grid_x = trtp.cdiv((N + 256 - 1), 256) - launch_params.block_x = 256 + block = 256 + launch_params.grid_x = trtp.cdiv(N, block) + launch_params.block_x = block launch_params.shared_mem = 0 extra_args = trtp.SymIntExprs(1) @@ -217,26 +180,33 @@ class PointwiseSigmoidModel_WithTRTWrapper(torch.nn.Module): """ def forward(self, input: torch.Tensor) -> torch.Tensor: - x = torch.mul(input, 2) - y = torch.div(x, 2) - z = torch.ops.pointwise_sigmoid_ops.pointwise_sigmoid(y) - a = torch.add(z, 1) - return a + + z = torch.ops.pointwise_sigmoid_ops.pointwise_sigmoid(input) + return z if __name__ == "__main__": model = PointwiseSigmoidModel_WithTRTWrapper().to("cuda").eval() input = torch.randn(1, 1024, device="cuda", dtype=torch.float16) + print(torch.sigmoid(input)) + + print(model(input)) + with torch_tensorrt.logging.debug(): trt_inputs = [input] model_trt = torch_tensorrt.compile( model, inputs=trt_inputs, + enabled_precisions={torch.float16}, min_block_size=1, ) print("Model compiled successfully!") print("Running inference with compiled model...") + print("Compiled model output:") + print(model_trt(input)) + print("Original model output:") + print(model(input)) with torch.no_grad(): for i in range(10): res = model_trt(input) @@ -244,4 +214,4 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: res, model(input), rtol=1e-2, atol=1e-2 ), "Results do not match!" - print("Inference successful!") + # print("Inference successful!")