Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions fastdeploy/entrypoints/engine_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,13 +581,6 @@ async def rearrange_experts(self, request_dict: dict):
if "data" not in request_dict or not isinstance(request_dict["data"], list):
content = {"code": 1, "msg": "data not in request or data is not a list"}
status_code = HTTPStatus.BAD_REQUEST

elif len(request_dict["data"]) != len(self.expert_tokens_stats_array_list):
content = {
"code": 1,
"msg": f"actual data length {len(request_dict['data'])}, expect length {len(self.expert_tokens_stats_array_list)}",
}
status_code = HTTPStatus.BAD_REQUEST
else:
weight = np.array(request_dict["data"], dtype=np.int32)
for idx in range(len(self.expert_tokens_stats_array_list)):
Expand Down
9 changes: 6 additions & 3 deletions fastdeploy/eplb/async_expert_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,9 @@ def load_tensor_from_shm_mem(tensor_infos, shm_ptr, logger=None):
# NumPy 不支持 bfloat16,因此先以 uint16 读取原始数据,再用 Paddle cast 为 bfloat16
tmp = np_array.view(np.uint16)
tensor = paddle.Tensor(tmp, dtype=paddle.bfloat16, place=paddle.CPUPlace(), zero_copy=True)
elif dtype == paddle.float8_e4m3fn:
tmp = np_array.view(np.uint8)
tensor = paddle.Tensor(tmp, dtype=paddle.float8_e4m3fn, place=paddle.CPUPlace(), zero_copy=True)
else:
raise TypeError(f"Unsupported dtype: {dtype}")

