test_api / fastapi_app.py
API-Handler's picture
Upload 10 files
501c69f verified
from fastapi import FastAPI, Request, Response
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import json
from typegpt_api import generate, model_mapping, simplified_models
from api_info import developer_info, model_providers
app = FastAPI()
# Set up CORS middleware if needed
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/health_check")
async def health_check():
return {"status": "OK"}
@app.get("/models")
async def get_models():
try:
response = {
"object": "list",
"data": []
}
for provider, info in model_providers.items():
for model in info["models"]:
response["data"].append({
"id": model,
"object": "model",
"provider": provider,
"description": info["description"]
})
return JSONResponse(content=response)
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)
@app.post("/chat/completions")
async def chat_completions(request: Request):
# Receive the JSON payload
try:
body = await request.json()
except Exception as e:
return JSONResponse(content={"error": "Invalid JSON payload"}, status_code=400)
# Extract parameters
model = body.get("model")
messages = body.get("messages")
temperature = body.get("temperature", 0.7)
top_p = body.get("top_p", 1.0)
n = body.get("n", 1)
stream = body.get("stream", False)
stop = body.get("stop")
max_tokens = body.get("max_tokens")
presence_penalty = body.get("presence_penalty", 0.0)
frequency_penalty = body.get("frequency_penalty", 0.0)
logit_bias = body.get("logit_bias")
user = body.get("user")
timeout = 30 # or set based on your preference
# Validate required parameters
if not model:
return JSONResponse(content={"error": "The 'model' parameter is required."}, status_code=400)
if not messages:
return JSONResponse(content={"error": "The 'messages' parameter is required."}, status_code=400)
# Call the generate function
try:
if stream:
async def generate_stream():
response = generate(
model=model,
messages=messages,
temperature=temperature,
top_p=top_p,
n=n,
stream=True,
stop=stop,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
user=user,
timeout=timeout,
)
for chunk in response:
yield f"data: {json.dumps(chunk)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
generate_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Transfer-Encoding": "chunked"
}
)
else:
response = generate(
model=model,
messages=messages,
temperature=temperature,
top_p=top_p,
n=n,
stream=False,
stop=stop,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
user=user,
timeout=timeout,
)
return JSONResponse(content=response)
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)
@app.get("/developer_info")
async def get_developer_info():
return JSONResponse(content=developer_info)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)