""" A model worker that executes the model based on vLLM. See documentations at docs/vllm_integration.md """ import argparse import asyncio import json from typing import List from fastapi import FastAPI, Request, BackgroundTasks from fastapi.responses import StreamingResponse, JSONResponse import torch import uvicorn from vllm import AsyncLLMEngine from vllm.engine.arg_utils import AsyncEngineArgs from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid from fastchat.serve.model_worker import ( BaseModelWorker, logger, worker_id, global_counter, model_semaphore, ) from fastchat.utils import get_context_length class VLLMWorker(BaseModelWorker): def __init__( self, controller_addr: str, worker_addr: str, worker_id: str, model_path: str, model_names: List[str], no_register: bool, llm_engine: AsyncLLMEngine, ): super().__init__( controller_addr, worker_addr, worker_id, model_path, model_names ) logger.info( f"Loading the model {self.model_names} on worker {worker_id}, worker type: vLLM worker..." ) self.tokenizer = llm_engine.engine.tokenizer self.context_len = get_context_length(llm_engine.engine.model_config.hf_config) if not no_register: self.init_heart_beat() async def generate_stream(self, params): context = params.pop("prompt") request_id = params.pop("request_id") temperature = float(params.get("temperature", 1.0)) top_p = float(params.get("top_p", 1.0)) max_new_tokens = params.get("max_new_tokens", 256) stop_str = params.get("stop", None) echo = params.get("echo", True) # Handle stop_str if stop_str is None: stop_str = [] # TODO(Hao): handle stop token IDs # stop_token_ids = params.get("stop_token_ids", None) or [] # make sampling params in vllm top_p = max(top_p, 1e-5) if temperature <= 1e-5: top_p = 1.0 sampling_params = SamplingParams( n=1, temperature=temperature, top_p=top_p, use_beam_search=False, stop=stop_str, max_tokens=max_new_tokens, ) results_generator = engine.generate(context, sampling_params, request_id) async for request_output in results_generator: prompt = request_output.prompt if echo: text_outputs = [ prompt + output.text for output in request_output.outputs ] else: text_outputs = [output.text for output in request_output.outputs] text_outputs = " ".join(text_outputs) text_outputs = text_outputs.replace("", "") # Note: usage is not supported yet ret = {"text": text_outputs, "error_code": 0, "usage": {}} yield (json.dumps(ret) + "\0").encode() async def generate(self, params): async for x in self.generate_stream(params): pass return json.loads(x[:-1].decode()) app = FastAPI() def release_model_semaphore(): model_semaphore.release() def acquire_model_semaphore(): global model_semaphore, global_counter global_counter += 1 if model_semaphore is None: model_semaphore = asyncio.Semaphore(args.limit_model_concurrency) return model_semaphore.acquire() def create_background_tasks(request_id): async def abort_request() -> None: await engine.abort(request_id) background_tasks = BackgroundTasks() background_tasks.add_task(release_model_semaphore) background_tasks.add_task(abort_request) return background_tasks @app.post("/worker_generate_stream") async def api_generate_stream(request: Request): params = await request.json() await acquire_model_semaphore() request_id = random_uuid() params["request_id"] = request_id generator = worker.generate_stream(params) background_tasks = create_background_tasks(request_id) return StreamingResponse(generator, background=background_tasks) @app.post("/worker_generate") async def api_generate(request: Request): params = await request.json() await acquire_model_semaphore() request_id = random_uuid() params["request_id"] = request_id output = await worker.generate(params) release_model_semaphore() engine.abort(request_id) return JSONResponse(output) @app.post("/worker_get_status") async def api_get_status(request: Request): return worker.get_status() @app.post("/count_token") async def api_count_token(request: Request): params = await request.json() return worker.count_token(params) @app.post("/worker_get_conv_template") async def api_get_conv(request: Request): return worker.get_conv_template() @app.post("/model_details") async def api_model_details(request: Request): return {"context_length": worker.context_len} if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--port", type=int, default=21002) parser.add_argument("--worker-address", type=str, default="http://localhost:21002") parser.add_argument( "--controller-address", type=str, default="http://localhost:21001" ) parser.add_argument("--model-path", type=str, default="lmsys/vicuna-7b-v1.3") parser.add_argument( "--model-names", type=lambda s: s.split(","), help="Optional display comma separated names", ) parser.add_argument("--limit-model-concurrency", type=int, default=1024) parser.add_argument("--no-register", action="store_true") parser.add_argument("--num-gpus", type=int, default=1) parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() if args.model_path: args.model = args.model_path if args.num_gpus > 1: args.tensor_parallel_size = args.num_gpus engine_args = AsyncEngineArgs.from_cli_args(args) engine = AsyncLLMEngine.from_engine_args(engine_args) worker = VLLMWorker( args.controller_address, args.worker_address, worker_id, args.model_path, args.model_names, args.no_register, engine, ) uvicorn.run(app, host=args.host, port=args.port, log_level="info")