diffsketcher / api.py
jree423's picture
Add: diffsketcher api.py with original implementation
45675d3 verified
raw
history blame
1.88 kB
from fastapi import FastAPI, Response, HTTPException
from pydantic import BaseModel
from typing import Union, Dict, Any
import os
import io
import sys
from handler import EndpointHandler
# Add debug logging
def debug_log(message):
print(f"DEBUG: {message}")
sys.stdout.flush()
debug_log("Starting API initialization")
app = FastAPI()
# Initialize the handler with the model directory
model_dir = os.environ.get("MODEL_DIR", "/code/diffsketcher")
debug_log(f"Using model_dir: {model_dir}")
handler = EndpointHandler(model_dir)
debug_log("Handler initialized")
class TextRequest(BaseModel):
inputs: Union[str, Dict[str, Any]]
@app.get("/")
def read_root():
debug_log("Root endpoint called")
return {"message": "DiffSketcher Vector Graphics Generation API"}
@app.post("/")
async def generate(request: TextRequest):
try:
debug_log(f"Generate endpoint called with request: {request}")
# Call the handler
result = handler(request.dict())
debug_log("Handler returned result")
# If the result is a PIL Image, convert it to bytes
if hasattr(result, "save"):
debug_log("Result is a PIL Image, converting to bytes")
img_byte_arr = io.BytesIO()
result.save(img_byte_arr, format="PNG")
img_byte_arr.seek(0)
# Return the image as a response
debug_log("Returning image response")
return Response(content=img_byte_arr.getvalue(), media_type="image/png")
else:
# Return the result as JSON
debug_log(f"Returning JSON response: {result}")
return result
except Exception as e:
debug_log(f"Error in generate endpoint: {e}")
import traceback
debug_log(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))