| from typing import List, Dict, Any, Optional |
| import requests |
| import time |
| import aiohttp |
| import asyncio |
| import numpy as np |
| from tqdm.asyncio import tqdm |
| from .base_model import BaseModel |
|
|
| class VLLMClient(BaseModel): |
| """ |
| Wrapper class for VLLM OpenAI-Compatible API, supporting aiohttp asynchronous batch requests. |
| """ |
| DEFAULT_API_URL = "http://127.0.0.1:8000/v1/chat/completions" |
| DEFAULT_TIMEOUT = 600 |
|
|
| def __init__( |
| self, |
| model_name: str, |
| model_path: str = "", |
| max_tokens: int = 8192, |
| temperature: float = 0.7, |
| repeat_penalty: float = 0.2, |
| api_url: Optional[str] = None, |
| system_prompt: str = None, |
| max_concurrent_requests = 20 |
| ) -> None: |
| """ |
| Initialize VLLM client. |
| |
| :param model_name: Model name for the "model" field in API requests, optional. |
| :param api_url: Complete URL of VLLM API server. |
| """ |
| self.model_name = model_name |
| self.api_url = api_url if api_url else self.DEFAULT_API_URL |
| self.default_max_tokens = max_tokens |
| self.default_temperature = temperature |
| if system_prompt is not None: |
| self.system_message: Dict[str, str] = { |
| "role": "system", |
| "content": system_prompt |
| } |
| else: |
| self.system_message = None |
| self.max_concurrent_requests = max_concurrent_requests |
|
|
| def load_model(self): |
| self.headers = {"Content-Type": "application/json"} |
| self.check_vllm_service(self.api_url) |
| |
|
|
| def check_vllm_service(self, api_url: str) -> bool: |
| """ |
| Check if VLLM service is running normally |
| Args: |
| api_url: Base URL of VLLM service (e.g., http://localhost:8000/v1/chat/completions) |
| |
| Returns: |
| True if service responds normally within 5 minutes, False otherwise |
| """ |
| |
| check_url = api_url.replace("v1/chat/completions", "v1/models") |
|
|
| total_timeout = 1200 |
| retry_interval = 10 |
| max_retries = total_timeout // retry_interval |
| |
| for _ in range(max_retries): |
| try: |
| |
| response = requests.get(check_url, timeout=5) |
| |
| if response.status_code == 200: |
| print("VLLM service started successfully") |
| return True |
| except (requests.exceptions.ConnectionError, |
| requests.exceptions.Timeout, |
| requests.exceptions.RequestException): |
| pass |
| |
| |
| time.sleep(retry_interval) |
| print(f"Connecting to VLLM Serving: {check_url}") |
| |
| |
| raise ValueError("Failed to connect to VLLM service") |
|
|
| def _build_conversation(self, query_message: Dict) -> List[Dict]: |
| """Build complete conversation list including System Prompt and User Message.""" |
|
|
| user_message = {"role": "user", "content": []} |
| for content in query_message["content"]: |
| if content["type"] == "text": |
| user_message["content"].append(content) |
| elif content["type"] == "image": |
| user_message["content"].append({"type": "image_url", "image_url": {"url": "file://"+content["image"]}}) |
| elif content["type"] == "audio": |
| user_message["content"].append({"type": "audio_url", "audio_url": {"url": "file://"+content["audio"]}}) |
| elif content["type"] == "video": |
| user_message["content"].append({"type": "video_url", "video_url": {"url": "file://"+content["video"]}}) |
| else: |
| raise ValueError(f"Unknown content type: {content['type']}") |
| |
| full_message = [] |
| if self.system_message is not None: |
| full_message = [self.system_message.copy(), user_message] |
| else: |
| full_message = [user_message] |
| return full_message |
|
|
| async def _async_call_api( |
| self, |
| session: aiohttp.ClientSession, |
| user_message: Dict, |
| message_idx: int, |
| timeout: int = DEFAULT_TIMEOUT |
| ) -> tuple[int, Any, Optional[str]]: |
| """ |
| Send single API request asynchronously. |
| |
| Returns (index, model_text, error_message). |
| """ |
| conversation = self._build_conversation(user_message) |
| |
| data = { |
| |
| "messages": conversation, |
| "max_tokens": self.default_max_tokens, |
| "temperature": self.default_temperature |
| } |
| |
| try: |
| |
| async with session.post( |
| self.api_url, |
| headers=self.headers, |
| json=data, |
| timeout=timeout |
| ) as response: |
| |
| if response.status != 200: |
| error_text = await response.text() |
| error_msg = f"🚨 [{message_idx}] API Request failed with status {response.status}. Error: {error_text[:200]}..." |
| print(error_msg) |
| return message_idx, None, error_msg |
| |
| response_json = await response.json() |
| |
| |
| if response_json and response_json.get("choices"): |
| response_text = response_json["choices"][0]["message"]["content"] |
| |
| return message_idx, response_text, None |
| else: |
| error_msg = f"❌ [{message_idx}] API response format error." |
| print(error_msg) |
| return message_idx, None, error_msg |
| |
|
|
| except asyncio.TimeoutError: |
| error_msg = f"⏱️ [{message_idx}] API Request timed out after {timeout} seconds." |
| print(error_msg) |
| return message_idx, None, error_msg |
| except Exception as e: |
| error_msg = f"❌ [{message_idx}] An unexpected error occurred: {e}. Data: {user_message['content'][:50]}..." |
| print(error_msg) |
| return message_idx, None, error_msg |
|
|
| async def generate_batch( |
| self, |
| messages: List[Dict], |
| show_progress: bool = True, |
| progress_desc: str = "Processing" |
| ) -> List[Any]: |
| """ |
| Send batch requests using aiohttp async concurrency with optional progress bar. |
| |
| :param messages: List of user messages. |
| :param show_progress: Whether to show progress bar (default: True). |
| :param progress_desc: Description text for progress bar (default: "Processing"). |
| :return: Result list in original order (containing generated text or None). |
| """ |
| |
| all_results = [] |
| |
| |
| pbar = tqdm(total=len(messages), desc=progress_desc, disable=not show_progress) |
| |
| async with aiohttp.ClientSession() as session: |
| |
| for batch_start in range(0, len(messages), self.max_concurrent_requests): |
| batch_end = min(batch_start + self.max_concurrent_requests, len(messages)) |
| batch_messages = messages[batch_start:batch_end] |
| |
| |
| tasks = [ |
| self._async_call_api(session, msg, idx) |
| for idx, msg in enumerate(batch_messages, start=batch_start) |
| ] |
| |
| |
| batch_results = await asyncio.gather(*tasks) |
| |
| all_results.extend(batch_results) |
| |
| |
| if show_progress: |
| pbar.update(len(batch_results)) |
| |
| pbar.close() |
| |
| |
| sorted_results = sorted(all_results, key=lambda x: x[0]) |
| |
| |
| final_outputs = [res[1] for res in sorted_results] |
| return final_outputs |
| |
| def generate(self, message: Dict) -> str: |
| """ |
| Synchronous call for single request. |
| |
| Note: Running async code in class requires asyncio.run(), not recommended for library code abuse. |
| """ |
| print("Warning: Synchronous call to 'generate' method, recommend using '_async_call_api' or 'generate_batch' directly.") |
| |
| async def run_single(): |
| async with aiohttp.ClientSession() as session: |
| |
| _, text_output, _ = await self._async_call_api(session, message, 0) |
| return text_output |
|
|
| return asyncio.run(run_single()) |
|
|
|
|
| |
|
|
| if __name__ == '__main__': |
| vllm_client = VLLMClient( |
| model_name="qwen-2.5-omni-7b", |
| api_url="http://127.0.0.1:8000/v1/chat/completions" |
| ) |
|
|
| batch_messages = [ |
| {"role": "user", "content": [{"type": "text", "text": "Why is the sky blue?"}]}, |
| {"role": "user", "content": [{"type": "text", "text": "What is photosynthesis?"}]}, |
| {"role": "user", "content": [{"type": "text", "text": "Please write a Fibonacci sequence function in Python."}]} |
| ] |
|
|
| async def main_batch_run(): |
| print("\n--- Starting async batch requests ---") |
| results = await vllm_client.generate_batch(batch_messages) |
| |
| print("\n--- Batch request results ---") |
| for i, res in enumerate(results): |
| if isinstance(res, str): |
| print(f"Request {i+1}: Success. Result: {res[:50]}...") |
| else: |
| print(f"Request {i+1}: Failed/Timeout.") |
| return results |
|
|
| |
| final_results = asyncio.run(main_batch_run()) |