|
|
from fastapi import FastAPI, HTTPException |
|
|
from pydantic import BaseModel |
|
|
from gradio_client import Client |
|
|
|
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
client = Client("Efficient-Large-Model/SanaSprint") |
|
|
|
|
|
|
|
|
class GenerationRequest(BaseModel): |
|
|
prompt: str |
|
|
model_size: str = "1.6B" |
|
|
seed: int = 0 |
|
|
randomize_seed: bool = True |
|
|
width: int = 1024 |
|
|
height: int = 1024 |
|
|
guidance_scale: float = 4.5 |
|
|
num_inference_steps: int = 2 |
|
|
|
|
|
@app.post("/generate") |
|
|
async def generate_image(request: GenerationRequest): |
|
|
try: |
|
|
result = client.predict( |
|
|
prompt=request.prompt, |
|
|
model_size=request.model_size, |
|
|
seed=request.seed, |
|
|
randomize_seed=request.randomize_seed, |
|
|
width=request.width, |
|
|
height=request.height, |
|
|
guidance_scale=request.guidance_scale, |
|
|
num_inference_steps=request.num_inference_steps, |
|
|
api_name="/infer" |
|
|
) |
|
|
return {"result": result} |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|