Skip to content

Conversation

@ST-XX
Copy link
Collaborator

@ST-XX ST-XX commented Nov 5, 2025

Motivation

Optimize the integration of xgrammar by introducing asynchronous compilation and native caching. Improve efficiency for CUDA platforms with inplace operations and DLPack interconversion, and remove redundant backend caching logic.
优化 xgrammar 的集成,采用异步编译与原生缓存机制,提升 CUDA 平台效率,同时去除冗余的后端缓存逻辑。

Modifications

  • Refactored xgrammar to use asynchronous compile and implemented native caching.

  • Removed caching from the backend to avoid duplication.

  • Triggered xgrammar compilation during the Prefill stage, and joined the compile result before sampling the first token in decode.

  • For CUDA platforms:

    • Used DLPack as an intermediate format for conversion between paddle.Tensor and torch.Tensor.
    • Leveraged CUDA hardware for inplace acceleration of xgr.apply_token_bitmask_inplace.
    • Removed previous GPU to CPU numpy conversion.
  • Other platforms retain existing logic.

  • xgrammar 改为异步 compile 并实现了原生缓存机制。

  • 去掉了 backend 中的缓存,避免重复。

  • Prefill 阶段发起 xgrammar 编译,在 decode 第一个 token 的 sampler 之前 join 编译结果。

  • 对于 CUDA 平台:

    • 使用 DLPack 作为 paddle.Tensor 和 torch.Tensor 之间的中间转换格式。
    • 利用 CUDA 硬件进行 inplace 加速 xgr.apply_token_bitmask_inplace
    • 移除了之前 GPU 到 CPU 的 numpy 转换。
  • 其他平台维持原有逻辑。

Usage or Command

import openai

port = "8170"
client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key="null")

completion = client.chat.completions.create(
    model="null",
    messages=[
        {
            "role": "user",
            "content": "Generate a JSON object containing: names of China's Four Great Inventions, their dynasties of origin, and brief descriptions (each under 50 characters)",
        }
    ],
    response_format={"type": "json_object"}
)
print(completion.choices[0].message.content)

Accuracy Tests

预期会影响结果,以符合受限解码的结构化要求。

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

@paddle-bot
Copy link

paddle-bot bot commented Nov 5, 2025

Thanks for your contribution!

