Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, Response, Depends | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import torch | |
| from torch.cuda.amp import autocast | |
| from diffusers import DiffusionPipeline | |
| from io import BytesIO | |
| import os | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_credentials=True, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"] | |
| ) | |
| model_id = "stabilityai/stable-diffusion-xl-base-1.0" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32, use_safetensors=True, variant="fp16") | |
| pipe = pipe.to(device) | |
| def get_secret(secret_key: str = Depends(lambda x: x.headers.get("secret-key"))): | |
| return secret_key | |
| def generate(prompt: str, secret: str = Depends(get_secret)): | |
| if secret != os.getenv("SECRET_KEY"): | |
| return Response(content="Unauthorized", status_code=401) | |
| with autocast(device): | |
| image = pipe(prompt, guidance_scale=8.5).images[0] | |
| buffer = BytesIO() | |
| image.save(buffer, format="PNG") | |
| buffer.seek(0) | |
| return Response(content=buffer.getvalue(), media_type="image/png") | |