|  | import contextlib | 
					
						
						|  | import logging | 
					
						
						|  | import os | 
					
						
						|  | import sys | 
					
						
						|  | import ast | 
					
						
						|  | import json | 
					
						
						|  | from threading import Thread | 
					
						
						|  | import time | 
					
						
						|  | from traceback import print_exception | 
					
						
						|  | from typing import List | 
					
						
						|  | from pydantic import BaseModel, Field | 
					
						
						|  |  | 
					
						
						|  | import uvicorn | 
					
						
						|  | from fastapi import Depends, FastAPI, Header, HTTPException | 
					
						
						|  | from fastapi.middleware.cors import CORSMiddleware | 
					
						
						|  | from fastapi.requests import Request | 
					
						
						|  | from fastapi.responses import JSONResponse | 
					
						
						|  | from sse_starlette import EventSourceResponse | 
					
						
						|  | from starlette.responses import PlainTextResponse | 
					
						
						|  |  | 
					
						
						|  | from openai_server.log import logger | 
					
						
						|  |  | 
					
						
						|  | sys.path.append('openai_server') | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class Generation(BaseModel): | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | top_k: int | None = 1 | 
					
						
						|  | repetition_penalty: float | None = 1 | 
					
						
						|  | min_p: float | None = 0.0 | 
					
						
						|  | max_time: float | None = 360 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class Params(BaseModel): | 
					
						
						|  |  | 
					
						
						|  | user: str | None = Field(default=None, description="Track user") | 
					
						
						|  | model: str | None = Field(default=None, description="Choose model") | 
					
						
						|  | best_of: int | None = Field(default=1, description="Unused") | 
					
						
						|  | frequency_penalty: float | None = 0.0 | 
					
						
						|  | max_tokens: int | None = 256 | 
					
						
						|  | n: int | None = Field(default=1, description="Unused") | 
					
						
						|  | presence_penalty: float | None = 0.0 | 
					
						
						|  | stop: str | List[str] | None = None | 
					
						
						|  | stop_token_ids: List[int] | None = None | 
					
						
						|  | stream: bool | None = False | 
					
						
						|  | temperature: float | None = 0.3 | 
					
						
						|  | top_p: float | None = 1.0 | 
					
						
						|  | seed: int | None = 1234 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class CompletionParams(Params): | 
					
						
						|  | prompt: str | List[str] | 
					
						
						|  | logit_bias: dict | None = None | 
					
						
						|  | logprobs: int | None = None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class TextRequest(Generation, CompletionParams): | 
					
						
						|  | pass | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class TextResponse(BaseModel): | 
					
						
						|  | id: str | 
					
						
						|  | choices: List[dict] | 
					
						
						|  | created: int = int(time.time()) | 
					
						
						|  | model: str | 
					
						
						|  | object: str = "text_completion" | 
					
						
						|  | usage: dict | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ChatParams(Params): | 
					
						
						|  | messages: List[dict] | 
					
						
						|  | tools: list | None = Field(default=None, description="WIP") | 
					
						
						|  | tool_choice: str | None = Field(default=None, description="WIP") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ChatRequest(Generation, ChatParams): | 
					
						
						|  |  | 
					
						
						|  | pass | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ChatResponse(BaseModel): | 
					
						
						|  | id: str | 
					
						
						|  | choices: List[dict] | 
					
						
						|  | created: int = int(time.time()) | 
					
						
						|  | model: str | 
					
						
						|  | object: str = "chat.completion" | 
					
						
						|  | usage: dict | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ModelInfoResponse(BaseModel): | 
					
						
						|  | model_name: str | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ModelListResponse(BaseModel): | 
					
						
						|  | model_names: List[str] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def verify_api_key(authorization: str = Header(None)) -> None: | 
					
						
						|  | server_api_key = os.getenv('H2OGPT_OPENAI_API_KEY', 'EMPTY') | 
					
						
						|  | if server_api_key == 'EMPTY': | 
					
						
						|  |  | 
					
						
						|  | return | 
					
						
						|  | if server_api_key and (authorization is None or authorization != f"Bearer {server_api_key}"): | 
					
						
						|  | raise HTTPException(status_code=401, detail="Unauthorized") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | app = FastAPI() | 
					
						
						|  | check_key = [Depends(verify_api_key)] | 
					
						
						|  | app.add_middleware( | 
					
						
						|  | CORSMiddleware, | 
					
						
						|  | allow_origins=["*"], | 
					
						
						|  | allow_credentials=True, | 
					
						
						|  | allow_methods=["*"], | 
					
						
						|  | allow_headers=["*"] | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class InvalidRequestError(Exception): | 
					
						
						|  | pass | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @app.exception_handler(Exception) | 
					
						
						|  | async def validation_exception_handler(request, exc): | 
					
						
						|  | print_exception(exc) | 
					
						
						|  | exc2 = InvalidRequestError(str(exc)) | 
					
						
						|  | return PlainTextResponse(str(exc2), status_code=400) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @app.options("/", dependencies=check_key) | 
					
						
						|  | async def options_route(): | 
					
						
						|  | return JSONResponse(content="OK") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @app.post('/v1/completions', response_model=TextResponse, dependencies=check_key) | 
					
						
						|  | async def openai_completions(request: Request, request_data: TextRequest): | 
					
						
						|  | if request_data.stream: | 
					
						
						|  | async def generator(): | 
					
						
						|  | from openai_server.backend import stream_completions | 
					
						
						|  | response = stream_completions(dict(request_data)) | 
					
						
						|  | for resp in response: | 
					
						
						|  | disconnected = await request.is_disconnected() | 
					
						
						|  | if disconnected: | 
					
						
						|  | break | 
					
						
						|  |  | 
					
						
						|  | yield {"data": json.dumps(resp)} | 
					
						
						|  |  | 
					
						
						|  | return EventSourceResponse(generator()) | 
					
						
						|  |  | 
					
						
						|  | else: | 
					
						
						|  | from openai_server.backend import completions | 
					
						
						|  | response = completions(dict(request_data)) | 
					
						
						|  | return JSONResponse(response) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @app.post('/v1/chat/completions', response_model=ChatResponse, dependencies=check_key) | 
					
						
						|  | async def openai_chat_completions(request: Request, request_data: ChatRequest): | 
					
						
						|  | if request_data.stream: | 
					
						
						|  | from openai_server.backend import stream_chat_completions | 
					
						
						|  |  | 
					
						
						|  | async def generator(): | 
					
						
						|  | response = stream_chat_completions(dict(request_data)) | 
					
						
						|  | for resp in response: | 
					
						
						|  | disconnected = await request.is_disconnected() | 
					
						
						|  | if disconnected: | 
					
						
						|  | break | 
					
						
						|  |  | 
					
						
						|  | yield {"data": json.dumps(resp)} | 
					
						
						|  |  | 
					
						
						|  | return EventSourceResponse(generator()) | 
					
						
						|  | else: | 
					
						
						|  | from openai_server.backend import chat_completions | 
					
						
						|  | response = chat_completions(dict(request_data)) | 
					
						
						|  | return JSONResponse(response) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @app.get("/v1/models", dependencies=check_key) | 
					
						
						|  | @app.get("/v1/models/{model}", dependencies=check_key) | 
					
						
						|  | @app.get("/v1/models/{repo}/{model}", dependencies=check_key) | 
					
						
						|  | async def handle_models(request: Request): | 
					
						
						|  | path = request.url.path | 
					
						
						|  | model_name = path[len('/v1/models/'):] | 
					
						
						|  |  | 
					
						
						|  | from openai_server.backend import gradio_client | 
					
						
						|  | model_dict = ast.literal_eval(gradio_client.predict(api_name='/model_names')) | 
					
						
						|  | base_models = [x['base_model'] for x in model_dict] | 
					
						
						|  |  | 
					
						
						|  | if not model_name: | 
					
						
						|  | response = { | 
					
						
						|  | "object": "list", | 
					
						
						|  | "data": base_models, | 
					
						
						|  | } | 
					
						
						|  | else: | 
					
						
						|  | model_index = base_models.index(model_name) | 
					
						
						|  | if model_index >= 0: | 
					
						
						|  | response = model_dict[model_index] | 
					
						
						|  | else: | 
					
						
						|  | response = dict(model_name='INVALID') | 
					
						
						|  |  | 
					
						
						|  | return JSONResponse(response) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @app.get("/v1/internal/model/info", response_model=ModelInfoResponse, dependencies=check_key) | 
					
						
						|  | async def handle_model_info(): | 
					
						
						|  | from openai_server.backend import get_model_info | 
					
						
						|  | return JSONResponse(content=get_model_info()) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @app.get("/v1/internal/model/list", response_model=ModelListResponse, dependencies=check_key) | 
					
						
						|  | async def handle_list_models(): | 
					
						
						|  | from openai_server.backend import get_model_list | 
					
						
						|  | return JSONResponse(content=get_model_list()) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def run_server(host='0.0.0.0', | 
					
						
						|  | port=5000, | 
					
						
						|  | ssl_certfile=None, | 
					
						
						|  | ssl_keyfile=None, | 
					
						
						|  | gradio_prefix=None, | 
					
						
						|  | gradio_host=None, | 
					
						
						|  | gradio_port=None, | 
					
						
						|  | h2ogpt_key=None, | 
					
						
						|  | ): | 
					
						
						|  | os.environ['GRADIO_PREFIX'] = gradio_prefix or 'http' | 
					
						
						|  | os.environ['GRADIO_SERVER_HOST'] = gradio_host or 'localhost' | 
					
						
						|  | os.environ['GRADIO_SERVER_PORT'] = gradio_port or '7860' | 
					
						
						|  | os.environ['GRADIO_H2OGPT_H2OGPT_KEY'] = h2ogpt_key or '' | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | server_api_key = os.getenv('H2OGPT_OPENAI_API_KEY', os.environ['GRADIO_H2OGPT_H2OGPT_KEY']) or 'EMPTY' | 
					
						
						|  | os.environ['H2OGPT_OPENAI_API_KEY'] = server_api_key | 
					
						
						|  |  | 
					
						
						|  | port = int(os.getenv('H2OGPT_OPENAI_PORT', port)) | 
					
						
						|  | ssl_certfile = os.getenv('H2OGPT_OPENAI_CERT_PATH', ssl_certfile) | 
					
						
						|  | ssl_keyfile = os.getenv('H2OGPT_OPENAI_KEY_PATH', ssl_keyfile) | 
					
						
						|  |  | 
					
						
						|  | prefix = 'https' if ssl_keyfile and ssl_certfile else 'http' | 
					
						
						|  | logger.info(f'OpenAI API URL: {prefix}://{host}:{port}') | 
					
						
						|  | logger.info(f'OpenAI API key: {server_api_key}') | 
					
						
						|  |  | 
					
						
						|  | logging.getLogger("uvicorn.error").propagate = False | 
					
						
						|  | uvicorn.run(app, host=host, port=port, ssl_certfile=ssl_certfile, ssl_keyfile=ssl_keyfile) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def run(wait=True, **kwargs): | 
					
						
						|  | if wait: | 
					
						
						|  | run_server(**kwargs) | 
					
						
						|  | else: | 
					
						
						|  | Thread(target=run_server, kwargs=kwargs, daemon=True).start() | 
					
						
						|  |  |