Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, Request, Response | |
from fastapi.responses import JSONResponse, StreamingResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
import uvicorn | |
import json | |
from typegpt_api import generate, model_mapping, simplified_models | |
from api_info import developer_info, model_providers | |
app = FastAPI() | |
# Set up CORS middleware if needed | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
async def health_check(): | |
return {"status": "OK"} | |
async def get_models(): | |
try: | |
response = { | |
"object": "list", | |
"data": [] | |
} | |
for provider, info in model_providers.items(): | |
for model in info["models"]: | |
response["data"].append({ | |
"id": model, | |
"object": "model", | |
"provider": provider, | |
"description": info["description"] | |
}) | |
return JSONResponse(content=response) | |
except Exception as e: | |
return JSONResponse(content={"error": str(e)}, status_code=500) | |
async def chat_completions(request: Request): | |
# Receive the JSON payload | |
try: | |
body = await request.json() | |
except Exception as e: | |
return JSONResponse(content={"error": "Invalid JSON payload"}, status_code=400) | |
# Extract parameters | |
model = body.get("model") | |
messages = body.get("messages") | |
temperature = body.get("temperature", 0.7) | |
top_p = body.get("top_p", 1.0) | |
n = body.get("n", 1) | |
stream = body.get("stream", False) | |
stop = body.get("stop") | |
max_tokens = body.get("max_tokens") | |
presence_penalty = body.get("presence_penalty", 0.0) | |
frequency_penalty = body.get("frequency_penalty", 0.0) | |
logit_bias = body.get("logit_bias") | |
user = body.get("user") | |
timeout = 30 # or set based on your preference | |
# Validate required parameters | |
if not model: | |
return JSONResponse(content={"error": "The 'model' parameter is required."}, status_code=400) | |
if not messages: | |
return JSONResponse(content={"error": "The 'messages' parameter is required."}, status_code=400) | |
# Call the generate function | |
try: | |
if stream: | |
async def generate_stream(): | |
response = generate( | |
model=model, | |
messages=messages, | |
temperature=temperature, | |
top_p=top_p, | |
n=n, | |
stream=True, | |
stop=stop, | |
max_tokens=max_tokens, | |
presence_penalty=presence_penalty, | |
frequency_penalty=frequency_penalty, | |
logit_bias=logit_bias, | |
user=user, | |
timeout=timeout, | |
) | |
for chunk in response: | |
yield f"data: {json.dumps(chunk)}\n\n" | |
yield "data: [DONE]\n\n" | |
return StreamingResponse( | |
generate_stream(), | |
media_type="text/event-stream", | |
headers={ | |
"Cache-Control": "no-cache", | |
"Connection": "keep-alive", | |
"Transfer-Encoding": "chunked" | |
} | |
) | |
else: | |
response = generate( | |
model=model, | |
messages=messages, | |
temperature=temperature, | |
top_p=top_p, | |
n=n, | |
stream=False, | |
stop=stop, | |
max_tokens=max_tokens, | |
presence_penalty=presence_penalty, | |
frequency_penalty=frequency_penalty, | |
logit_bias=logit_bias, | |
user=user, | |
timeout=timeout, | |
) | |
return JSONResponse(content=response) | |
except Exception as e: | |
return JSONResponse(content={"error": str(e)}, status_code=500) | |
async def get_developer_info(): | |
return JSONResponse(content=developer_info) | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=8000) |