Expand Down Expand Up @@ -306,8 +309,8 @@ def load_safetensor_fp8_from_disk(self, need_to_reload: List[Tuple[int, int]]):
"""
up_gate_down = ["up_gate_proj", "down_proj"]
quant_weight_scale = ["quant_weight", "weight_scale"]
if self.moe_quant_type == "w4a8":
quant_weight_scale = ["quant_weight"]
# if self.moe_quant_type == "w4a8":
# quant_weight_scale = ["quant_weight"]
ckpt_name = [
(f"ernie.layers.{layer_id}.mlp.experts.{expert_id}.{proj_name}.{quant_name}")
for layer_id, expert_id in need_to_reload
Expand All @@ -324,7 +327,7 @@ def load_safetensor_fp8_from_disk(self, need_to_reload: List[Tuple[int, int]]):
from safetensors import safe_open

for st_file in hf_weights_files:
with safe_open(st_file, framework="np", device="cpu") as f:
with safe_open(st_file, framework="paddle", device="cpu") as f:
for name in f.keys():
if name in ckpt_name:
weight = f.get_tensor(name)
Expand Down
3 changes: 2 additions & 1 deletion fastdeploy/eplb/experts_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def __init__(
self.http_timeout = 1
# 重置重排状态: 'done' -> 'free'
self.rearrange_end_ts = 0
self.rearrange_reset_interval = 300
self.rearrange_reset_interval = 30

self.tensor_infos = None

Expand Down Expand Up @@ -437,6 +437,7 @@ def allreduce_load_weight_result(self):
# prefill需要等待调度屏蔽
if (
self.fd_config.splitwise_role == "decode"
or self.fd_config.splitwise_role == "mixed"
or not self.eplb_config.redundant_expert_enable_schedule_cordon
):
self.logger.info("redundant_expert: allreduce_load_weight_result success, notify infer.py")
Expand Down
2 changes: 2 additions & 0 deletions fastdeploy/model_executor/layers/moe/ep.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,8 @@ def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor):
expert_in_rank_num_list,
tokens_per_expert_stats_list,
) = layer.redundant_table_manger.get_ep_rank_to_expert_id_list_by_layer(layer.layer_idx)
if layer.is_rearrange is False:
expert_id_to_ep_rank_array = paddle.arange(layer.num_experts).cast("int32")

if layer.topk_method == "noaux_tc":
from .moe import get_moe_scores
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,10 @@ def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange:
# self.check(layer, up_gate_proj_weights, down_proj_weights)
up_gate_proj_weight_scale = []
down_proj_weight_scale = []

if isinstance(state_dict, list):
state_dict = dict(state_dict)

for expert_idx in logical_expert_ids:
up_gate_proj_weight_scale.append(
get_tensor(state_dict.pop(up_gate_proj_expert_weight_scale_key.format(expert_idx)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,10 @@ def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange:
# self.check(layer, up_gate_proj_weights, down_proj_weights)
up_gate_proj_weight_scale = []
down_proj_weight_scale = []

if isinstance(state_dict, list):
state_dict = dict(state_dict)

for expert_idx in logical_expert_ids:
up_gate_proj_expert_weight_scale_key_name = up_gate_proj_expert_weight_scale_key.format(expert_idx)
down_proj_expert_weight_scale_key_name = down_proj_expert_weight_scale_key.format(expert_idx)
Expand Down
3 changes: 2 additions & 1 deletion fastdeploy/model_executor/layers/moe/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def __init__(
else:
self.quant_method = get_moe_method()
self.redundant_table_manger = redundant_table_manger
self.is_rearrange = False
if self.ep_size > 1:
self.quant_method.init_ep(self)

Expand Down Expand Up @@ -397,7 +398,7 @@ def load_experts_weight(
)
]
ep_rank_to_expert_id_list = [i for i in range(self.num_experts)]
if self.redundant_table_manger is not None:
if self.redundant_table_manger is not None and is_rearrange is True:
(
ep_rank_to_expert_id_list,
expert_id_to_ep_rank_array,
Expand Down
2 changes: 1 addition & 1 deletion fastdeploy/model_executor/models/ernie4_5_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def load_state_dict(self, state_dict):
self.shared_experts.load_state_dict(state_dict)

def update_state_dict(self, state_dict):
self.fused_moe.load_state_dict(state_dict, True)
self.experts.load_state_dict(state_dict, True)

def split_allgather_out(self, hidden_states: paddle.Tensor, token_num: int):
token_num_per_rank = (token_num + self.tensor_parallel_size - 1) // self.tensor_parallel_size
Expand Down
16 changes: 13 additions & 3 deletions fastdeploy/worker/worker_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,13 +262,21 @@ def update_weights_from_tensor(self, mmap_infos):
"""
update_weights_from_tensor
"""
import time

while True:
if self.experts_manager.tensor_infos is None:
time.sleep(0.1)
else:
break
state_dicts = load_tensor_from_shm_mem(self.experts_manager.tensor_infos, mmap_infos[MODEL_MAIN_NAME], logger)
rank_expert_list, logical_to_physical_map, expert_count = self.experts_manager.get_ep_rank_to_expert_id_list()
self.worker.get_model().redundant_table_manger.update_expert_rank_table(
self.worker.get_model().ernie.redundant_table_manger.update_expert_rank_table(
rank_expert_list, logical_to_physical_map, expert_count
)
# TO BE FIXED
self.worker.get_model().update_state_dict(state_dicts)
self.worker.get_model().ernie.update_state_dict(state_dicts)
self.experts_manager.tensor_infos = None

def _broadcast_model_weights_signal(self, src: int, group) -> int:
model_weights_signal_tensor = paddle.full(shape=[1], fill_value=self.model_weights_signal[0], dtype="int32")
Expand Down Expand Up @@ -362,7 +370,9 @@ def event_loop_normal(self) -> None:
_,
_,
_,
) = self.worker.get_model().redundant_table_manger.get_expert_tokens_stats(clear_stat=clear_stat)
) = self.worker.get_model().ernie.redundant_table_manger.get_expert_tokens_stats(
clear_stat=clear_stat
)
local_experts_token_stats_array.value[:] = new_stats_array[:]
elif local_experts_token_stats_array.value is None:
logger.warning("redundant_expert: local_experts_token_stats not init")
Expand Down
Loading