diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 979a94eb7c6..00b516e140f 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -1059,6 +1059,8 @@ std::vector UpdateAttnMaskOffsets( const paddle::Tensor& decode_states, const paddle::Tensor& mask_rollback); +void gelu_tanh(paddle::Tensor& output, paddle::Tensor& input, bool enable_pdl); + PYBIND11_MODULE(fastdeploy_ops, m) { m.def("get_expert_token_num", &GetExpertTokenNum, @@ -1648,4 +1650,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("update_attn_mask_offsets", &UpdateAttnMaskOffsets, "update attention mask"); + + m.def("gelu_tanh", &gelu_tanh, "gelu_tanh function"); } diff --git a/custom_ops/gpu_ops/gelu_tanh.cu b/custom_ops/gpu_ops/gelu_tanh.cu new file mode 100644 index 00000000000..ce57149b378 --- /dev/null +++ b/custom_ops/gpu_ops/gelu_tanh.cu @@ -0,0 +1,88 @@ +#include "helper.h" + +__forceinline__ __device__ float tanh_ptx(float x) { + float y; + asm volatile("tanh.approx.f32 %0, %1;" : "=f"(y) : "f"(x)); + return y; +} + +__device__ __forceinline__ float gelu_tanh_func(const float& val) { + const float cdf = + 0.5f * (1.0f + tanh_ptx((0.7978845608028654f * + (val + 0.044715f * val * val * val)))); + return val * cdf; +} + +template +__global__ void gelu_tanh_kernel(T* __restrict__ out, + const T* __restrict__ input, + const int d) { + constexpr uint32_t kVecSize = 16 / sizeof(T); + const int64_t token_idx = blockIdx.x; + const int64_t thread_idx = threadIdx.x; + const int64_t stride = blockDim.x; + const int64_t offset = token_idx * d; + using vec_t = AlignedVector; +#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && \ + (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); +#endif + +#pragma unroll 1 + for (uint32_t idx = thread_idx; idx < d / kVecSize; idx += stride) { + vec_t x_vec; + Load(input + offset + idx * kVecSize, &x_vec); +#pragma unroll + for (uint32_t i = 0; i < kVecSize; ++i) { + x_vec[i] = static_cast(gelu_tanh_func(static_cast(x_vec[i]))); + } + Store(x_vec, out + token_idx * d + idx * kVecSize); + } + + const int64_t remaining_offset = d - d % (stride * kVecSize); + // process the remaining elements +#pragma unroll 1 + for (int64_t idx = thread_idx; idx < d % (stride * kVecSize); idx += stride) { + float x = input[offset + remaining_offset + idx]; + out[token_idx * d + remaining_offset + idx] = + static_cast(gelu_tanh_func(x)); + } + +#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && \ + (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif +} + +void gelu_tanh(paddle::Tensor& output, paddle::Tensor& input, bool enable_pdl) { + int d = input.dims()[1]; + int64_t num_tokens = input.dims()[0]; + cudaStream_t stream = input.stream(); + + DISPATCH_FLOAT_FP6_DTYPE(input.dtype(), scalar_t, { + uint32_t vec_size = 16 / sizeof(scalar_t); + cudaLaunchConfig_t config; + config.gridDim = num_tokens; + config.blockDim = std::min(d / vec_size, 1024U); + config.dynamicSmemBytes = 0; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; + config.numAttrs = 1; + config.attrs = attrs; + + cudaLaunchKernelEx(&config, + gelu_tanh_kernel, + output.data(), + input.data(), + d); + }); +} + +PD_BUILD_STATIC_OP(gelu_tanh) + .Inputs({"output", "input"}) + .Attrs({"enable_pdl:bool"}) + .Outputs({"out"}) + .SetInplaceMap({{"output", "out"}}) + .SetKernelFn(PD_KERNEL(gelu_tanh)); diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index 4903f3dd575..72ed0a2b11c 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -306,6 +306,7 @@ def find_end_files(directory, end_str): "gpu_ops/limit_thinking_content_length_v1.cu", "gpu_ops/limit_thinking_content_length_v2.cu", "gpu_ops/update_attn_mask_offsets.cu", + "gpu_ops/gelu_tanh.cu", ] # pd_disaggregation diff --git a/fastdeploy/model_executor/layers/activation.py b/fastdeploy/model_executor/layers/activation.py index 79fd3b24f64..7b8a1197918 100644 --- a/fastdeploy/model_executor/layers/activation.py +++ b/fastdeploy/model_executor/layers/activation.py @@ -23,6 +23,24 @@ from fastdeploy.config import FDConfig from fastdeploy.platforms import current_platform +from .utils import is_arch_support_pdl + + +def gelu_tanh( + input: paddle.Tensor, + out: Optional[paddle.Tensor] = None, + enable_pdl: Optional[bool] = None, +): + """GeLU Tanh operation.""" + from fastdeploy.model_executor.ops.gpu import gelu_tanh + + if out is None: + out = out = paddle.empty_like(input) + if enable_pdl is None: + enable_pdl = is_arch_support_pdl() + gelu_tanh(out, input, enable_pdl) + return out + class SiluAndMul(nn.Layer): """ diff --git a/fastdeploy/model_executor/layers/utils.py b/fastdeploy/model_executor/layers/utils.py index c0644896e8e..d1a83faeff1 100644 --- a/fastdeploy/model_executor/layers/utils.py +++ b/fastdeploy/model_executor/layers/utils.py @@ -402,3 +402,10 @@ def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size: int, ran def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int, world_size: int, offset: int = 0): per_partition_vocab_size = divide(global_vocab_size, world_size) return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, offset=offset) + + +@functools.lru_cache(maxsize=1) +def is_arch_support_pdl() -> bool: + # Hopper arch's compute capability == 9.0 + prop = paddle.device.cuda.get_device_properties() + return prop.major >= 9 diff --git a/tests/operators/test_gelu_tanh.py b/tests/operators/test_gelu_tanh.py new file mode 100644 index 00000000000..450d83d5897 --- /dev/null +++ b/tests/operators/test_gelu_tanh.py @@ -0,0 +1,43 @@ +import math +import unittest +from itertools import product + +import numpy as np +import paddle + +from fastdeploy.model_executor.layers.activation import gelu_tanh + + +class TestGeluTanh(unittest.TestCase): + def setUp(self): + paddle.seed(2024) + print(paddle.device.cuda.get_device_properties()) + print(paddle.__git_commit__) + + def native_gelu_tanh( + self, + input: paddle.Tensor, + ): + x_fp32 = input.cast("float32") + out = ( + 0.5 + * x_fp32 + * (1.0 + paddle.tanh(math.sqrt(2.0 / math.pi) * (x_fp32 + 0.044715 * paddle.pow(x_fp32, 3.0)))) + ) + return out.cast(input.dtype) + + def test_gelu_tanh(self): + bszs = [1, 32, 64, 128, 1024] + hidden_sizes = [4096, 7168] + test_cases = product(bszs, hidden_sizes) + for bsz, hidden_size in test_cases: + shape = [bsz, hidden_size] + input = paddle.randn(shape, dtype="float16") + out_ref = self.native_gelu_tanh(input) + + out = gelu_tanh(input) + np.testing.assert_allclose(out_ref.numpy(), out.numpy(), rtol=1e-03, atol=1e-03) + + +if __name__ == "__main__": + unittest.main()