imageapi / main.py
Arkm20's picture
Create main.py
8332ab6 verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from gradio_client import Client
# Init FastAPI app
app = FastAPI()
# Initialize the Gradio Client
client = Client("Efficient-Large-Model/SanaSprint")
# Request body schema
class GenerationRequest(BaseModel):
prompt: str
model_size: str = "1.6B"
seed: int = 0
randomize_seed: bool = True
width: int = 1024
height: int = 1024
guidance_scale: float = 4.5
num_inference_steps: int = 2
@app.post("/generate")
async def generate_image(request: GenerationRequest):
try:
result = client.predict(
prompt=request.prompt,
model_size=request.model_size,
seed=request.seed,
randomize_seed=request.randomize_seed,
width=request.width,
height=request.height,
guidance_scale=request.guidance_scale,
num_inference_steps=request.num_inference_steps,
api_name="/infer"
)
return {"result": result}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))