|
import contextlib |
|
import logging |
|
import os |
|
import sys |
|
import ast |
|
import json |
|
from threading import Thread |
|
import time |
|
from traceback import print_exception |
|
from typing import List |
|
from pydantic import BaseModel, Field |
|
|
|
import uvicorn |
|
from fastapi import Depends, FastAPI, Header, HTTPException |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from fastapi.requests import Request |
|
from fastapi.responses import JSONResponse |
|
from sse_starlette import EventSourceResponse |
|
from starlette.responses import PlainTextResponse |
|
|
|
from openai_server.log import logger |
|
|
|
sys.path.append('openai_server') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Generation(BaseModel): |
|
|
|
|
|
top_k: int | None = 1 |
|
repetition_penalty: float | None = 1 |
|
min_p: float | None = 0.0 |
|
max_time: float | None = 360 |
|
|
|
|
|
class Params(BaseModel): |
|
|
|
user: str | None = Field(default=None, description="Track user") |
|
model: str | None = Field(default=None, description="Choose model") |
|
best_of: int | None = Field(default=1, description="Unused") |
|
frequency_penalty: float | None = 0.0 |
|
max_tokens: int | None = 256 |
|
n: int | None = Field(default=1, description="Unused") |
|
presence_penalty: float | None = 0.0 |
|
stop: str | List[str] | None = None |
|
stop_token_ids: List[int] | None = None |
|
stream: bool | None = False |
|
temperature: float | None = 0.3 |
|
top_p: float | None = 1.0 |
|
seed: int | None = 1234 |
|
|
|
|
|
class CompletionParams(Params): |
|
prompt: str | List[str] |
|
logit_bias: dict | None = None |
|
logprobs: int | None = None |
|
|
|
|
|
class TextRequest(Generation, CompletionParams): |
|
pass |
|
|
|
|
|
class TextResponse(BaseModel): |
|
id: str |
|
choices: List[dict] |
|
created: int = int(time.time()) |
|
model: str |
|
object: str = "text_completion" |
|
usage: dict |
|
|
|
|
|
class ChatParams(Params): |
|
messages: List[dict] |
|
tools: list | None = Field(default=None, description="WIP") |
|
tool_choice: str | None = Field(default=None, description="WIP") |
|
|
|
|
|
class ChatRequest(Generation, ChatParams): |
|
|
|
pass |
|
|
|
|
|
class ChatResponse(BaseModel): |
|
id: str |
|
choices: List[dict] |
|
created: int = int(time.time()) |
|
model: str |
|
object: str = "chat.completion" |
|
usage: dict |
|
|
|
|
|
class ModelInfoResponse(BaseModel): |
|
model_name: str |
|
|
|
|
|
class ModelListResponse(BaseModel): |
|
model_names: List[str] |
|
|
|
|
|
def verify_api_key(authorization: str = Header(None)) -> None: |
|
server_api_key = os.getenv('H2OGPT_OPENAI_API_KEY', 'EMPTY') |
|
if server_api_key == 'EMPTY': |
|
|
|
return |
|
if server_api_key and (authorization is None or authorization != f"Bearer {server_api_key}"): |
|
raise HTTPException(status_code=401, detail="Unauthorized") |
|
|
|
|
|
app = FastAPI() |
|
check_key = [Depends(verify_api_key)] |
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"] |
|
) |
|
|
|
|
|
|
|
|
|
|
|
class InvalidRequestError(Exception): |
|
pass |
|
|
|
|
|
@app.exception_handler(Exception) |
|
async def validation_exception_handler(request, exc): |
|
print_exception(exc) |
|
exc2 = InvalidRequestError(str(exc)) |
|
return PlainTextResponse(str(exc2), status_code=400) |
|
|
|
|
|
@app.options("/", dependencies=check_key) |
|
async def options_route(): |
|
return JSONResponse(content="OK") |
|
|
|
|
|
@app.post('/v1/completions', response_model=TextResponse, dependencies=check_key) |
|
async def openai_completions(request: Request, request_data: TextRequest): |
|
if request_data.stream: |
|
async def generator(): |
|
from openai_server.backend import stream_completions |
|
response = stream_completions(dict(request_data)) |
|
for resp in response: |
|
disconnected = await request.is_disconnected() |
|
if disconnected: |
|
break |
|
|
|
yield {"data": json.dumps(resp)} |
|
|
|
return EventSourceResponse(generator()) |
|
|
|
else: |
|
from openai_server.backend import completions |
|
response = completions(dict(request_data)) |
|
return JSONResponse(response) |
|
|
|
|
|
@app.post('/v1/chat/completions', response_model=ChatResponse, dependencies=check_key) |
|
async def openai_chat_completions(request: Request, request_data: ChatRequest): |
|
if request_data.stream: |
|
from openai_server.backend import stream_chat_completions |
|
|
|
async def generator(): |
|
response = stream_chat_completions(dict(request_data)) |
|
for resp in response: |
|
disconnected = await request.is_disconnected() |
|
if disconnected: |
|
break |
|
|
|
yield {"data": json.dumps(resp)} |
|
|
|
return EventSourceResponse(generator()) |
|
else: |
|
from openai_server.backend import chat_completions |
|
response = chat_completions(dict(request_data)) |
|
return JSONResponse(response) |
|
|
|
|
|
|
|
@app.get("/v1/models", dependencies=check_key) |
|
@app.get("/v1/models/{model}", dependencies=check_key) |
|
@app.get("/v1/models/{repo}/{model}", dependencies=check_key) |
|
async def handle_models(request: Request): |
|
path = request.url.path |
|
model_name = path[len('/v1/models/'):] |
|
|
|
from openai_server.backend import gradio_client |
|
model_dict = ast.literal_eval(gradio_client.predict(api_name='/model_names')) |
|
base_models = [x['base_model'] for x in model_dict] |
|
|
|
if not model_name: |
|
response = { |
|
"object": "list", |
|
"data": base_models, |
|
} |
|
else: |
|
model_index = base_models.index(model_name) |
|
if model_index >= 0: |
|
response = model_dict[model_index] |
|
else: |
|
response = dict(model_name='INVALID') |
|
|
|
return JSONResponse(response) |
|
|
|
|
|
@app.get("/v1/internal/model/info", response_model=ModelInfoResponse, dependencies=check_key) |
|
async def handle_model_info(): |
|
from openai_server.backend import get_model_info |
|
return JSONResponse(content=get_model_info()) |
|
|
|
|
|
@app.get("/v1/internal/model/list", response_model=ModelListResponse, dependencies=check_key) |
|
async def handle_list_models(): |
|
from openai_server.backend import get_model_list |
|
return JSONResponse(content=get_model_list()) |
|
|
|
|
|
def run_server(host='0.0.0.0', |
|
port=5000, |
|
ssl_certfile=None, |
|
ssl_keyfile=None, |
|
gradio_prefix=None, |
|
gradio_host=None, |
|
gradio_port=None, |
|
h2ogpt_key=None, |
|
): |
|
os.environ['GRADIO_PREFIX'] = gradio_prefix or 'http' |
|
os.environ['GRADIO_SERVER_HOST'] = gradio_host or 'localhost' |
|
os.environ['GRADIO_SERVER_PORT'] = gradio_port or '7860' |
|
os.environ['GRADIO_H2OGPT_H2OGPT_KEY'] = h2ogpt_key or '' |
|
|
|
|
|
server_api_key = os.getenv('H2OGPT_OPENAI_API_KEY', os.environ['GRADIO_H2OGPT_H2OGPT_KEY']) or 'EMPTY' |
|
os.environ['H2OGPT_OPENAI_API_KEY'] = server_api_key |
|
|
|
port = int(os.getenv('H2OGPT_OPENAI_PORT', port)) |
|
ssl_certfile = os.getenv('H2OGPT_OPENAI_CERT_PATH', ssl_certfile) |
|
ssl_keyfile = os.getenv('H2OGPT_OPENAI_KEY_PATH', ssl_keyfile) |
|
|
|
prefix = 'https' if ssl_keyfile and ssl_certfile else 'http' |
|
logger.info(f'OpenAI API URL: {prefix}://{host}:{port}') |
|
logger.info(f'OpenAI API key: {server_api_key}') |
|
|
|
logging.getLogger("uvicorn.error").propagate = False |
|
uvicorn.run(app, host=host, port=port, ssl_certfile=ssl_certfile, ssl_keyfile=ssl_keyfile) |
|
|
|
|
|
def run(wait=True, **kwargs): |
|
if wait: |
|
run_server(**kwargs) |
|
else: |
|
Thread(target=run_server, kwargs=kwargs, daemon=True).start() |
|
|