Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions custom_ops/gpu_ops/cpp_extensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1059,6 +1059,8 @@ std::vector<paddle::Tensor> UpdateAttnMaskOffsets(
const paddle::Tensor& decode_states,
const paddle::Tensor& mask_rollback);

void gelu_tanh(paddle::Tensor& output, paddle::Tensor& input, bool enable_pdl);

PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("get_expert_token_num",
&GetExpertTokenNum,
Expand Down Expand Up @@ -1648,4 +1650,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("update_attn_mask_offsets",
&UpdateAttnMaskOffsets,
"update attention mask");

m.def("gelu_tanh", &gelu_tanh, "gelu_tanh function");
}
88 changes: 88 additions & 0 deletions custom_ops/gpu_ops/gelu_tanh.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#include "helper.h"

__forceinline__ __device__ float tanh_ptx(float x) {
float y;
asm volatile("tanh.approx.f32 %0, %1;" : "=f"(y) : "f"(x));
return y;
}

__device__ __forceinline__ float gelu_tanh_func(const float& val) {
const float cdf =
0.5f * (1.0f + tanh_ptx((0.7978845608028654f *
(val + 0.044715f * val * val * val))));
return val * cdf;
}

template <typename T>
__global__ void gelu_tanh_kernel(T* __restrict__ out,
const T* __restrict__ input,
const int d) {
constexpr uint32_t kVecSize = 16 / sizeof(T);
const int64_t token_idx = blockIdx.x;
const int64_t thread_idx = threadIdx.x;
const int64_t stride = blockDim.x;
const int64_t offset = token_idx * d;
using vec_t = AlignedVector<T, kVecSize>;
#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && \
(__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
#endif

#pragma unroll 1
for (uint32_t idx = thread_idx; idx < d / kVecSize; idx += stride) {
vec_t x_vec;
Load(input + offset + idx * kVecSize, &x_vec);
#pragma unroll
for (uint32_t i = 0; i < kVecSize; ++i) {
x_vec[i] = static_cast<T>(gelu_tanh_func(static_cast<float>(x_vec[i])));
}
Store(x_vec, out + token_idx * d + idx * kVecSize);
}

const int64_t remaining_offset = d - d % (stride * kVecSize);
// process the remaining elements
#pragma unroll 1
for (int64_t idx = thread_idx; idx < d % (stride * kVecSize); idx += stride) {
float x = input[offset + remaining_offset + idx];
out[token_idx * d + remaining_offset + idx] =
static_cast<T>(gelu_tanh_func(x));
}

#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && \
(__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
#endif
}

void gelu_tanh(paddle::Tensor& output, paddle::Tensor& input, bool enable_pdl) {
int d = input.dims()[1];
int64_t num_tokens = input.dims()[0];
cudaStream_t stream = input.stream();

DISPATCH_FLOAT_FP6_DTYPE(input.dtype(), scalar_t, {
uint32_t vec_size = 16 / sizeof(scalar_t);
cudaLaunchConfig_t config;
config.gridDim = num_tokens;
config.blockDim = std::min(d / vec_size, 1024U);
config.dynamicSmemBytes = 0;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
config.numAttrs = 1;
config.attrs = attrs;

cudaLaunchKernelEx(&config,
gelu_tanh_kernel<scalar_t>,
output.data<scalar_t>(),
input.data<scalar_t>(),
d);
});
}

PD_BUILD_STATIC_OP(gelu_tanh)
.Inputs({"output", "input"})
.Attrs({"enable_pdl:bool"})
.Outputs({"out"})
.SetInplaceMap({{"output", "out"}})
.SetKernelFn(PD_KERNEL(gelu_tanh));
1 change: 1 addition & 0 deletions custom_ops/setup_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ def find_end_files(directory, end_str):
"gpu_ops/limit_thinking_content_length_v1.cu",
"gpu_ops/limit_thinking_content_length_v2.cu",
"gpu_ops/update_attn_mask_offsets.cu",
"gpu_ops/gelu_tanh.cu",
]

# pd_disaggregation
Expand Down
18 changes: 18 additions & 0 deletions fastdeploy/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,24 @@
from fastdeploy.config import FDConfig
from fastdeploy.platforms import current_platform

from .utils import is_arch_support_pdl


def gelu_tanh(
input: paddle.Tensor,
out: Optional[paddle.Tensor] = None,
enable_pdl: Optional[bool] = None,
):
"""GeLU Tanh operation."""
from fastdeploy.model_executor.ops.gpu import gelu_tanh

if out is None:
out = out = paddle.empty_like(input)
if enable_pdl is None:
enable_pdl = is_arch_support_pdl()
gelu_tanh(out, input, enable_pdl)
return out


class SiluAndMul(nn.Layer):
"""
Expand Down
7 changes: 7 additions & 0 deletions fastdeploy/model_executor/layers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,3 +402,10 @@ def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size: int, ran
def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int, world_size: int, offset: int = 0):
per_partition_vocab_size = divide(global_vocab_size, world_size)
return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, offset=offset)


@functools.lru_cache(maxsize=1)
def is_arch_support_pdl() -> bool:
# Hopper arch's compute capability == 9.0
prop = paddle.device.cuda.get_device_properties()
return prop.major >= 9
43 changes: 43 additions & 0 deletions tests/operators/test_gelu_tanh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import math
import unittest
from itertools import product

import numpy as np
import paddle

from fastdeploy.model_executor.layers.activation import gelu_tanh


class TestGeluTanh(unittest.TestCase):
def setUp(self):
paddle.seed(2024)
print(paddle.device.cuda.get_device_properties())
print(paddle.__git_commit__)

def native_gelu_tanh(
self,
input: paddle.Tensor,
):
x_fp32 = input.cast("float32")
out = (
0.5
* x_fp32
* (1.0 + paddle.tanh(math.sqrt(2.0 / math.pi) * (x_fp32 + 0.044715 * paddle.pow(x_fp32, 3.0))))
)
return out.cast(input.dtype)

def test_gelu_tanh(self):
bszs = [1, 32, 64, 128, 1024]
hidden_sizes = [4096, 7168]
test_cases = product(bszs, hidden_sizes)
for bsz, hidden_size in test_cases:
shape = [bsz, hidden_size]
input = paddle.randn(shape, dtype="float16")
out_ref = self.native_gelu_tanh(input)

out = gelu_tanh(input)
np.testing.assert_allclose(out_ref.numpy(), out.numpy(), rtol=1e-03, atol=1e-03)


if __name__ == "__main__":
unittest.main()
Loading