| | import os |
| | from enum import Enum |
| | from typing import Optional, Dict, Any, List, Union |
| | from vllm import LLM, SamplingParams |
| | from vllm.outputs import RequestOutput |
| | from transformers import AutoTokenizer |
| |
|
| | DEFAULT_MAX_TOKENS = 16000 |
| |
|
| | class ModelType(Enum): |
| | BASE = "base" |
| | INSTRUCT = "instruct" |
| |
|
| | class VLLMClient: |
| | def __init__(self, |
| | model_path: str): |
| | |
| | self.model_path = model_path |
| | self.model_type = self._detect_model_type(model_path) |
| | self.llm = LLM(model=model_path) |
| | |
| | |
| | self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
| |
|
| | @staticmethod |
| | def _detect_model_type(model_path: str) -> ModelType: |
| |
|
| | model_path_lower = model_path.lower() |
| | instruct_keywords = ['instruct', 'chat', 'dialogue', 'conversations', 'kista'] |
| | |
| | |
| | is_instruct = any(keyword in model_path_lower for keyword in instruct_keywords) |
| | return ModelType.INSTRUCT if is_instruct else ModelType.BASE |
| |
|
| | def _format_base_prompt(self, system: Optional[str], content: str) -> str: |
| | """ |
| | Format prompt for base models including system prompt. |
| | """ |
| | if system: |
| | |
| | return f"{system} {content}" |
| | return content |
| |
|
| | def _format_instruct_prompt(self, system: Optional[str], content: str) -> str: |
| | """ |
| | Format prompt for instruct models using the model's chat template. |
| | """ |
| | messages = [] |
| | if system: |
| | messages.append({"role": "system", "content": system}) |
| | messages.append({"role": "user", "content": content}) |
| | |
| | return self.tokenizer.apply_chat_template( |
| | messages, |
| | tokenize=False, |
| | add_generation_prompt=True |
| | ) |
| |
|
| | def _create_message_payload(self, |
| | system: Optional[str], |
| | content: str, |
| | max_tokens: int, |
| | temperature: float) -> Dict[str, Any]: |
| | """ |
| | Create the sampling parameters and format the prompt based on model type. |
| | """ |
| | if self.model_type == ModelType.BASE: |
| | formatted_prompt = self._format_base_prompt(system, content) |
| | else: |
| | formatted_prompt = self._format_instruct_prompt(system, content) |
| |
|
| | sampling_params = SamplingParams( |
| | max_tokens=max_tokens, |
| | temperature=temperature, |
| | top_p=0.95, |
| | presence_penalty=0.0, |
| | frequency_penalty=0.0, |
| | ) |
| | |
| | return { |
| | "prompt": formatted_prompt, |
| | "sampling_params": sampling_params |
| | } |
| |
|
| | |
| | def send_message(self, |
| | content: str, |
| | system: Optional[str] = None, |
| | max_tokens: int = 1000, |
| | temperature: float = 0) -> Dict[str, Any]: |
| | """ |
| | Send a message to the model and get a response. |
| | |
| | Args: |
| | content: User message or raw prompt |
| | system: System prompt (supported for both base and instruct models) |
| | max_tokens: Maximum number of tokens to generate |
| | temperature: Sampling temperature |
| | json_eval: Whether to parse the response as JSON |
| | |
| | Returns: |
| | Dictionary containing status and result/error |
| | """ |
| | try: |
| | payload = self._create_message_payload( |
| | system=system, |
| | content=content, |
| | max_tokens=max_tokens, |
| | temperature=temperature |
| | ) |
| |
|
| | outputs = self.llm.generate( |
| | prompts=[payload["prompt"]], |
| | sampling_params=payload["sampling_params"] |
| | ) |
| |
|
| | try: |
| | result_text = outputs[0].outputs[0].text.strip() |
| | result = result_text |
| | return {'status': True, 'result': result} |
| | except Exception as e: |
| | return {'status': True, 'result': outputs} |
| | |
| | except Exception as e: |
| | return {'status': False, 'error': str(e)} |
| |
|