Spaces:
Paused
Paused
from fastapi import FastAPI | |
from fastapi.responses import JSONResponse, StreamingResponse | |
from pydantic import BaseModel | |
import types | |
import json | |
from pydantic import validator | |
from endpoint_handler import EndpointHandler # your handler file | |
import base64 | |
app = FastAPI() | |
handler = None | |
async def load_handler(): | |
global handler | |
handler = EndpointHandler() | |
class PredictInput(BaseModel): | |
image: str # base64-encoded image string | |
question: str | |
stream: bool = False | |
def question_not_empty(cls, v): | |
if not v.strip(): | |
raise ValueError("Question must not be empty") | |
return v | |
def valid_base64_and_size(cls, v): | |
try: | |
decoded = base64.b64decode(v, validate=True) | |
except Exception: | |
raise ValueError("`image` must be valid base64") | |
if len(decoded) > 10 * 1024 * 1024: # 10 MB limit | |
raise ValueError("Image exceeds 10 MB after decoding") | |
return v | |
class PredictRequest(BaseModel): | |
inputs: PredictInput | |
async def root(): | |
return {"message": "FastAPI app is running on Hugging Face"} | |
async def predict_endpoint(payload: PredictRequest): | |
print(f"[Request] Received question: {payload.inputs.question}") | |
data = { | |
"inputs": { | |
"image": payload.inputs.image, | |
"question": payload.inputs.question, | |
"stream": payload.inputs.stream | |
} | |
} | |
try: | |
result = handler.predict(data) | |
except ValueError as ve: | |
return JSONResponse({"error": str(ve)}, status_code=400) | |
except Exception as e: | |
return JSONResponse({"error": "Internal server error"}, status_code=500) | |
# βββ If it's a generator, return SSE/streaming ββββββββββββββββββββββββββ | |
if isinstance(result, types.GeneratorType): | |
def event_stream(): | |
try: | |
for chunk in result: | |
# Each chunk should already be a Python dict or | |
# a string containing JSON. We wrap it in "data: β¦\n\n" | |
yield f"data: {json.dumps(chunk)}\n\n" | |
# Finally send an endβofβstream marker | |
yield f"data: {json.dumps({'end': True})}\n\n" | |
except Exception as e: | |
yield f"data: {json.dumps({'error': str(e)})}\n\n" | |
return StreamingResponse(event_stream(), media_type="text/event-stream") | |
# βββ Otherwise (nonβstreaming), return a single JSON response ββββββββββ | |
# result is expected to be a JSONβstring or a dict | |
try: | |
# If handler.predict returned a JSONβencoded str, parse it to dict | |
if isinstance(result, str): | |
parsed = json.loads(result) | |
else: | |
parsed = result # assume it's already a dict | |
except Exception: | |
# Fall back to returning the raw result | |
return JSONResponse({"error": "Invalid JSON from handler"}, status_code=500) | |
return JSONResponse(parsed) | |