embedding-m3e-large / api /core /vllm_engine.py
gordonchan's picture
Upload 41 files
ca56e6a verified
import asyncio
from typing import (
Optional,
List,
Dict,
Any,
AsyncIterator,
Union,
)
from fastapi import HTTPException
from loguru import logger
from openai.types.chat import ChatCompletionMessageParam
from transformers import PreTrainedTokenizer
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
from api.adapter import get_prompt_adapter
from api.generation import build_qwen_chat_input
class VllmEngine:
def __init__(
self,
model: AsyncLLMEngine,
tokenizer: PreTrainedTokenizer,
model_name: str,
prompt_name: Optional[str] = None,
context_len: Optional[int] = -1,
):
"""
Initializes the VLLMEngine object.
Args:
model: The AsyncLLMEngine object.
tokenizer: The PreTrainedTokenizer object.
model_name: The name of the model.
prompt_name: The name of the prompt (optional).
context_len: The length of the context (optional, default=-1).
"""
self.model = model
self.model_name = model_name.lower()
self.tokenizer = tokenizer
self.prompt_name = prompt_name.lower() if prompt_name is not None else None
self.prompt_adapter = get_prompt_adapter(self.model_name, prompt_name=self.prompt_name)
model_config = asyncio.run(self.model.get_model_config())
if "qwen" in self.model_name:
self.max_model_len = context_len if context_len > 0 else 8192
else:
self.max_model_len = model_config.max_model_len
def apply_chat_template(
self,
messages: List[ChatCompletionMessageParam],
max_tokens: Optional[int] = 256,
functions: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
) -> Union[str, List[int]]:
"""
Applies a chat template to the given messages and returns the processed output.
Args:
messages: A list of ChatCompletionMessageParam objects representing the chat messages.
max_tokens: The maximum number of tokens in the output (optional, default=256).
functions: A dictionary or list of dictionaries representing the functions to be applied (optional).
tools: A list of dictionaries representing the tools to be used (optional).
Returns:
Union[str, List[int]]: The processed output as a string or a list of integers.
"""
if self.prompt_adapter.function_call_available:
messages = self.prompt_adapter.postprocess_messages(
messages, functions, tools,
)
if functions or tools:
logger.debug(f"==== Messages with tools ====\n{messages}")
if "chatglm3" in self.model_name:
query, role = messages[-1]["content"], messages[-1]["role"]
return self.tokenizer.build_chat_input(
query, history=messages[:-1], role=role
)["input_ids"][0].tolist()
elif "qwen" in self.model_name:
return build_qwen_chat_input(
self.tokenizer,
messages,
self.max_model_len,
max_tokens,
functions,
tools,
)
else:
return self.prompt_adapter.apply_chat_template(messages)
def convert_to_inputs(
self,
prompt: Optional[str] = None,
token_ids: Optional[List[int]] = None,
max_tokens: Optional[int] = 256,
) -> List[int]:
max_input_tokens = self.max_model_len - max_tokens
input_ids = token_ids or self.tokenizer(prompt).input_ids
return input_ids[-max_input_tokens:]
def generate(self, params: Dict[str, Any], request_id: str) -> AsyncIterator:
"""
Generates text based on the given parameters and request ID.
Args:
params (Dict[str, Any]): A dictionary of parameters for text generation.
request_id (str): The ID of the request.
Yields:
Any: The generated text.
"""
max_tokens = params.get("max_tokens", 256)
prompt_or_messages = params.get("prompt_or_messages")
if isinstance(prompt_or_messages, list):
prompt_or_messages = self.apply_chat_template(
prompt_or_messages,
max_tokens,
functions=params.get("functions"),
tools=params.get("tools"),
)
if isinstance(prompt_or_messages, list):
prompt, token_ids = None, prompt_or_messages
else:
prompt, token_ids = prompt_or_messages, None
token_ids = self.convert_to_inputs(prompt, token_ids, max_tokens=max_tokens)
try:
sampling_params = SamplingParams(
n=params.get("n", 1),
presence_penalty=params.get("presence_penalty", 0.),
frequency_penalty=params.get("frequency_penalty", 0.),
temperature=params.get("temperature", 0.9),
top_p=params.get("top_p", 0.8),
stop=params.get("stop", []),
stop_token_ids=params.get("stop_token_ids", []),
max_tokens=params.get("max_tokens", 256),
repetition_penalty=params.get("repetition_penalty", 1.03),
min_p=params.get("min_p", 0.0),
best_of=params.get("best_of", 1),
ignore_eos=params.get("ignore_eos", False),
use_beam_search=params.get("use_beam_search", False),
skip_special_tokens=params.get("skip_special_tokens", True),
spaces_between_special_tokens=params.get("spaces_between_special_tokens", True),
)
result_generator = self.model.generate(
prompt_or_messages if isinstance(prompt_or_messages, str) else None,
sampling_params,
request_id,
token_ids,
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) from e
return result_generator
@property
def stop(self):
"""
Gets the stop property of the prompt adapter.
Returns:
The stop property of the prompt adapter, or None if it does not exist.
"""
return self.prompt_adapter.stop if hasattr(self.prompt_adapter, "stop") else None