Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions fastdeploy/demo/offline_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
from fastdeploy.engine.sampling_params import SamplingParams
from fastdeploy.entrypoints.llm import LLM

model_name_or_path = "./models/llama-7b"
model_name_or_path = "/workspace/ERNIE-4.5-0.3B-Paddle"

# 超参设置
sampling_params = SamplingParams(temperature=0.1, max_tokens=30)
llm = LLM(model=model_name_or_path, tensor_parallel_size=1)
sampling_params = SamplingParams(temperature=0.1, max_tokens=30, prompt_logprobs=100)
llm = LLM(model=model_name_or_path, tensor_parallel_size=1, enable_prefix_caching=False)
output = llm.generate(prompts="who are you?", use_tqdm=True, sampling_params=sampling_params)

print(output)
2 changes: 2 additions & 0 deletions fastdeploy/engine/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def from_optional(
reasoning_max_tokens=None,
min_tokens=1,
logprobs=None,
prompt_logprobs=None,
bad_words=None,
guided_decoding=None,
bad_words_token_ids=None,
Expand All @@ -158,6 +159,7 @@ def from_optional(
reasoning_max_tokens=reasoning_max_tokens,
min_tokens=min_tokens,
logprobs=logprobs,
prompt_logprobs=prompt_logprobs,
bad_words=bad_words,
guided_decoding=guided_decoding,
bad_words_token_ids=bad_words_token_ids,
Expand Down
7 changes: 4 additions & 3 deletions fastdeploy/inter_communicator/zmq_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""

from abc import ABC, abstractmethod
from multiprocessing.reduction import ForkingPickler

import zmq

Expand All @@ -37,7 +38,7 @@ def _create_socket(self):
def _ensure_socket(self):
"""Ensure the socket is created before use."""
if self.socket is None:
self.socket = self._create_socket()
self.socket: zmq.Socket = self._create_socket()

@abstractmethod
def connect(self):
Expand Down Expand Up @@ -65,14 +66,14 @@ def send_pyobj(self, data):
Send a Pickle-serializable object over the socket.
"""
self._ensure_socket()
self.socket.send_pyobj(data)
self.socket.send(ForkingPickler.dumps(data), copy=False)

def recv_pyobj(self):
"""
Receive a Pickle-serializable object from the socket.
"""
self._ensure_socket()
return self.socket.recv_pyobj()
return ForkingPickler.loads(self.socket.recv())

@abstractmethod
def close(self):
Expand Down
9 changes: 5 additions & 4 deletions fastdeploy/inter_communicator/zmq_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import time
from abc import ABC, abstractmethod
from collections import defaultdict
from multiprocessing.reduction import ForkingPickler

import msgpack
import zmq
Expand All @@ -44,7 +45,7 @@ def _create_socket(self):
def _ensure_socket(self):
"""Ensure the socket is created before use."""
if self.socket is None:
self.socket = self._create_socket()
self.socket: zmq.Socket = self._create_socket()

def send_json(self, data):
"""
Expand All @@ -65,14 +66,14 @@ def send_pyobj(self, data):
Send a Pickle-serializable object over the socket.
"""
self._ensure_socket()
self.socket.send_pyobj(data)
self.socket.send(ForkingPickler.dumps(data), copy=False)

def recv_pyobj(self):
"""
Receive a Pickle-serializable object from the socket.
"""
self._ensure_socket()
return self.socket.recv_pyobj()
return ForkingPickler.loads(self.socket.recv())

def pack_aggregated_data(self, data):
"""
Expand Down Expand Up @@ -111,7 +112,7 @@ def receive_pyobj_once(self, block=False):
return "zmp socket has closed", None
try:
flags = zmq.NOBLOCK if not block else 0
return None, self.socket.recv_pyobj(flags=flags)
return None, ForkingPickler.loads(self.socket.recv(flags=flags))
except zmq.Again:
return None, None
except Exception as e:
Expand Down
4 changes: 3 additions & 1 deletion fastdeploy/model_executor/layers/sample/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,9 @@ def gather_logprobs(
else:
indices = token_ids
top_logprobs = token_logprobs

indices = indices.cpu()
top_logprobs = top_logprobs.cpu()
token_ranks = token_ranks.cpu()
return LogprobsTensors(indices, top_logprobs, token_ranks)

def forward_cuda(
Expand Down
5 changes: 2 additions & 3 deletions fastdeploy/output/token_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,6 @@ def process_sampling_results_use_zmq(self):
"""
if self.speculative_decoding:
raise NotImplementedError("GET_SAVE_OUTPUT_V1 does not support speculative decoding")
if self.use_logprobs:
raise NotImplementedError("GET_SAVE_OUTPUT_V1 does not support use_logprobs")
rank_id = self.cfg.parallel_config.local_data_parallel_id
while True:
try:
Expand All @@ -316,7 +314,8 @@ def process_sampling_results_use_zmq(self):
) or (rank_id == 0):
receive_datas = self.zmq_server.recv_pyobj()
assert isinstance(receive_datas, list)
llm_logger.debug(f"token_processor receive_data {receive_datas}")
if envs.FD_DEBUG:
llm_logger.debug(f"token_processor receive_data {receive_datas}")

self._reschedule_preempt_task_use_zmq(receive_datas)

Expand Down
2 changes: 1 addition & 1 deletion fastdeploy/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2732,7 +2732,7 @@ def _get_prompt_logprobs_list(

logprobs_tensors = self.in_progress_prompt_logprobs.get(req_id)
if not logprobs_tensors:
logprobs_tensors = LogprobsTensors.empty(num_prompt_tokens - 1, num_prompt_logprobs + 1)
logprobs_tensors = LogprobsTensors.empty_cpu(num_prompt_tokens - 1, num_prompt_logprobs + 1)
self.in_progress_prompt_logprobs[req_id] = logprobs_tensors
start_idx = request.prefill_start_index
start_tok = start_idx + 1
Expand Down
6 changes: 3 additions & 3 deletions fastdeploy/worker/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ def tolists(self):
def empty_cpu(num_positions: int, num_tokens_per_position: int) -> "LogprobsTensors":
"""Create empty LogprobsTensors on CPU."""

logprob_token_ids = paddle.empty([num_positions, num_tokens_per_position], dtype=paddle.int64).cpu()
logprobs = paddle.empty_like(logprob_token_ids, dtype=paddle.float32).cpu()
selected_token_ranks = paddle.empty([num_positions], dtype=paddle.int64).cpu()
logprob_token_ids = paddle.empty([num_positions, num_tokens_per_position], device="cpu", dtype=paddle.int64)
logprobs = paddle.empty_like(logprob_token_ids, device="cpu", dtype=paddle.float32)
selected_token_ranks = paddle.empty([num_positions], device="cpu", dtype=paddle.int64)
return LogprobsTensors(
logprob_token_ids=logprob_token_ids,
logprobs=logprobs,
Expand Down
Loading