| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import TYPE_CHECKING |
|
|
| import torch |
|
|
| from vllm.sampling_params import SamplingParams |
| from vllm.v1.sample.logits_processor import ( |
| AdapterLogitsProcessor, |
| RequestLogitsProcessor, |
| ) |
|
|
| if TYPE_CHECKING: |
| from vllm.config import VllmConfig |
|
|
| |
| TOOL_CALL_END_TOKEN_ID = 32 |
| CALLS_TOKEN_ID = 25 |
|
|
|
|
| class SingleToolCallEnforcer: |
| """Request-level logits processor that enforces single tool call. |
| |
| When <|tool_call:end|> token is generated, forces the next token |
| to be <|calls|> (which is a stop token), preventing parallel tool calls. |
| """ |
|
|
| def __init__( |
| self, |
| tool_call_end_token_id: int, |
| calls_token_id: int, |
| ): |
| self._tool_call_end_token_id = tool_call_end_token_id |
| self._calls_token_id = calls_token_id |
|
|
| def __call__( |
| self, |
| output_token_ids: list[int], |
| logits: torch.Tensor, |
| ) -> torch.Tensor: |
| |
| if output_token_ids and output_token_ids[-1] == self._tool_call_end_token_id: |
| |
| mask = torch.full_like(logits, -float("inf")) |
| mask[self._calls_token_id] = logits[self._calls_token_id] |
| return mask |
|
|
| return logits |
|
|
|
|
| class ParallelToolCallLogitsProcessor(AdapterLogitsProcessor): |
| """Logits processor that enforces single tool call when parallel_tool_calls=False. |
| |
| When parallel_tool_calls is disabled in SamplingParams, this processor |
| ensures that after <|tool_call:end|> is generated, the next token is |
| forced to be <|calls|> (a stop token), preventing multiple tool calls. |
| """ |
|
|
| def __init__( |
| self, |
| vllm_config: "VllmConfig", |
| device: torch.device, |
| is_pin_memory: bool, |
| ): |
| super().__init__(vllm_config, device, is_pin_memory) |
|
|
| def is_argmax_invariant(self) -> bool: |
| """This processor can change argmax result by forcing specific tokens.""" |
| return False |
|
|
| def new_req_logits_processor( |
| self, |
| params: SamplingParams, |
| ) -> RequestLogitsProcessor | None: |
| """Return a request-level logits processor if parallel_tool_calls=False. |
| |
| Args: |
| params: Request sampling params |
| |
| Returns: |
| SingleToolCallEnforcer if parallel_tool_calls is False, otherwise None. |
| """ |
| |
| if params.parallel_tool_calls is False: |
| return SingleToolCallEnforcer( |
| tool_call_end_token_id=TOOL_CALL_END_TOKEN_ID, |
| calls_token_id=CALLS_TOKEN_ID, |
| ) |
|
|
| return None |
|
|
|
|