From 6e2e8b5c4d313648e92553fa30fbd8023c1f6506 Mon Sep 17 00:00:00 2001 From: xiaoxiaohehe001 Date: Fri, 7 Nov 2025 17:54:01 +0800 Subject: [PATCH 1/2] fix eplb --- fastdeploy/eplb/async_expert_loader.py | 9 ++++++--- fastdeploy/eplb/experts_manager.py | 3 ++- fastdeploy/model_executor/layers/moe/ep.py | 2 ++ .../layers/moe/fused_moe_cutlass_backend.py | 4 ++++ .../layers/moe/fused_moe_deepgemm_backend.py | 4 ++++ fastdeploy/model_executor/layers/moe/moe.py | 3 ++- fastdeploy/model_executor/models/ernie4_5_moe.py | 2 +- fastdeploy/worker/worker_process.py | 16 +++++++++++++--- 8 files changed, 34 insertions(+), 9 deletions(-) diff --git a/fastdeploy/eplb/async_expert_loader.py b/fastdeploy/eplb/async_expert_loader.py index 14ca9990142..92b53637544 100644 --- a/fastdeploy/eplb/async_expert_loader.py +++ b/fastdeploy/eplb/async_expert_loader.py @@ -162,6 +162,9 @@ def load_tensor_from_shm_mem(tensor_infos, shm_ptr, logger=None): # NumPy 不支持 bfloat16,因此先以 uint16 读取原始数据,再用 Paddle cast 为 bfloat16 tmp = np_array.view(np.uint16) tensor = paddle.Tensor(tmp, dtype=paddle.bfloat16, place=paddle.CPUPlace(), zero_copy=True) + elif dtype == paddle.float8_e4m3fn: + tmp = np_array.view(np.uint8) + tensor = paddle.Tensor(tmp, dtype=paddle.float8_e4m3fn, place=paddle.CPUPlace(), zero_copy=True) else: raise TypeError(f"Unsupported dtype: {dtype}") @@ -306,8 +309,8 @@ def load_safetensor_fp8_from_disk(self, need_to_reload: List[Tuple[int, int]]): """ up_gate_down = ["up_gate_proj", "down_proj"] quant_weight_scale = ["quant_weight", "weight_scale"] - if self.moe_quant_type == "w4a8": - quant_weight_scale = ["quant_weight"] + # if self.moe_quant_type == "w4a8": + # quant_weight_scale = ["quant_weight"] ckpt_name = [ (f"ernie.layers.{layer_id}.mlp.experts.{expert_id}.{proj_name}.{quant_name}") for layer_id, expert_id in need_to_reload @@ -324,7 +327,7 @@ def load_safetensor_fp8_from_disk(self, need_to_reload: List[Tuple[int, int]]): from safetensors import safe_open for st_file in hf_weights_files: - with safe_open(st_file, framework="np", device="cpu") as f: + with safe_open(st_file, framework="paddle", device="cpu") as f: for name in f.keys(): if name in ckpt_name: weight = f.get_tensor(name) diff --git a/fastdeploy/eplb/experts_manager.py b/fastdeploy/eplb/experts_manager.py index c8a2ea19789..d2af1a1b168 100644 --- a/fastdeploy/eplb/experts_manager.py +++ b/fastdeploy/eplb/experts_manager.py @@ -122,7 +122,7 @@ def __init__( self.http_timeout = 1 # 重置重排状态: 'done' -> 'free' self.rearrange_end_ts = 0 - self.rearrange_reset_interval = 300 + self.rearrange_reset_interval = 30 self.tensor_infos = None @@ -437,6 +437,7 @@ def allreduce_load_weight_result(self): # prefill需要等待调度屏蔽 if ( self.fd_config.splitwise_role == "decode" + or self.fd_config.splitwise_role == "mixed" or not self.eplb_config.redundant_expert_enable_schedule_cordon ): self.logger.info("redundant_expert: allreduce_load_weight_result success, notify infer.py") diff --git a/fastdeploy/model_executor/layers/moe/ep.py b/fastdeploy/model_executor/layers/moe/ep.py index 4b4de0b5ac6..73a42cc1343 100644 --- a/fastdeploy/model_executor/layers/moe/ep.py +++ b/fastdeploy/model_executor/layers/moe/ep.py @@ -436,6 +436,8 @@ def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor): expert_in_rank_num_list, tokens_per_expert_stats_list, ) = layer.redundant_table_manger.get_ep_rank_to_expert_id_list_by_layer(layer.layer_idx) + if layer.is_rearrange is False: + expert_id_to_ep_rank_array = paddle.arange(layer.num_experts).cast("int32") if layer.topk_method == "noaux_tc": from .moe import get_moe_scores diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index 31f52063b40..9dd5c9984a9 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -971,6 +971,10 @@ def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: # self.check(layer, up_gate_proj_weights, down_proj_weights) up_gate_proj_weight_scale = [] down_proj_weight_scale = [] + + if isinstance(state_dict, list): + state_dict = dict(state_dict) + for expert_idx in logical_expert_ids: up_gate_proj_weight_scale.append( get_tensor(state_dict.pop(up_gate_proj_expert_weight_scale_key.format(expert_idx))) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py index 6c558efbab7..178bc74b342 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py @@ -246,6 +246,10 @@ def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: # self.check(layer, up_gate_proj_weights, down_proj_weights) up_gate_proj_weight_scale = [] down_proj_weight_scale = [] + + if isinstance(state_dict, list): + state_dict = dict(state_dict) + for expert_idx in logical_expert_ids: up_gate_proj_expert_weight_scale_key_name = up_gate_proj_expert_weight_scale_key.format(expert_idx) down_proj_expert_weight_scale_key_name = down_proj_expert_weight_scale_key.format(expert_idx) diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index 49a215f2014..8b83aeccabe 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -191,6 +191,7 @@ def __init__( else: self.quant_method = get_moe_method() self.redundant_table_manger = redundant_table_manger + self.is_rearrange = False if self.ep_size > 1: self.quant_method.init_ep(self) @@ -397,7 +398,7 @@ def load_experts_weight( ) ] ep_rank_to_expert_id_list = [i for i in range(self.num_experts)] - if self.redundant_table_manger is not None: + if self.redundant_table_manger is not None and is_rearrange is True: ( ep_rank_to_expert_id_list, expert_id_to_ep_rank_array, diff --git a/fastdeploy/model_executor/models/ernie4_5_moe.py b/fastdeploy/model_executor/models/ernie4_5_moe.py index 96758472deb..9e02a44ef2c 100644 --- a/fastdeploy/model_executor/models/ernie4_5_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_moe.py @@ -214,7 +214,7 @@ def load_state_dict(self, state_dict): self.shared_experts.load_state_dict(state_dict) def update_state_dict(self, state_dict): - self.fused_moe.load_state_dict(state_dict, True) + self.experts.load_state_dict(state_dict, True) def split_allgather_out(self, hidden_states: paddle.Tensor, token_num: int): token_num_per_rank = (token_num + self.tensor_parallel_size - 1) // self.tensor_parallel_size diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index fb312db249e..84ed49948e2 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -262,13 +262,21 @@ def update_weights_from_tensor(self, mmap_infos): """ update_weights_from_tensor """ + import time + + while True: + if self.experts_manager.tensor_infos is None: + time.sleep(0.1) + else: + break state_dicts = load_tensor_from_shm_mem(self.experts_manager.tensor_infos, mmap_infos[MODEL_MAIN_NAME], logger) rank_expert_list, logical_to_physical_map, expert_count = self.experts_manager.get_ep_rank_to_expert_id_list() - self.worker.get_model().redundant_table_manger.update_expert_rank_table( + self.worker.get_model().ernie.redundant_table_manger.update_expert_rank_table( rank_expert_list, logical_to_physical_map, expert_count ) # TO BE FIXED - self.worker.get_model().update_state_dict(state_dicts) + self.worker.get_model().ernie.update_state_dict(state_dicts) + self.experts_manager.tensor_infos = None def _broadcast_model_weights_signal(self, src: int, group) -> int: model_weights_signal_tensor = paddle.full(shape=[1], fill_value=self.model_weights_signal[0], dtype="int32") @@ -362,7 +370,9 @@ def event_loop_normal(self) -> None: _, _, _, - ) = self.worker.get_model().redundant_table_manger.get_expert_tokens_stats(clear_stat=clear_stat) + ) = self.worker.get_model().ernie.redundant_table_manger.get_expert_tokens_stats( + clear_stat=clear_stat + ) local_experts_token_stats_array.value[:] = new_stats_array[:] elif local_experts_token_stats_array.value is None: logger.warning("redundant_expert: local_experts_token_stats not init") From 7605371ec1672caa2d5e36d3020aa3d1349418c4 Mon Sep 17 00:00:00 2001 From: xiaoxiaohehe001 Date: Sun, 9 Nov 2025 23:55:05 +0800 Subject: [PATCH 2/2] fix eplb --- fastdeploy/entrypoints/engine_client.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index 6e0f6bd5e02..665cf4c602e 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -581,13 +581,6 @@ async def rearrange_experts(self, request_dict: dict): if "data" not in request_dict or not isinstance(request_dict["data"], list): content = {"code": 1, "msg": "data not in request or data is not a list"} status_code = HTTPStatus.BAD_REQUEST - - elif len(request_dict["data"]) != len(self.expert_tokens_stats_array_list): - content = { - "code": 1, - "msg": f"actual data length {len(request_dict['data'])}, expect length {len(self.expert_tokens_stats_array_list)}", - } - status_code = HTTPStatus.BAD_REQUEST else: weight = np.array(request_dict["data"], dtype=np.int32) for idx in range(len(self.expert_tokens_stats_array_list)):