Student0809's picture
Add files using upload-large-folder tool
7feac49 verified
# Copyright (c) Alibaba, Inc. and its affiliates.
# Code partially sourced from Hugging Face TRL
import atexit
import logging
import time
from typing import List, Optional
import requests
import torch
from dacite import from_dict
from requests import ConnectionError
from torch import nn
from swift.llm import AdapterRequest, InferRequest, Template
from swift.llm.infer.protocol import ChatCompletionResponse, RequestConfig
from swift.plugin import Metric
from swift.utils import is_vllm_ascend_available, is_vllm_available
if is_vllm_available():
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup
if is_vllm_ascend_available():
from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator as PyNcclCommunicator # noqa
logger = logging.getLogger(__name__)
class VLLMClient:
"""
A client class to interact with a vLLM server.
This class provides methods to infer completions, initialize and manage weight update groups, and update model
weights in a distributed setting. Before using it, start the vLLM server with `trl vllm-serve`.
Args:
host (`str`, *optional*, defaults to `"0.0.0.0"`):
IP address of the vLLM server.
server_port (`int`, *optional*, defaults to `8000`):
Port number of the vLLM server.
group_port (`int`, *optional*, defaults to `51216`):
Port number for the weight update group.
connection_timeout (`float`, *optional*, defaults to `0.0`):
Total timeout duration in seconds to wait for the server to be up. If the server is not up after the
timeout, a `ConnectionError` is raised.
"""
def __init__(self,
host: str = '0.0.0.0',
server_port: int = 8000,
group_port: int = 51216,
connection_timeout: float = 0.0):
if not is_vllm_available():
raise ImportError('vLLM is not installed. Please install it with `pip install vllm`.')
self.session = requests.Session()
self.host = host
self.server_port = server_port
self.group_port = group_port
self.check_server(connection_timeout) # check server and fail after timeout
def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0):
"""
Check server availability with retries on failure, within a total timeout duration. If the server is not up
after the total timeout duration, raise a `ConnectionError`.
Args:
retry_interval (`float`, *optional*, defaults to `2.0`):
Interval in seconds between retries.
total_timeout (`float`, *optional*, defaults to `0.0`):
Total timeout duration in seconds.
"""
url = f'http://{self.host}:{self.server_port}/health/'
start_time = time.time() # Record the start time
while True:
try:
response = requests.get(url)
except requests.exceptions.RequestException as exc:
# Check if the total timeout duration has passed
elapsed_time = time.time() - start_time
if elapsed_time >= total_timeout:
raise ConnectionError(
f"The vLLM server can't be reached at {self.host}:{self.server_port} after {total_timeout} "
'seconds. Make sure the server is running by running `swift deploy`.') from exc
else:
if response.status_code == 200:
logger.info('Server is up!')
return None
# Retry logic: wait before trying again
logger.info(f'Server is not up yet. Retrying in {retry_interval} seconds...')
time.sleep(retry_interval)
def infer(
self,
infer_requests: List[InferRequest],
request_config: Optional[RequestConfig] = None,
metrics: Optional[List[Metric]] = None,
*,
template: Optional[Template] = None,
use_tqdm: Optional[bool] = None,
adapter_request: Optional[AdapterRequest] = None,
):
url = f'http://{self.host}:{self.server_port}/infer/'
response = self.session.post(
url,
json={
'infer_requests': infer_requests,
'request_config': request_config,
'metrics': metrics,
'template': template,
'use_tqdm': use_tqdm,
'adapter_request': adapter_request,
},
)
if response.status_code == 200:
return [from_dict(data_class=ChatCompletionResponse, data=resp) for resp in response.json()]
else:
raise Exception(f'Request failed: {response.status_code}, {response.text}')
def init_communicator(self):
"""
Initializes the weight update group in a distributed setup for model synchronization.
"""
# Get the tensor parallel size from the server
url = f'http://{self.host}:{self.server_port}/get_world_size/'
response = requests.get(url)
if response.status_code == 200:
vllm_world_size = response.json()['world_size']
else:
raise Exception(f'Request failed: {response.status_code}, {response.text}')
world_size = vllm_world_size + 1 # add the client to the world
self.rank = vllm_world_size # the client's rank is the last process
# Initialize weight update group
url = f'http://{self.host}:{self.server_port}/init_communicator/'
# In the server side, the host is set to 0.0.0.0
response = self.session.post(url, json={'host': '0.0.0.0', 'port': self.group_port, 'world_size': world_size})
if response.status_code != 200:
raise Exception(f'Request failed: {response.status_code}, {response.text}')
# Brief delay to allow server initialization. While not strictly required (client socket will retry on
# connection failure), this prevents log warnings like:
# [W416 23:24:57.460001114 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3
time.sleep(0.1)
# Set up the communication group for weight broadcasting
pg = StatelessProcessGroup.create(host=self.host, port=self.group_port, rank=self.rank, world_size=world_size)
self.pynccl_comm = PyNcclCommunicator(pg, device=0)
# When the client object is deleted, close the weight update group
atexit.register(self.close_communicator)
def update_named_param(self, name: str, weights: torch.Tensor):
"""
Updates a specific named parameter in the model and broadcasts it to other processes.
Args:
name (`str`):
Name of the layer whose weights are being updated.
weights (`torch.Tensor`):
Tensor containing the updated weights.
"""
dtype, shape = str(weights.dtype), tuple(weights.shape)
url = f'http://{self.host}:{self.server_port}/update_named_param/'
response = self.session.post(url, json={'name': name, 'dtype': dtype, 'shape': shape})
if response.status_code != 200:
raise Exception(f'Request failed: {response.status_code}, {response.text}')
# Broadcast the weights to the other processes
self.pynccl_comm.broadcast(weights, src=self.rank)
self.pynccl_comm.group.barrier()
def update_model_params(self, model: nn.Module):
"""
Updates all parameters of the given model by calling `update_named_param` for each parameter in the model.
Args:
model (`nn.Module`):
Model whose parameters (weights/biases) are to be updated.
"""
for name, param in model.named_parameters():
# Update each parameter individually
self.update_named_param(name, param.data)
def reset_prefix_cache(self):
"""
Resets the prefix cache for the model.
"""
url = f'http://{self.host}:{self.server_port}/reset_prefix_cache/'
response = self.session.post(url)
if response.status_code != 200:
raise Exception(f'Request failed: {response.status_code}, {response.text}')
def close_communicator(self):
"""
Closes the weight update group and cleans up the communication group.
"""
url = f'http://{self.host}:{self.server_port}/close_communicator/'
try:
response = self.session.post(url)
except ConnectionError:
# The server might be already down, so we don't need to close the communicator
pass
else:
if response.status_code != 200:
raise Exception(f'Request failed: {response.status_code}, {response.text}')