|
import atexit |
|
import logging |
|
import socket |
|
import threading |
|
import time |
|
from concurrent.futures import ThreadPoolExecutor |
|
from typing import List, Optional, Union |
|
from urllib.parse import urlparse |
|
|
|
import requests |
|
import torch |
|
from dacite import from_dict |
|
from requests import ConnectionError |
|
from torch import nn |
|
|
|
from swift.llm import AdapterRequest, RolloutInferRequest, Template |
|
from swift.llm.infer.protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, GymRolloutResponseChoice, |
|
RequestConfig, RolloutResponseChoice) |
|
from swift.plugin import Metric |
|
from swift.utils import is_vllm_ascend_available, is_vllm_available |
|
|
|
if is_vllm_available(): |
|
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator |
|
from vllm.distributed.utils import StatelessProcessGroup |
|
|
|
if is_vllm_ascend_available(): |
|
from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator as PyNcclCommunicator |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class VLLMClient: |
|
|
|
def __init__(self, |
|
base_urls: Optional[List[str]] = None, |
|
hosts: List[str] = ['0.0.0.0'], |
|
server_ports: List[int] = [8000], |
|
group_ports: Union[int, List[int]] = 51216, |
|
connection_timeout: float = 240.0): |
|
if not is_vllm_available(): |
|
raise ImportError('vLLM is not installed. Please install it with `pip install vllm`.') |
|
|
|
if base_urls is not None: |
|
self.base_urls = [] |
|
self.hosts = [] |
|
for url in base_urls: |
|
parsed_url = urlparse(url) |
|
host = socket.gethostbyname(parsed_url.hostname) |
|
scheme = parsed_url.scheme or 'http' |
|
base_url_i = f'{scheme}://{parsed_url.netloc}{parsed_url.path}' |
|
self.base_urls.append(base_url_i) |
|
self.hosts.append(host) |
|
else: |
|
if len(hosts) != len(server_ports): |
|
raise ValueError('host and server_port must have same length when lists are provided') |
|
self.base_urls = [f'http://{h}:{p}' for h, p in zip(hosts, server_ports)] |
|
self.hosts = hosts |
|
|
|
self.num_servers = len(self.base_urls) |
|
|
|
self.sessions = [requests.Session() for _ in range(self.num_servers)] |
|
|
|
if isinstance(group_ports, int): |
|
self.group_ports = [group_ports + i for i in range(self.num_servers)] |
|
elif isinstance(group_ports, list) and len(group_ports) == self.num_servers: |
|
self.group_ports = group_ports |
|
else: |
|
raise ValueError('group_port must be int or list of length num_servers') |
|
|
|
self.pynccl_comms = [] |
|
self.check_server(connection_timeout) |
|
|
|
def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0): |
|
server_status = [False] * self.num_servers |
|
|
|
def check_single_server(i): |
|
start_time = time.time() |
|
url = f'{self.base_urls[i]}/health/' |
|
while True: |
|
try: |
|
response = requests.get(url, timeout=retry_interval) |
|
if response.status_code == 200: |
|
server_status[i] = True |
|
return |
|
except Exception: |
|
pass |
|
|
|
elapsed = time.time() - start_time |
|
if elapsed >= total_timeout: |
|
return |
|
|
|
time.sleep(retry_interval) |
|
|
|
threads = [] |
|
for i in range(self.num_servers): |
|
t = threading.Thread(target=check_single_server, args=(i, )) |
|
t.daemon = True |
|
t.start() |
|
threads.append(t) |
|
|
|
for t in threads: |
|
t.join(total_timeout) |
|
|
|
if not all(server_status): |
|
failed_servers = [self.base_urls[i] for i, status in enumerate(server_status) if not status] |
|
raise ConnectionError(f'Servers not reachable after {total_timeout}s: {failed_servers}') |
|
|
|
def infer( |
|
self, |
|
infer_requests: List[RolloutInferRequest], |
|
request_config: Optional[RequestConfig] = None, |
|
metrics: Optional[List[Metric]] = None, |
|
*, |
|
template: Optional[Template] = None, |
|
use_tqdm: Optional[bool] = None, |
|
adapter_request: Optional[AdapterRequest] = None, |
|
): |
|
if not hasattr(self, 'use_async_engine') or not hasattr(self, 'use_gym_env'): |
|
self.get_engine_type() |
|
|
|
n = len(infer_requests) |
|
chunk_size = (n + self.num_servers - 1) // self.num_servers |
|
chunks = [infer_requests[i:i + chunk_size] for i in range(0, n, chunk_size)] |
|
chunks += [[]] * (self.num_servers - len(chunks)) |
|
|
|
results = [None] * self.num_servers |
|
errors = [None] * self.num_servers |
|
|
|
def process_chunk(i, chunk): |
|
try: |
|
response = self.sessions[i].post( |
|
f'{self.base_urls[i]}/infer/', |
|
json={ |
|
'infer_requests': chunk, |
|
'request_config': request_config, |
|
'metrics': metrics, |
|
'template': template, |
|
'use_tqdm': use_tqdm, |
|
'adapter_request': adapter_request, |
|
}, |
|
) |
|
|
|
if response.status_code != 200: |
|
errors[i] = Exception(f'Server {i} failed: {response.status_code}, {response.text}') |
|
return |
|
|
|
resp_data = response.json() |
|
results[i] = self.parse_resp_data(resp_data) |
|
except Exception as e: |
|
errors[i] = e |
|
|
|
with ThreadPoolExecutor(max_workers=self.num_servers) as executor: |
|
futures = [executor.submit(process_chunk, i, chunk) for i, chunk in enumerate(chunks)] |
|
for future in futures: |
|
future.result() |
|
|
|
all_errors = [e for e in errors if e is not None] |
|
if all_errors: |
|
raise RuntimeError(f'Multiple errors: {all_errors}') |
|
|
|
return [res for server_results in results for res in server_results] |
|
|
|
def init_communicator(self): |
|
self.pynccl_comms = [] |
|
|
|
for i in range(self.num_servers): |
|
response = self.sessions[i].get(f'{self.base_urls[i]}/get_world_size/') |
|
if response.status_code != 200: |
|
raise Exception(f'Server {i} failed: {response.text}') |
|
vllm_world_size = response.json()['world_size'] |
|
|
|
world_size = vllm_world_size + 1 |
|
rank = vllm_world_size |
|
|
|
response = self.sessions[i].post( |
|
f'{self.base_urls[i]}/init_communicator/', |
|
json={ |
|
'host': '0.0.0.0', |
|
'port': self.group_ports[i], |
|
'world_size': world_size |
|
}) |
|
if response.status_code != 200: |
|
raise Exception(f'Server {i} init failed: {response.text}') |
|
|
|
time.sleep(0.1) |
|
|
|
pg = StatelessProcessGroup.create( |
|
host=self.hosts[i], port=self.group_ports[i], rank=rank, world_size=world_size) |
|
comm = PyNcclCommunicator(pg, device=0) |
|
self.pynccl_comms.append(comm) |
|
|
|
atexit.register(self.close_communicator) |
|
|
|
def update_named_param(self, name: str, weights: torch.Tensor): |
|
dtype = str(weights.dtype) |
|
shape = tuple(weights.shape) |
|
|
|
errors = [None] * self.num_servers |
|
|
|
def _update_single_server(i): |
|
try: |
|
response = self.sessions[i].post( |
|
f'{self.base_urls[i]}/update_named_param/', |
|
json={ |
|
'name': name, |
|
'dtype': dtype, |
|
'shape': shape |
|
}, |
|
) |
|
if response.status_code != 200: |
|
raise Exception(f'Server {i} update failed: {response.text}') |
|
|
|
self.pynccl_comms[i].broadcast(weights, src=self.pynccl_comms[i].rank) |
|
self.pynccl_comms[i].group.barrier() |
|
except Exception as e: |
|
errors[i] = e |
|
|
|
with ThreadPoolExecutor(max_workers=self.num_servers) as executor: |
|
futures = [executor.submit(_update_single_server, i) for i in range(self.num_servers)] |
|
for future in futures: |
|
future.result() |
|
|
|
all_errors = [e for e in errors if e is not None] |
|
if all_errors: |
|
raise RuntimeError(f'Multiple errors: {all_errors}') |
|
|
|
def update_model_params(self, model: nn.Module): |
|
for name, param in model.named_parameters(): |
|
self.update_named_param(name, param.data) |
|
|
|
def reset_prefix_cache(self): |
|
errors = [None] * self.num_servers |
|
|
|
def _reset_single_server(i): |
|
try: |
|
response = self.sessions[i].post(f'{self.base_urls[i]}/reset_prefix_cache/') |
|
if response.status_code != 200: |
|
raise Exception(f'Server {i} reset failed: {response.text}') |
|
except Exception as e: |
|
errors[i] = e |
|
|
|
with ThreadPoolExecutor(max_workers=self.num_servers) as executor: |
|
futures = [executor.submit(_reset_single_server, i) for i in range(self.num_servers)] |
|
for future in futures: |
|
future.result() |
|
all_errors = [e for e in errors if e is not None] |
|
if all_errors: |
|
raise RuntimeError(f'Multiple errors on reset_prefix_cache: {all_errors}') |
|
|
|
def get_engine_type(self): |
|
|
|
response = self.sessions[0].post(f'{self.base_urls[0]}/get_engine_type/') |
|
if response.status_code != 200: |
|
raise Exception(f'Engine type request failed: {response.text}') |
|
|
|
result = response.json() |
|
self.use_async_engine = result['engine_type'] == 'AsyncLLMEngine' |
|
self.use_gym_env = result.get('gym_env', False) |
|
return result['engine_type'] |
|
|
|
def close_communicator(self): |
|
for i in range(self.num_servers): |
|
try: |
|
response = self.sessions[i].post(f'{self.base_urls[i]}/close_communicator/') |
|
if response.status_code != 200: |
|
logger.warning(f'Server {i} close failed: {response.text}') |
|
except Exception as e: |
|
logger.warning(f'Error closing server {i} communicator: {str(e)}') |
|
|
|
def parse_resp_data(self, resp_data): |
|
if self.use_gym_env: |
|
choice_cls = GymRolloutResponseChoice |
|
elif self.use_async_engine: |
|
choice_cls = RolloutResponseChoice |
|
else: |
|
choice_cls = ChatCompletionResponseChoice |
|
result = [ |
|
ChatCompletionResponse( |
|
choices=[from_dict(data_class=choice_cls, data=c) for c in resp['choices']], |
|
**{k: v |
|
for k, v in resp.items() if k != 'choices'}) for resp in resp_data |
|
] |
|
|
|
return result |
|
|