Spaces:
Running
Running
import asyncio | |
from typing import ( | |
Optional, | |
List, | |
Dict, | |
Any, | |
AsyncIterator, | |
Union, | |
) | |
from fastapi import HTTPException | |
from loguru import logger | |
from openai.types.chat import ChatCompletionMessageParam | |
from transformers import PreTrainedTokenizer | |
from vllm.engine.async_llm_engine import AsyncLLMEngine | |
from vllm.sampling_params import SamplingParams | |
from api.adapter import get_prompt_adapter | |
from api.generation import build_qwen_chat_input | |
class VllmEngine: | |
def __init__( | |
self, | |
model: AsyncLLMEngine, | |
tokenizer: PreTrainedTokenizer, | |
model_name: str, | |
prompt_name: Optional[str] = None, | |
context_len: Optional[int] = -1, | |
): | |
""" | |
Initializes the VLLMEngine object. | |
Args: | |
model: The AsyncLLMEngine object. | |
tokenizer: The PreTrainedTokenizer object. | |
model_name: The name of the model. | |
prompt_name: The name of the prompt (optional). | |
context_len: The length of the context (optional, default=-1). | |
""" | |
self.model = model | |
self.model_name = model_name.lower() | |
self.tokenizer = tokenizer | |
self.prompt_name = prompt_name.lower() if prompt_name is not None else None | |
self.prompt_adapter = get_prompt_adapter(self.model_name, prompt_name=self.prompt_name) | |
model_config = asyncio.run(self.model.get_model_config()) | |
if "qwen" in self.model_name: | |
self.max_model_len = context_len if context_len > 0 else 8192 | |
else: | |
self.max_model_len = model_config.max_model_len | |
def apply_chat_template( | |
self, | |
messages: List[ChatCompletionMessageParam], | |
max_tokens: Optional[int] = 256, | |
functions: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, | |
tools: Optional[List[Dict[str, Any]]] = None, | |
) -> Union[str, List[int]]: | |
""" | |
Applies a chat template to the given messages and returns the processed output. | |
Args: | |
messages: A list of ChatCompletionMessageParam objects representing the chat messages. | |
max_tokens: The maximum number of tokens in the output (optional, default=256). | |
functions: A dictionary or list of dictionaries representing the functions to be applied (optional). | |
tools: A list of dictionaries representing the tools to be used (optional). | |
Returns: | |
Union[str, List[int]]: The processed output as a string or a list of integers. | |
""" | |
if self.prompt_adapter.function_call_available: | |
messages = self.prompt_adapter.postprocess_messages( | |
messages, functions, tools, | |
) | |
if functions or tools: | |
logger.debug(f"==== Messages with tools ====\n{messages}") | |
if "chatglm3" in self.model_name: | |
query, role = messages[-1]["content"], messages[-1]["role"] | |
return self.tokenizer.build_chat_input( | |
query, history=messages[:-1], role=role | |
)["input_ids"][0].tolist() | |
elif "qwen" in self.model_name: | |
return build_qwen_chat_input( | |
self.tokenizer, | |
messages, | |
self.max_model_len, | |
max_tokens, | |
functions, | |
tools, | |
) | |
else: | |
return self.prompt_adapter.apply_chat_template(messages) | |
def convert_to_inputs( | |
self, | |
prompt: Optional[str] = None, | |
token_ids: Optional[List[int]] = None, | |
max_tokens: Optional[int] = 256, | |
) -> List[int]: | |
max_input_tokens = self.max_model_len - max_tokens | |
input_ids = token_ids or self.tokenizer(prompt).input_ids | |
return input_ids[-max_input_tokens:] | |
def generate(self, params: Dict[str, Any], request_id: str) -> AsyncIterator: | |
""" | |
Generates text based on the given parameters and request ID. | |
Args: | |
params (Dict[str, Any]): A dictionary of parameters for text generation. | |
request_id (str): The ID of the request. | |
Yields: | |
Any: The generated text. | |
""" | |
max_tokens = params.get("max_tokens", 256) | |
prompt_or_messages = params.get("prompt_or_messages") | |
if isinstance(prompt_or_messages, list): | |
prompt_or_messages = self.apply_chat_template( | |
prompt_or_messages, | |
max_tokens, | |
functions=params.get("functions"), | |
tools=params.get("tools"), | |
) | |
if isinstance(prompt_or_messages, list): | |
prompt, token_ids = None, prompt_or_messages | |
else: | |
prompt, token_ids = prompt_or_messages, None | |
token_ids = self.convert_to_inputs(prompt, token_ids, max_tokens=max_tokens) | |
try: | |
sampling_params = SamplingParams( | |
n=params.get("n", 1), | |
presence_penalty=params.get("presence_penalty", 0.), | |
frequency_penalty=params.get("frequency_penalty", 0.), | |
temperature=params.get("temperature", 0.9), | |
top_p=params.get("top_p", 0.8), | |
stop=params.get("stop", []), | |
stop_token_ids=params.get("stop_token_ids", []), | |
max_tokens=params.get("max_tokens", 256), | |
repetition_penalty=params.get("repetition_penalty", 1.03), | |
min_p=params.get("min_p", 0.0), | |
best_of=params.get("best_of", 1), | |
ignore_eos=params.get("ignore_eos", False), | |
use_beam_search=params.get("use_beam_search", False), | |
skip_special_tokens=params.get("skip_special_tokens", True), | |
spaces_between_special_tokens=params.get("spaces_between_special_tokens", True), | |
) | |
result_generator = self.model.generate( | |
prompt_or_messages if isinstance(prompt_or_messages, str) else None, | |
sampling_params, | |
request_id, | |
token_ids, | |
) | |
except ValueError as e: | |
raise HTTPException(status_code=400, detail=str(e)) from e | |
return result_generator | |
def stop(self): | |
""" | |
Gets the stop property of the prompt adapter. | |
Returns: | |
The stop property of the prompt adapter, or None if it does not exist. | |
""" | |
return self.prompt_adapter.stop if hasattr(self.prompt_adapter, "stop") else None | |