Skip to content

Conversation

@yuanlehome
Copy link
Collaborator

@yuanlehome yuanlehome commented Nov 6, 2025

Motivation

#4688#4836 支持了 ernie4.5 moe XPU 下 的 TSP-EP 混合并行,但仅验证了ernie4.5 moe 模型,且实现不够通用。此 RP 支持并验证了 Qwen3 moe 的 GPU + CUDAGraph 模式,对其他 MoE 模型同样适用。

Modifications

  • 支持 V1 loader
  • 一个 bug fix
  • 一些必要的组网调整
  • 默认生效 SP 优化,若需要关闭,通过添加 --disable-sequence-parallel-moe

Usage or Command

DP2TP4EP8

MODEL_PATH="/path/to/Qwen3-235B-A22B-Instruct-2507-FP8/"

# FD related
export FD_ENABLE_MULTI_API_SERVER=1
export ENABLE_V1_KVCACHE_SCHEDULER=1
export FD_USE_DEEP_GEMM=1

python -m fastdeploy.entrypoints.openai.multi_api_server \
    --port "1211,1222" \
    --num-servers 2 \
    --metrics-port "3112,3212" \
    --args \
    --host $(hostname -i) \
    --model "$MODEL_PATH" \
    --engine-worker-queue-port "1477,1478" \
    --cache-queue-port "55660,55661" \
    --disable-custom-all-reduce \
    --tensor-parallel-size 4 \
    --data-parallel-size 2 \
    --enable-expert-parallel \
    --max-model-len 65536 \
    --max-num-seqs 32 \
    --quantization block_wise_fp8 \
    --load-choices default_v1 \
    --graph-optimization-config '{"use_cudagraph":true,"use_unique_memory_pool":true}' \

Accuracy Tests

无。

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

@paddle-bot
Copy link

paddle-bot bot commented Nov 6, 2025

Thanks for your contribution!

Copilot AI review requested due to automatic review settings November 7, 2025 12:25
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR refactors the RMSNorm layer to consistently return a tuple of (output, residual_out) in all cases, and updates all model files to handle this new API. Additionally, it includes bug fixes for environment variable handling and quantization logic.

  • Standardizes RMSNorm to always return a tuple regardless of whether residual_input is provided
  • Updates all model files to unpack the tuple return value using [0] indexing
  • Fixes environment variable type conversion for stop sequence configuration
  • Corrects quantization logic for handling None values in output_dim

Reviewed Changes

Copilot reviewed 19 out of 19 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
fastdeploy/model_executor/layers/normalization.py Refactors RMSNorm.forward() to always return tuple (output, residual_out)
fastdeploy/model_executor/models/qwen3moe.py Updates RMSNorm calls to use new API with tuple unpacking
fastdeploy/model_executor/models/qwen3.py Updates RMSNorm calls to extract first element from returned tuple
fastdeploy/model_executor/models/qwen2.py Updates RMSNorm calls and input_layernorm parameter names
fastdeploy/model_executor/models/qwen2_5_vl/qwen2_5_vl.py Updates norm call to use tuple unpacking
fastdeploy/model_executor/models/paddleocr_vl/paddleocr_vl.py Updates norm call to use tuple unpacking
fastdeploy/model_executor/models/gpt_oss.py Updates RMSNorm calls to use new API
fastdeploy/model_executor/models/glm4_moe.py Updates q_norm and k_norm calls with tuple unpacking
fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py Updates RMSNorm calls to use new API
fastdeploy/model_executor/models/ernie4_5_mtp.py Updates multiple norm calls to use tuple unpacking
fastdeploy/model_executor/models/ernie4_5_moe.py Updates RMSNorm calls and removes forward_meta from post_attention_layernorm
fastdeploy/model_executor/models/deepseek_v3.py Updates layernorm and RMSNorm calls to use new API
fastdeploy/model_executor/layers/quantization/block_wise_fp8.py Adds None check before negating output_dim boolean
fastdeploy/model_executor/layers/linear.py Fixes reshape dimension and output_dim assignment
fastdeploy/input/text_processor.py Removes unused return_tensors parameter
fastdeploy/envs.py Converts environment variables to int type at source
fastdeploy/entrypoints/engine_client.py Removes redundant int() conversions
fastdeploy/engine/engine.py Removes redundant int() conversions
fastdeploy/config.py Removes redundant int() conversions
Comments suppressed due to low confidence (2)

fastdeploy/model_executor/models/qwen3moe.py:206

  • The post_attention_layernorm call is missing the forward_meta parameter. Looking at line 196-198 where input_layernorm is called with forward_meta, the post_attention_layernorm should also include this parameter for consistency and proper parallel execution handling.
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)

fastdeploy/model_executor/layers/normalization.py:1

  • The docstring states 'If residual_input is None, returns the normalized output tensor' and 'If residual_input is provided, returns a tuple', but the implementation now always returns a tuple (out, residual_out) regardless of whether residual_input is None. The documentation should be updated to reflect that this method always returns a tuple.
"""

@yuanlehome yuanlehome changed the title Support qwen3 moe tsp [TSP] Support qwen3 moe tsp + cudagraph Nov 7, 2025
zhupengyang
zhupengyang previously approved these changes Nov 10, 2025
Copy link
Collaborator

@zhupengyang zhupengyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

carryyu
carryyu previously approved these changes Nov 10, 2025
Copy link
Collaborator

@carryyu carryyu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@yuanlehome yuanlehome dismissed stale reviews from carryyu and zhupengyang via 5f2f1cb November 10, 2025 04:56
qingqing01
qingqing01 previously approved these changes Nov 10, 2025
Copy link
Collaborator

@qingqing01 qingqing01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

gongshaotian
gongshaotian previously approved these changes Nov 10, 2025
Copy link
Collaborator

@gongshaotian gongshaotian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@yuanlehome yuanlehome dismissed stale reviews from gongshaotian and qingqing01 via 0082b1a November 10, 2025 07:46
Copy link
Collaborator

@gongshaotian gongshaotian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Jiang-Jia-Jun Jiang-Jia-Jun merged commit 3dc0ffa into PaddlePaddle:develop Nov 10, 2025
14 of 16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants