Spaces:
Sleeping
Sleeping
File size: 4,024 Bytes
0099d95 f0feabf 498d80c 1ad1813 739823d c53513a 0bddb91 9391fe6 0bddb91 9391fe6 0bddb91 2cd7197 9391fe6 0bddb91 11c5c73 0bddb91 11c5c73 0bddb91 11c5c73 0bddb91 5322bcf 0bddb91 9391fe6 0bddb91 9391fe6 0bddb91 5322bcf 0bddb91 9391fe6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import os
import time
import random
import asyncio
import json
from fastapi import FastAPI, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security.api_key import APIKeyHeader
from pydantic import BaseModel, Field, field_validator
from typing import List, Optional
from dotenv import load_dotenv
from starlette.responses import StreamingResponse
from openai import OpenAI
from typing import List, Optional, Type
load_dotenv()
API_KEYS = [
os.getenv("API_GEMINI_1"),
os.getenv("API_GEMINI_2"),
os.getenv("API_GEMINI_3")
]
BASE_URL = os.getenv("BASE_URL", "https://generativelanguage.googleapis.com/v1beta/openai/")
EXPECTED_API_KEY = os.getenv("API_HUGGINGFACE")
API_KEY_NAME = "Authorization"
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
app = FastAPI(title="OpenAI-SDK-compatible API", version="1.0.0", description="Un wrapper FastAPI compatibile con le specifiche dell'API OpenAI.")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def verify_api_key(api_key: str = Depends(api_key_header)):
if not api_key:
raise HTTPException(status_code=403, detail="API key mancante")
if api_key != f"Bearer {EXPECTED_API_KEY}":
raise HTTPException(status_code=403, detail="API key non valida")
return api_key
def get_openai_client():
api_key = random.choice(API_KEYS)
return OpenAI(api_key=api_key, base_url=BASE_URL)
def call_api_sync(params: ChatCompletionRequest):
try:
client = get_openai_client()
print(params)
response = client.chat.completions.create(
model=params.model,
messages=[m.model_dump() for m in params.messages],
max_tokens=params.max_tokens,
temperature=params.temperature,
stream=params.stream
)
return response
except Exception as e:
if "429" in str(e):
time.sleep(2)
return call_api_sync(params)
else:
raise e
async def _resp_async_generator(params: ChatCompletionRequest):
client = get_openai_client()
try:
response = client.chat.completions.create(
model=params.model,
messages=[m.model_dump() for m in params.messages],
max_tokens=params.max_tokens,
temperature=params.temperature,
stream=True
)
for chunk in response:
chunk_data = chunk.to_dict() if hasattr(chunk, "to_dict") else chunk
yield f"data: {json.dumps(chunk_data)}\n\n"
await asyncio.sleep(0.01)
yield "data: [DONE]\n\n"
except Exception as e:
error_data = {"error": str(e)}
yield f"data: {json.dumps(error_data)}\n\n"
# ---------------------------------------------------------------------------------------
@app.get("/")
def read_general():
return {"response": "Benvenuto"}
@app.get("/health")
async def health_check():
return {"message": "success"}
class Message(BaseModel):
role: str
content: str
# ---------------------------------- Generazione Testo ---------------------------------------
class ChatCompletionRequest(BaseModel):
model: str = "gemini-2.0-flash"
messages: List[Message]
max_tokens: Optional[int] = 8196
temperature: Optional[float] = 0.8
stream: Optional[bool] = False
@app.post("/v1/chat/completions", dependencies=[Depends(verify_api_key)])
async def chat_completions(req: ChatCompletionRequest):
if not req.messages:
raise HTTPException(status_code=400, detail="Nessun messaggio fornito")
if req.stream:
return StreamingResponse(
_resp_async_generator(req),
media_type="application/x-ndjson"
)
else:
try:
response = call_api_sync(req)
return response
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) |