llm_fastapi / main.py
sreejith8100's picture
stream
c1aa475
from fastapi import FastAPI
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel
import types
import json
from pydantic import validator
from endpoint_handler import EndpointHandler # your handler file
import base64
app = FastAPI()
handler = None
@app.on_event("startup")
async def load_handler():
global handler
handler = EndpointHandler()
class PredictInput(BaseModel):
image: str # base64-encoded image string
question: str
stream: bool = False
@validator("question")
def question_not_empty(cls, v):
if not v.strip():
raise ValueError("Question must not be empty")
return v
@validator("image")
def valid_base64_and_size(cls, v):
try:
decoded = base64.b64decode(v, validate=True)
except Exception:
raise ValueError("`image` must be valid base64")
if len(decoded) > 10 * 1024 * 1024: # 10 MB limit
raise ValueError("Image exceeds 10 MB after decoding")
return v
class PredictRequest(BaseModel):
inputs: PredictInput
@app.get("/")
async def root():
return {"message": "FastAPI app is running on Hugging Face"}
@app.post("/predict")
async def predict_endpoint(payload: PredictRequest):
print(f"[Request] Received question: {payload.inputs.question}")
data = {
"inputs": {
"image": payload.inputs.image,
"question": payload.inputs.question,
"stream": payload.inputs.stream
}
}
try:
result = handler.predict(data)
except ValueError as ve:
return JSONResponse({"error": str(ve)}, status_code=400)
except Exception as e:
return JSONResponse({"error": "Internal server error"}, status_code=500)
# ─── If it's a generator, return SSE/streaming ──────────────────────────
if isinstance(result, types.GeneratorType):
def event_stream():
try:
for chunk in result:
# Each chunk should already be a Python dict or
# a string containing JSON. We wrap it in "data: …\n\n"
yield f"data: {json.dumps(chunk)}\n\n"
# Finally send an end‐of‐stream marker
yield f"data: {json.dumps({'end': True})}\n\n"
except Exception as e:
yield f"data: {json.dumps({'error': str(e)})}\n\n"
return StreamingResponse(event_stream(), media_type="text/event-stream")
# ─── Otherwise (non‐streaming), return a single JSON response ──────────
# result is expected to be a JSON‐string or a dict
try:
# If handler.predict returned a JSON‐encoded str, parse it to dict
if isinstance(result, str):
parsed = json.loads(result)
else:
parsed = result # assume it's already a dict
except Exception:
# Fall back to returning the raw result
return JSONResponse({"error": "Invalid JSON from handler"}, status_code=500)
return JSONResponse(parsed)