Spaces:
Runtime error
Runtime error
| """A server that provides OpenAI-compatible RESTful APIs. It supports: | |
| - Chat Completions. (Reference: https://platform.openai.com/docs/api-reference/chat) | |
| - Completions. (Reference: https://platform.openai.com/docs/api-reference/completions) | |
| - Embeddings. (Reference: https://platform.openai.com/docs/api-reference/embeddings) | |
| Usage: | |
| python3 -m fastchat.serve.openai_api_server | |
| """ | |
| import asyncio | |
| import argparse | |
| import json | |
| import os | |
| from typing import Generator, Optional, Union, Dict, List, Any | |
| import aiohttp | |
| import fastapi | |
| from fastapi import Depends, HTTPException | |
| from fastapi.exceptions import RequestValidationError | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse, JSONResponse | |
| from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer | |
| import httpx | |
| try: | |
| from pydantic.v1 import BaseSettings | |
| except ImportError: | |
| from pydantic import BaseSettings | |
| import shortuuid | |
| import tiktoken | |
| import uvicorn | |
| from fastchat.constants import ( | |
| WORKER_API_TIMEOUT, | |
| WORKER_API_EMBEDDING_BATCH_SIZE, | |
| ErrorCode, | |
| ) | |
| from fastchat.conversation import Conversation, SeparatorStyle | |
| from fastchat.protocol.openai_api_protocol import ( | |
| ChatCompletionRequest, | |
| ChatCompletionResponse, | |
| ChatCompletionResponseStreamChoice, | |
| ChatCompletionStreamResponse, | |
| ChatMessage, | |
| ChatCompletionResponseChoice, | |
| CompletionRequest, | |
| CompletionResponse, | |
| CompletionResponseChoice, | |
| DeltaMessage, | |
| CompletionResponseStreamChoice, | |
| CompletionStreamResponse, | |
| EmbeddingsRequest, | |
| EmbeddingsResponse, | |
| ErrorResponse, | |
| LogProbs, | |
| ModelCard, | |
| ModelList, | |
| ModelPermission, | |
| UsageInfo, | |
| ) | |
| from fastchat.protocol.api_protocol import ( | |
| APIChatCompletionRequest, | |
| APITokenCheckRequest, | |
| APITokenCheckResponse, | |
| APITokenCheckResponseItem, | |
| ) | |
| from fastchat.utils import build_logger | |
| logger = build_logger("openai_api_server", "openai_api_server.log") | |
| conv_template_map = {} | |
| fetch_timeout = aiohttp.ClientTimeout(total=3 * 3600) | |
| async def fetch_remote(url, pload=None, name=None): | |
| async with aiohttp.ClientSession(timeout=fetch_timeout) as session: | |
| async with session.post(url, json=pload) as response: | |
| chunks = [] | |
| if response.status != 200: | |
| ret = { | |
| "text": f"{response.reason}", | |
| "error_code": ErrorCode.INTERNAL_ERROR, | |
| } | |
| return json.dumps(ret) | |
| async for chunk, _ in response.content.iter_chunks(): | |
| chunks.append(chunk) | |
| output = b"".join(chunks) | |
| if name is not None: | |
| res = json.loads(output) | |
| if name != "": | |
| res = res[name] | |
| return res | |
| return output | |
| class AppSettings(BaseSettings): | |
| # The address of the model controller. | |
| controller_address: str = "http://localhost:21001" | |
| api_keys: Optional[List[str]] = None | |
| app_settings = AppSettings() | |
| app = fastapi.FastAPI() | |
| headers = {"User-Agent": "FastChat API Server"} | |
| get_bearer_token = HTTPBearer(auto_error=False) | |
| async def check_api_key( | |
| auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token), | |
| ) -> str: | |
| if app_settings.api_keys: | |
| if auth is None or (token := auth.credentials) not in app_settings.api_keys: | |
| raise HTTPException( | |
| status_code=401, | |
| detail={ | |
| "error": { | |
| "message": "", | |
| "type": "invalid_request_error", | |
| "param": None, | |
| "code": "invalid_api_key", | |
| } | |
| }, | |
| ) | |
| return token | |
| else: | |
| # api_keys not set; allow all | |
| return None | |
| def create_error_response(code: int, message: str) -> JSONResponse: | |
| return JSONResponse( | |
| ErrorResponse(message=message, code=code).dict(), status_code=400 | |
| ) | |
| async def validation_exception_handler(request, exc): | |
| return create_error_response(ErrorCode.VALIDATION_TYPE_ERROR, str(exc)) | |
| async def check_model(request) -> Optional[JSONResponse]: | |
| controller_address = app_settings.controller_address | |
| ret = None | |
| models = await fetch_remote(controller_address + "/list_models", None, "models") | |
| if request.model not in models: | |
| ret = create_error_response( | |
| ErrorCode.INVALID_MODEL, | |
| f"Only {'&&'.join(models)} allowed now, your model {request.model}", | |
| ) | |
| return ret | |
| async def check_length(request, prompt, max_tokens, worker_addr): | |
| if ( | |
| not isinstance(max_tokens, int) or max_tokens <= 0 | |
| ): # model worker not support max_tokens=None | |
| max_tokens = 1024 * 1024 | |
| context_len = await fetch_remote( | |
| worker_addr + "/model_details", {"model": request.model}, "context_length" | |
| ) | |
| token_num = await fetch_remote( | |
| worker_addr + "/count_token", | |
| {"model": request.model, "prompt": prompt}, | |
| "count", | |
| ) | |
| length = min(max_tokens, context_len - token_num) | |
| if length <= 0: | |
| return None, create_error_response( | |
| ErrorCode.CONTEXT_OVERFLOW, | |
| f"This model's maximum context length is {context_len} tokens. However, your messages resulted in {token_num} tokens. Please reduce the length of the messages.", | |
| ) | |
| return length, None | |
| def check_requests(request) -> Optional[JSONResponse]: | |
| # Check all params | |
| if request.max_tokens is not None and request.max_tokens <= 0: | |
| return create_error_response( | |
| ErrorCode.PARAM_OUT_OF_RANGE, | |
| f"{request.max_tokens} is less than the minimum of 1 - 'max_tokens'", | |
| ) | |
| if request.n is not None and request.n <= 0: | |
| return create_error_response( | |
| ErrorCode.PARAM_OUT_OF_RANGE, | |
| f"{request.n} is less than the minimum of 1 - 'n'", | |
| ) | |
| if request.temperature is not None and request.temperature < 0: | |
| return create_error_response( | |
| ErrorCode.PARAM_OUT_OF_RANGE, | |
| f"{request.temperature} is less than the minimum of 0 - 'temperature'", | |
| ) | |
| if request.temperature is not None and request.temperature > 2: | |
| return create_error_response( | |
| ErrorCode.PARAM_OUT_OF_RANGE, | |
| f"{request.temperature} is greater than the maximum of 2 - 'temperature'", | |
| ) | |
| if request.top_p is not None and request.top_p < 0: | |
| return create_error_response( | |
| ErrorCode.PARAM_OUT_OF_RANGE, | |
| f"{request.top_p} is less than the minimum of 0 - 'top_p'", | |
| ) | |
| if request.top_p is not None and request.top_p > 1: | |
| return create_error_response( | |
| ErrorCode.PARAM_OUT_OF_RANGE, | |
| f"{request.top_p} is greater than the maximum of 1 - 'top_p'", | |
| ) | |
| if request.top_k is not None and (request.top_k > -1 and request.top_k < 1): | |
| return create_error_response( | |
| ErrorCode.PARAM_OUT_OF_RANGE, | |
| f"{request.top_k} is out of Range. Either set top_k to -1 or >=1.", | |
| ) | |
| if request.stop is not None and ( | |
| not isinstance(request.stop, str) and not isinstance(request.stop, list) | |
| ): | |
| return create_error_response( | |
| ErrorCode.PARAM_OUT_OF_RANGE, | |
| f"{request.stop} is not valid under any of the given schemas - 'stop'", | |
| ) | |
| return None | |
| def process_input(model_name, inp): | |
| if isinstance(inp, str): | |
| inp = [inp] | |
| elif isinstance(inp, list): | |
| if isinstance(inp[0], int): | |
| try: | |
| decoding = tiktoken.model.encoding_for_model(model_name) | |
| except KeyError: | |
| logger.warning("Warning: model not found. Using cl100k_base encoding.") | |
| model = "cl100k_base" | |
| decoding = tiktoken.get_encoding(model) | |
| inp = [decoding.decode(inp)] | |
| elif isinstance(inp[0], list): | |
| try: | |
| decoding = tiktoken.model.encoding_for_model(model_name) | |
| except KeyError: | |
| logger.warning("Warning: model not found. Using cl100k_base encoding.") | |
| model = "cl100k_base" | |
| decoding = tiktoken.get_encoding(model) | |
| inp = [decoding.decode(text) for text in inp] | |
| return inp | |
| def create_openai_logprobs(logprob_dict): | |
| """Create OpenAI-style logprobs.""" | |
| return LogProbs(**logprob_dict) if logprob_dict is not None else None | |
| def _add_to_set(s, new_stop): | |
| if not s: | |
| return | |
| if isinstance(s, str): | |
| new_stop.add(s) | |
| else: | |
| new_stop.update(s) | |
| async def get_gen_params( | |
| model_name: str, | |
| worker_addr: str, | |
| messages: Union[str, List[Dict[str, str]]], | |
| *, | |
| temperature: float, | |
| top_p: float, | |
| top_k: Optional[int], | |
| presence_penalty: Optional[float], | |
| frequency_penalty: Optional[float], | |
| max_tokens: Optional[int], | |
| echo: Optional[bool], | |
| logprobs: Optional[int] = None, | |
| stop: Optional[Union[str, List[str]]], | |
| best_of: Optional[int] = None, | |
| use_beam_search: Optional[bool] = None, | |
| ) -> Dict[str, Any]: | |
| conv = await get_conv(model_name, worker_addr) | |
| conv = Conversation( | |
| name=conv["name"], | |
| system_template=conv["system_template"], | |
| system_message=conv["system_message"], | |
| roles=conv["roles"], | |
| messages=list(conv["messages"]), # prevent in-place modification | |
| offset=conv["offset"], | |
| sep_style=SeparatorStyle(conv["sep_style"]), | |
| sep=conv["sep"], | |
| sep2=conv["sep2"], | |
| stop_str=conv["stop_str"], | |
| stop_token_ids=conv["stop_token_ids"], | |
| ) | |
| if isinstance(messages, str): | |
| prompt = messages | |
| images = [] | |
| else: | |
| for message in messages: | |
| msg_role = message["role"] | |
| if msg_role == "system": | |
| conv.set_system_message(message["content"]) | |
| elif msg_role == "user": | |
| if type(message["content"]) == list: | |
| image_list = [ | |
| item["image_url"]["url"] | |
| for item in message["content"] | |
| if item["type"] == "image_url" | |
| ] | |
| text_list = [ | |
| item["text"] | |
| for item in message["content"] | |
| if item["type"] == "text" | |
| ] | |
| text = "\n".join(text_list) | |
| conv.append_message(conv.roles[0], (text, image_list)) | |
| else: | |
| conv.append_message(conv.roles[0], message["content"]) | |
| elif msg_role == "assistant": | |
| conv.append_message(conv.roles[1], message["content"]) | |
| else: | |
| raise ValueError(f"Unknown role: {msg_role}") | |
| # Add a blank message for the assistant. | |
| conv.append_message(conv.roles[1], None) | |
| prompt = conv.get_prompt() | |
| images = conv.get_images() | |
| gen_params = { | |
| "model": model_name, | |
| "prompt": prompt, | |
| "temperature": temperature, | |
| "logprobs": logprobs, | |
| "top_p": top_p, | |
| "top_k": top_k, | |
| "presence_penalty": presence_penalty, | |
| "frequency_penalty": frequency_penalty, | |
| "max_new_tokens": max_tokens, | |
| "echo": echo, | |
| "stop_token_ids": conv.stop_token_ids, | |
| } | |
| if len(images) > 0: | |
| gen_params["images"] = images | |
| if best_of is not None: | |
| gen_params.update({"best_of": best_of}) | |
| if use_beam_search is not None: | |
| gen_params.update({"use_beam_search": use_beam_search}) | |
| new_stop = set() | |
| _add_to_set(stop, new_stop) | |
| _add_to_set(conv.stop_str, new_stop) | |
| gen_params["stop"] = list(new_stop) | |
| logger.debug(f"==== request ====\n{gen_params}") | |
| return gen_params | |
| async def get_worker_address(model_name: str) -> str: | |
| """ | |
| Get worker address based on the requested model | |
| :param model_name: The worker's model name | |
| :return: Worker address from the controller | |
| :raises: :class:`ValueError`: No available worker for requested model | |
| """ | |
| controller_address = app_settings.controller_address | |
| worker_addr = await fetch_remote( | |
| controller_address + "/get_worker_address", {"model": model_name}, "address" | |
| ) | |
| # No available worker | |
| if worker_addr == "": | |
| raise ValueError(f"No available worker for {model_name}") | |
| logger.debug(f"model_name: {model_name}, worker_addr: {worker_addr}") | |
| return worker_addr | |
| async def get_conv(model_name: str, worker_addr: str): | |
| conv_template = conv_template_map.get((worker_addr, model_name)) | |
| if conv_template is None: | |
| conv_template = await fetch_remote( | |
| worker_addr + "/worker_get_conv_template", {"model": model_name}, "conv" | |
| ) | |
| conv_template_map[(worker_addr, model_name)] = conv_template | |
| return conv_template | |
| async def show_available_models(): | |
| controller_address = app_settings.controller_address | |
| ret = await fetch_remote(controller_address + "/refresh_all_workers") | |
| models = await fetch_remote(controller_address + "/list_models", None, "models") | |
| models.sort() | |
| # TODO: return real model permission details | |
| model_cards = [] | |
| for m in models: | |
| model_cards.append(ModelCard(id=m, root=m, permission=[ModelPermission()])) | |
| return ModelList(data=model_cards) | |
| async def create_chat_completion(request: ChatCompletionRequest): | |
| """Creates a completion for the chat message""" | |
| error_check_ret = await check_model(request) | |
| if error_check_ret is not None: | |
| return error_check_ret | |
| error_check_ret = check_requests(request) | |
| if error_check_ret is not None: | |
| return error_check_ret | |
| worker_addr = await get_worker_address(request.model) | |
| gen_params = await get_gen_params( | |
| request.model, | |
| worker_addr, | |
| request.messages, | |
| temperature=request.temperature, | |
| top_p=request.top_p, | |
| top_k=request.top_k, | |
| presence_penalty=request.presence_penalty, | |
| frequency_penalty=request.frequency_penalty, | |
| max_tokens=request.max_tokens, | |
| echo=False, | |
| stop=request.stop, | |
| ) | |
| max_new_tokens, error_check_ret = await check_length( | |
| request, | |
| gen_params["prompt"], | |
| gen_params["max_new_tokens"], | |
| worker_addr, | |
| ) | |
| if error_check_ret is not None: | |
| return error_check_ret | |
| gen_params["max_new_tokens"] = max_new_tokens | |
| if request.stream: | |
| generator = chat_completion_stream_generator( | |
| request.model, gen_params, request.n, worker_addr | |
| ) | |
| return StreamingResponse(generator, media_type="text/event-stream") | |
| choices = [] | |
| chat_completions = [] | |
| for i in range(request.n): | |
| content = asyncio.create_task(generate_completion(gen_params, worker_addr)) | |
| chat_completions.append(content) | |
| try: | |
| all_tasks = await asyncio.gather(*chat_completions) | |
| except Exception as e: | |
| return create_error_response(ErrorCode.INTERNAL_ERROR, str(e)) | |
| usage = UsageInfo() | |
| for i, content in enumerate(all_tasks): | |
| if isinstance(content, str): | |
| content = json.loads(content) | |
| if content["error_code"] != 0: | |
| return create_error_response(content["error_code"], content["text"]) | |
| choices.append( | |
| ChatCompletionResponseChoice( | |
| index=i, | |
| message=ChatMessage(role="assistant", content=content["text"]), | |
| finish_reason=content.get("finish_reason", "stop"), | |
| ) | |
| ) | |
| if "usage" in content: | |
| task_usage = UsageInfo.parse_obj(content["usage"]) | |
| for usage_key, usage_value in task_usage.dict().items(): | |
| setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) | |
| return ChatCompletionResponse(model=request.model, choices=choices, usage=usage) | |
| async def chat_completion_stream_generator( | |
| model_name: str, gen_params: Dict[str, Any], n: int, worker_addr: str | |
| ) -> Generator[str, Any, None]: | |
| """ | |
| Event stream format: | |
| https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format | |
| """ | |
| id = f"chatcmpl-{shortuuid.random()}" | |
| finish_stream_events = [] | |
| for i in range(n): | |
| # First chunk with role | |
| choice_data = ChatCompletionResponseStreamChoice( | |
| index=i, | |
| delta=DeltaMessage(role="assistant"), | |
| finish_reason=None, | |
| ) | |
| chunk = ChatCompletionStreamResponse( | |
| id=id, choices=[choice_data], model=model_name | |
| ) | |
| yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" | |
| previous_text = "" | |
| async for content in generate_completion_stream(gen_params, worker_addr): | |
| if content["error_code"] != 0: | |
| yield f"data: {json.dumps(content, ensure_ascii=False)}\n\n" | |
| yield "data: [DONE]\n\n" | |
| return | |
| decoded_unicode = content["text"].replace("\ufffd", "") | |
| delta_text = decoded_unicode[len(previous_text) :] | |
| previous_text = ( | |
| decoded_unicode | |
| if len(decoded_unicode) > len(previous_text) | |
| else previous_text | |
| ) | |
| if len(delta_text) == 0: | |
| delta_text = None | |
| choice_data = ChatCompletionResponseStreamChoice( | |
| index=i, | |
| delta=DeltaMessage(content=delta_text), | |
| finish_reason=content.get("finish_reason", None), | |
| ) | |
| chunk = ChatCompletionStreamResponse( | |
| id=id, choices=[choice_data], model=model_name | |
| ) | |
| if delta_text is None: | |
| if content.get("finish_reason", None) is not None: | |
| finish_stream_events.append(chunk) | |
| continue | |
| yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" | |
| # There is not "content" field in the last delta message, so exclude_none to exclude field "content". | |
| for finish_chunk in finish_stream_events: | |
| yield f"data: {finish_chunk.json(exclude_none=True, ensure_ascii=False)}\n\n" | |
| yield "data: [DONE]\n\n" | |
| async def create_completion(request: CompletionRequest): | |
| error_check_ret = await check_model(request) | |
| if error_check_ret is not None: | |
| return error_check_ret | |
| error_check_ret = check_requests(request) | |
| if error_check_ret is not None: | |
| return error_check_ret | |
| request.prompt = process_input(request.model, request.prompt) | |
| worker_addr = await get_worker_address(request.model) | |
| for text in request.prompt: | |
| max_tokens, error_check_ret = await check_length( | |
| request, text, request.max_tokens, worker_addr | |
| ) | |
| if error_check_ret is not None: | |
| return error_check_ret | |
| if isinstance(max_tokens, int) and max_tokens < request.max_tokens: | |
| request.max_tokens = max_tokens | |
| if request.stream: | |
| generator = generate_completion_stream_generator( | |
| request, request.n, worker_addr | |
| ) | |
| return StreamingResponse(generator, media_type="text/event-stream") | |
| else: | |
| text_completions = [] | |
| for text in request.prompt: | |
| gen_params = await get_gen_params( | |
| request.model, | |
| worker_addr, | |
| text, | |
| temperature=request.temperature, | |
| top_p=request.top_p, | |
| top_k=request.top_k, | |
| frequency_penalty=request.frequency_penalty, | |
| presence_penalty=request.presence_penalty, | |
| max_tokens=request.max_tokens, | |
| logprobs=request.logprobs, | |
| echo=request.echo, | |
| stop=request.stop, | |
| best_of=request.best_of, | |
| use_beam_search=request.use_beam_search, | |
| ) | |
| for i in range(request.n): | |
| content = asyncio.create_task( | |
| generate_completion(gen_params, worker_addr) | |
| ) | |
| text_completions.append(content) | |
| try: | |
| all_tasks = await asyncio.gather(*text_completions) | |
| except Exception as e: | |
| return create_error_response(ErrorCode.INTERNAL_ERROR, str(e)) | |
| choices = [] | |
| usage = UsageInfo() | |
| for i, content in enumerate(all_tasks): | |
| if content["error_code"] != 0: | |
| return create_error_response(content["error_code"], content["text"]) | |
| choices.append( | |
| CompletionResponseChoice( | |
| index=i, | |
| text=content["text"], | |
| logprobs=create_openai_logprobs(content.get("logprobs", None)), | |
| finish_reason=content.get("finish_reason", "stop"), | |
| ) | |
| ) | |
| task_usage = UsageInfo.parse_obj(content["usage"]) | |
| for usage_key, usage_value in task_usage.dict().items(): | |
| setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) | |
| return CompletionResponse( | |
| model=request.model, choices=choices, usage=UsageInfo.parse_obj(usage) | |
| ) | |
| async def generate_completion_stream_generator( | |
| request: CompletionRequest, n: int, worker_addr: str | |
| ): | |
| model_name = request.model | |
| id = f"cmpl-{shortuuid.random()}" | |
| finish_stream_events = [] | |
| for text in request.prompt: | |
| for i in range(n): | |
| previous_text = "" | |
| gen_params = await get_gen_params( | |
| request.model, | |
| worker_addr, | |
| text, | |
| temperature=request.temperature, | |
| top_p=request.top_p, | |
| top_k=request.top_k, | |
| presence_penalty=request.presence_penalty, | |
| frequency_penalty=request.frequency_penalty, | |
| max_tokens=request.max_tokens, | |
| logprobs=request.logprobs, | |
| echo=request.echo, | |
| stop=request.stop, | |
| ) | |
| async for content in generate_completion_stream(gen_params, worker_addr): | |
| if content["error_code"] != 0: | |
| yield f"data: {json.dumps(content, ensure_ascii=False)}\n\n" | |
| yield "data: [DONE]\n\n" | |
| return | |
| decoded_unicode = content["text"].replace("\ufffd", "") | |
| delta_text = decoded_unicode[len(previous_text) :] | |
| previous_text = ( | |
| decoded_unicode | |
| if len(decoded_unicode) > len(previous_text) | |
| else previous_text | |
| ) | |
| # todo: index is not apparent | |
| choice_data = CompletionResponseStreamChoice( | |
| index=i, | |
| text=delta_text, | |
| logprobs=create_openai_logprobs(content.get("logprobs", None)), | |
| finish_reason=content.get("finish_reason", None), | |
| ) | |
| chunk = CompletionStreamResponse( | |
| id=id, | |
| object="text_completion", | |
| choices=[choice_data], | |
| model=model_name, | |
| ) | |
| if len(delta_text) == 0: | |
| if content.get("finish_reason", None) is not None: | |
| finish_stream_events.append(chunk) | |
| continue | |
| yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" | |
| # There is not "content" field in the last delta message, so exclude_none to exclude field "content". | |
| for finish_chunk in finish_stream_events: | |
| yield f"data: {finish_chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" | |
| yield "data: [DONE]\n\n" | |
| async def generate_completion_stream(payload: Dict[str, Any], worker_addr: str): | |
| controller_address = app_settings.controller_address | |
| async with httpx.AsyncClient() as client: | |
| delimiter = b"\0" | |
| async with client.stream( | |
| "POST", | |
| worker_addr + "/worker_generate_stream", | |
| headers=headers, | |
| json=payload, | |
| timeout=WORKER_API_TIMEOUT, | |
| ) as response: | |
| # content = await response.aread() | |
| buffer = b"" | |
| async for raw_chunk in response.aiter_raw(): | |
| buffer += raw_chunk | |
| while (chunk_end := buffer.find(delimiter)) >= 0: | |
| chunk, buffer = buffer[:chunk_end], buffer[chunk_end + 1 :] | |
| if not chunk: | |
| continue | |
| yield json.loads(chunk.decode()) | |
| async def generate_completion(payload: Dict[str, Any], worker_addr: str): | |
| return await fetch_remote(worker_addr + "/worker_generate", payload, "") | |
| async def create_embeddings(request: EmbeddingsRequest, model_name: str = None): | |
| """Creates embeddings for the text""" | |
| if request.model is None: | |
| request.model = model_name | |
| error_check_ret = await check_model(request) | |
| if error_check_ret is not None: | |
| return error_check_ret | |
| request.input = process_input(request.model, request.input) | |
| data = [] | |
| token_num = 0 | |
| batch_size = WORKER_API_EMBEDDING_BATCH_SIZE | |
| batches = [ | |
| request.input[i : min(i + batch_size, len(request.input))] | |
| for i in range(0, len(request.input), batch_size) | |
| ] | |
| for num_batch, batch in enumerate(batches): | |
| payload = { | |
| "model": request.model, | |
| "input": batch, | |
| "encoding_format": request.encoding_format, | |
| } | |
| embedding = await get_embedding(payload) | |
| if "error_code" in embedding and embedding["error_code"] != 0: | |
| return create_error_response(embedding["error_code"], embedding["text"]) | |
| data += [ | |
| { | |
| "object": "embedding", | |
| "embedding": emb, | |
| "index": num_batch * batch_size + i, | |
| } | |
| for i, emb in enumerate(embedding["embedding"]) | |
| ] | |
| token_num += embedding["token_num"] | |
| return EmbeddingsResponse( | |
| data=data, | |
| model=request.model, | |
| usage=UsageInfo( | |
| prompt_tokens=token_num, | |
| total_tokens=token_num, | |
| completion_tokens=None, | |
| ), | |
| ).dict(exclude_none=True) | |
| async def get_embedding(payload: Dict[str, Any]): | |
| controller_address = app_settings.controller_address | |
| model_name = payload["model"] | |
| worker_addr = await get_worker_address(model_name) | |
| embedding = await fetch_remote(worker_addr + "/worker_get_embeddings", payload) | |
| return json.loads(embedding) | |
| ### GENERAL API - NOT OPENAI COMPATIBLE ### | |
| async def count_tokens(request: APITokenCheckRequest): | |
| """ | |
| Checks the token count for each message in your list | |
| This is not part of the OpenAI API spec. | |
| """ | |
| checkedList = [] | |
| for item in request.prompts: | |
| worker_addr = await get_worker_address(item.model) | |
| context_len = await fetch_remote( | |
| worker_addr + "/model_details", | |
| {"prompt": item.prompt, "model": item.model}, | |
| "context_length", | |
| ) | |
| token_num = await fetch_remote( | |
| worker_addr + "/count_token", | |
| {"prompt": item.prompt, "model": item.model}, | |
| "count", | |
| ) | |
| can_fit = True | |
| if token_num + item.max_tokens > context_len: | |
| can_fit = False | |
| checkedList.append( | |
| APITokenCheckResponseItem( | |
| fits=can_fit, contextLength=context_len, tokenCount=token_num | |
| ) | |
| ) | |
| return APITokenCheckResponse(prompts=checkedList) | |
| async def create_chat_completion(request: APIChatCompletionRequest): | |
| """Creates a completion for the chat message""" | |
| error_check_ret = await check_model(request) | |
| if error_check_ret is not None: | |
| return error_check_ret | |
| error_check_ret = check_requests(request) | |
| if error_check_ret is not None: | |
| return error_check_ret | |
| worker_addr = await get_worker_address(request.model) | |
| gen_params = await get_gen_params( | |
| request.model, | |
| worker_addr, | |
| request.messages, | |
| temperature=request.temperature, | |
| top_p=request.top_p, | |
| top_k=request.top_k, | |
| presence_penalty=request.presence_penalty, | |
| frequency_penalty=request.frequency_penalty, | |
| max_tokens=request.max_tokens, | |
| echo=False, | |
| stop=request.stop, | |
| ) | |
| if request.repetition_penalty is not None: | |
| gen_params["repetition_penalty"] = request.repetition_penalty | |
| max_new_tokens, error_check_ret = await check_length( | |
| request, | |
| gen_params["prompt"], | |
| gen_params["max_new_tokens"], | |
| worker_addr, | |
| ) | |
| if error_check_ret is not None: | |
| return error_check_ret | |
| gen_params["max_new_tokens"] = max_new_tokens | |
| if request.stream: | |
| generator = chat_completion_stream_generator( | |
| request.model, gen_params, request.n, worker_addr | |
| ) | |
| return StreamingResponse(generator, media_type="text/event-stream") | |
| choices = [] | |
| chat_completions = [] | |
| for i in range(request.n): | |
| content = asyncio.create_task(generate_completion(gen_params, worker_addr)) | |
| chat_completions.append(content) | |
| try: | |
| all_tasks = await asyncio.gather(*chat_completions) | |
| except Exception as e: | |
| return create_error_response(ErrorCode.INTERNAL_ERROR, str(e)) | |
| usage = UsageInfo() | |
| for i, content in enumerate(all_tasks): | |
| if content["error_code"] != 0: | |
| return create_error_response(content["error_code"], content["text"]) | |
| choices.append( | |
| ChatCompletionResponseChoice( | |
| index=i, | |
| message=ChatMessage(role="assistant", content=content["text"]), | |
| finish_reason=content.get("finish_reason", "stop"), | |
| ) | |
| ) | |
| task_usage = UsageInfo.parse_obj(content["usage"]) | |
| for usage_key, usage_value in task_usage.dict().items(): | |
| setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) | |
| return ChatCompletionResponse(model=request.model, choices=choices, usage=usage) | |
| ### END GENERAL API - NOT OPENAI COMPATIBLE ### | |
| def create_openai_api_server(): | |
| parser = argparse.ArgumentParser( | |
| description="FastChat ChatGPT-Compatible RESTful API server." | |
| ) | |
| parser.add_argument("--host", type=str, default="localhost", help="host name") | |
| parser.add_argument("--port", type=int, default=8000, help="port number") | |
| parser.add_argument( | |
| "--controller-address", type=str, default="http://localhost:21001" | |
| ) | |
| parser.add_argument( | |
| "--allow-credentials", action="store_true", help="allow credentials" | |
| ) | |
| parser.add_argument( | |
| "--allowed-origins", type=json.loads, default=["*"], help="allowed origins" | |
| ) | |
| parser.add_argument( | |
| "--allowed-methods", type=json.loads, default=["*"], help="allowed methods" | |
| ) | |
| parser.add_argument( | |
| "--allowed-headers", type=json.loads, default=["*"], help="allowed headers" | |
| ) | |
| parser.add_argument( | |
| "--api-keys", | |
| type=lambda s: s.split(","), | |
| help="Optional list of comma separated API keys", | |
| ) | |
| parser.add_argument( | |
| "--ssl", | |
| action="store_true", | |
| required=False, | |
| default=False, | |
| help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.", | |
| ) | |
| args = parser.parse_args() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=args.allowed_origins, | |
| allow_credentials=args.allow_credentials, | |
| allow_methods=args.allowed_methods, | |
| allow_headers=args.allowed_headers, | |
| ) | |
| app_settings.controller_address = args.controller_address | |
| app_settings.api_keys = args.api_keys | |
| logger.info(f"args: {args}") | |
| return args | |
| if __name__ == "__main__": | |
| args = create_openai_api_server() | |
| if args.ssl: | |
| uvicorn.run( | |
| app, | |
| host=args.host, | |
| port=args.port, | |
| log_level="info", | |
| ssl_keyfile=os.environ["SSL_KEYFILE"], | |
| ssl_certfile=os.environ["SSL_CERTFILE"], | |
| ) | |
| else: | |
| uvicorn.run(app, host=args.host, port=args.port, log_level="info") | |