def accept_token(self, token: int) -> None:
"""
Validate and accept a generated token against the grammar constraints.
when accept eos_token, is_terminated = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里在哪里判断的eos_token啊?输出超长的场景怎么处理的?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eos accept 之后,matcher 的状态就是is_terminated,下面就会被重置掉了。后面输出的 token 不会再限制格式。开 ignore_eos 之后也可以继续生成。

logits = torch.from_numpy(logits.numpy())

logits = logits.float() # cpu
apply_token_bitmask_inplace(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个算子在多硬件上好像没有验证过?不确定能不能用

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是纯 cpu 操作。 bitmask=token_bitmask.to(logits.device, non_blocking=True),
这个逻辑有点误导,实际 to 的还是 cpu

if current_platform.is_cuda():
dlpack = paddle.utils.dlpack.to_dlpack(logits)
t_logits = torch.from_dlpack(dlpack)
apply_token_bitmask_inplace(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个算子是支持paddle.tensor 的吧,为什么还要转torch.tensor 呢

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里还是原生的 xgr. apply_token_bitmask_inplace 接口,只支持 tensor.Tensor

"""update vocab mask. (cpu-heavy operation)"""
if len(self.logits_processor) == 0:
"""add logits processor to SamplerProcessor"""
assert len(prefill_tokens) == 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PD分离场景下,prefill_tokens非空?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PD 分离场景还没验证。这里会 assert 挂掉

processor.fill_token_bitmask(self.token_bitmask, idx)

def apply_token_mask(self, logits: paddle.Tensor, skip_idx_list: List[int] = []):
def apply_token_mask(self, logits: paddle.Tensor, prefill_done_idxs: List[int] = []):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

decode step间的异步是不是还没有加?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

多硬件的场景还没验证过。如果要支持,优先支持xpu吧

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里因为接口变了,必须得同步修改。xpu ci 过了。

@ST-XX ST-XX requested a review from kevincheng2 November 10, 2025 05:13
yuanlehome
yuanlehome previously approved these changes Nov 12, 2025
kevincheng2
kevincheng2 previously approved these changes Nov 13, 2025
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR optimizes the xgrammar integration for guided decoding by introducing asynchronous compilation, native caching, and CUDA-specific optimizations using DLPack for efficient tensor conversion.

Key Changes:

  • Implemented async compilation for xgrammar with native caching in the compiler, removing redundant backend-level caching
  • Added CUDA-optimized path using DLPack for zero-copy tensor conversion between Paddle and PyTorch
  • Refactored skip_idx_list to prefill_done_idxs across model runners for clearer semantics when handling chunked prefill with guided decoding

Reviewed Changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 26 comments.

Show a summary per file
File Description
fastdeploy/model_executor/layers/sample/sampler.py Refactored GuidedDecoding class to handle async Future-based processors, changed from dict to list-based storage, and updated token mask application logic
fastdeploy/model_executor/guided_decoding/xgrammar_backend.py Modified XGrammarProcessor to use is_terminated property instead of method, updated accept_token to return boolean and reset on failure, added CUDA-optimized apply_token_mask using DLPack, and configured compiler with native caching
fastdeploy/model_executor/guided_decoding/base_guided_decoding.py Removed backend-level caching logic, changed get_logits_processor return type to Future, increased ThreadPoolExecutor workers, and switched to fast tokenizer
fastdeploy/worker/gpu_model_runner.py Renamed _get_skip_idx to _get_p_done_idxs_gd with updated logic for prefill completion detection, removed _add_cache method, and updated sampler pre/post-process calls
fastdeploy/worker/metax_model_runner.py Applied same refactoring as gpu_model_runner: renamed skip index method, removed caching logic, and updated sampler integration
fastdeploy/worker/hpu_model_runner.py Removed _get_skip_idx and _add_cache methods, simplified guided decoding integration
fastdeploy/worker/gcu_model_runner.py Applied same refactoring as other model runners with renamed methods and removed caching
fastdeploy/scheduler/splitwise_scheduler.py Added Redis version check to ensure 6.2+ compatibility for batch RPOP operations
tests/layers/test_sampler.py Updated tests to pass FDConfig to Sampler constructor and added configuration building utilities

# skip, join at apply_token_mask
if isinstance(processor, Future):
continue
if processor.is_terminated:
Copy link

Copilot AI Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistent property vs method call: On line 149, is_terminated is accessed as a property (processor.is_terminated), but in the original code it was called as a method (is_terminated()). However, on line 227, it's also accessed as a property. In the XGrammarProcessor class (xgrammar_backend.py line 91), is_terminated is defined as a boolean attribute, not a method. This is a breaking change from the original API where it was a method. Ensure all usages are updated correctly and this change is intentional.

Copilot uses AI. Check for mistakes.
Comment on lines 124 to +141
def accept_token(self, token: int) -> None:
"""
Validate and accept a generated token against the grammar constraints.
when accept eos_token, is_terminated = True
Args:
token (int): The token ID to validate
Raises:
AssertionError: If token is not allowed by the grammar
"""
assert self.matcher.accept_token(token), f"Failed to accept token {token}"

def is_terminated(self) -> bool:
"""
Check if the grammar matching process has terminated.
Returns:
bool: True if matching has terminated, False otherwise
"""
return self.matcher.is_terminated()
if self.is_terminated or self.matcher.is_terminated():
self.is_terminated = True
return False
if not self.matcher.accept_token(token):
self.matcher.reset()
return False
if self.matcher.is_terminated():
self.is_terminated = True
return True
Copy link

Copilot AI Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The accept_token method's behavior has changed significantly. In the original implementation, it raised an assertion error if the token couldn't be accepted. Now it returns False and resets the matcher (line 137). This is a breaking API change that could silently fail where it previously would have raised an error. The docstring on line 126-127 still mentions "when accept eos_token, is_terminated = True" but doesn't document the return value or the new reset behavior.

Copilot uses AI. Check for mistakes.
Comment on lines +187 to +192
self.grammar_compiler = GrammarCompiler(
tokenizer_info=tokenizer_info,
max_threads=8,
cache_enabled=True,
cache_limit_bytes=4 * 1024 * 1024,
) # TODO cfg
Copy link

Copilot AI Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hardcoded configuration values that should be configurable. On lines 189-192, max_threads=8, cache_enabled=True, and cache_limit_bytes=4 * 1024 * 1024 are hardcoded with a TODO comment suggesting they should be configurable. These parameters significantly affect performance and should be exposed through the configuration system.

Suggested change
self.grammar_compiler = GrammarCompiler(
tokenizer_info=tokenizer_info,
max_threads=8,
cache_enabled=True,
cache_limit_bytes=4 * 1024 * 1024,
) # TODO cfg
# Read configuration values, fallback to defaults if not set
xgrammar_cfg = getattr(fd_config, "xgrammar_config", {})
max_threads = getattr(xgrammar_cfg, "max_threads", 8)
cache_enabled = getattr(xgrammar_cfg, "cache_enabled", True)
cache_limit_bytes = getattr(xgrammar_cfg, "cache_limit_bytes", 4 * 1024 * 1024)
self.grammar_compiler = GrammarCompiler(
tokenizer_info=tokenizer_info,
max_threads=max_threads,
cache_enabled=cache_enabled,
cache_limit_bytes=cache_limit_bytes,
)

Copilot uses AI. Check for mistakes.
self.cache = {}
self.fd_config = fd_config
self.executor = ThreadPoolExecutor()
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
Copy link

Copilot AI Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic for determining max_workers using (multiprocessing.cpu_count() + 1) // 2 seems arbitrary without documentation. Consider adding a comment explaining why half the CPU count plus one is chosen, or make this configurable through the FDConfig.

Suggested change
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
# Determine max_workers for ThreadPoolExecutor.
# Default is half the CPU count plus one, to balance concurrency and avoid oversubscription.
# This can be overridden by setting 'max_workers' in FDConfig.
max_workers = getattr(self.fd_config, "max_workers", None)
if max_workers is None:
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)

Copilot uses AI. Check for mistakes.
Comment on lines +463 to +471
dlpack = paddle.utils.dlpack.to_dlpack(logits)
t_logits = torch.from_dlpack(dlpack)
apply_token_bitmask_inplace(
logits=t_logits,
bitmask=token_bitmask.to(t_logits.device, non_blocking=True),
indices=indices,
)
dlpack2 = torch.utils.dlpack.to_dlpack(t_logits)
return paddle.utils.dlpack.from_dlpack(dlpack2)
Copy link

Copilot AI Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential memory safety issue with DLPack conversion. The code converts logits from Paddle to DLPack (line 463), then to PyTorch (line 464), performs in-place modification (line 465-469), converts back to DLPack (line 470), and finally back to Paddle (line 471). However, there's no guarantee that the Paddle tensor remains valid after the first DLPack conversion, and modifying through PyTorch could lead to undefined behavior if Paddle has already deallocated or moved the underlying memory. Consider adding documentation about the lifetime guarantees or testing this carefully with different tensor configurations.

Copilot uses AI. Check for mistakes.
try:
tokenizer_info = TokenizerInfo.from_huggingface(self.hf_tokenizer, vocab_size=self.vocab_size)
self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
llm_logger.info(f"xgrammar_backend.py tokenzer_info={tokenizer_info.dump_metadata()}")
Copy link

Copilot AI Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tokenzer_info in the log message contains a spelling error. It should be tokenizer_info.

Suggested change
llm_logger.info(f"xgrammar_backend.py tokenzer_info={tokenizer_info.dump_metadata()}")
llm_logger.info(f"xgrammar_backend.py tokenizer_info={tokenizer_info.dump_metadata()}")

Copilot uses AI. Check for mistakes.
Comment on lines +94 to +100
tmp_dir = f"./tmpefef{paddle.distributed.get_rank()}"
os.makedirs(tmp_dir, exist_ok=True)
with open(f"./{tmp_dir}/config.json", "w") as f:
json.dump(config_dict, f)
model_name_or_path = os.path.join(os.getcwd(), tmp_dir)
print("model_name_or_path", model_name_or_path)
return model_name_or_path
Copy link

Copilot AI Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test creates a temporary directory with a hardcoded name including paddle.distributed.get_rank(), which will create directories like ./tmpefefGET_RANK_VALUE. This could cause issues if multiple tests run in parallel or if cleanup fails. Consider using Python's tempfile.mkdtemp() for safer temporary directory creation and proper cleanup in a teardown method.

Copilot uses AI. Check for mistakes.
fd_config = get_fd_config(batch_size)
fd_config.model_config.logprobs_mode = logprobs_mode
sampler = Sampler(logprobs_mode=logprobs_mode, fd_config=fd_config)
assert sampler.logprobs_mode == logprobs_mode
Copy link

Copilot AI Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The assertion on line 209 assert sampler.logprobs_mode == logprobs_mode is testing the exact same value that was just set on line 208. This test is redundant since it's just verifying assignment works, which is a Python language feature. Consider removing this assertion or testing something more meaningful about the sampler's behavior with different logprobs modes.

Suggested change
assert sampler.logprobs_mode == logprobs_mode

Copilot uses AI. Check for mistakes.
Comment on lines +220 to +225
if self.reasoning_parser is not None:
if not self.logits_processors[idx].enable_reasoning:
if not self.logits_processors[idx].reasoning_ended:
reasoning_ended = self.reasoning_parser.is_reasoning_end([token])
self.logits_processors[idx].reasoning_ended = reasoning_ended
return
Copy link

Copilot AI Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reasoning parser logic appears inverted. Line 221 checks if not self.logits_processors[idx].enable_reasoning: (reasoning is disabled), but then lines 222-225 check if reasoning has ended and call is_reasoning_end. This logic should only execute when reasoning IS enabled (if self.logits_processors[idx].enable_reasoning:), not when it's disabled. The current implementation would skip reasoning end detection for processors that have reasoning enabled.

Copilot uses AI. Check for mistakes.
start backup threads
"""

# check redis version first
Copy link

Copilot AI Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment "check redis version first" is inconsistent with the capitalization style used elsewhere in the code. It should follow the same style as other comments. Consider: "Check Redis version first".

Suggested change
# check redis version first
# Check Redis version first

Copilot uses AI. Check for mistakes.
@ST-XX ST-XX dismissed stale reviews from yuanlehome and kevincheng2 via 037157c November 13, 2025 11:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants