Spaces:
Sleeping
Sleeping
| import torch | |
| from diffusers import StableDiffusionPipeline | |
| from fastapi import FastAPI, Response | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel | |
| import io | |
| import base64 | |
| from typing import Optional | |
| import uvicorn | |
| import os | |
| # Set cache directories to writable locations | |
| os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache" | |
| os.environ["HF_HOME"] = "/tmp/hf_home" | |
| os.environ["HF_DATASETS_CACHE"] = "/tmp/hf_datasets_cache" | |
| # Initialize FastAPI app | |
| app = FastAPI(title="Stable Diffusion API") | |
| # Define input model | |
| class TextToImageRequest(BaseModel): | |
| prompt: str | |
| negative_prompt: Optional[str] = None | |
| num_inference_steps: Optional[int] = 50 | |
| guidance_scale: Optional[float] = 7.5 | |
| height: Optional[int] = 512 | |
| width: Optional[int] = 512 | |
| seed: Optional[int] = None | |
| # Load the model (will be loaded when the Space is initialized) | |
| model_id = "CompVis/stable-diffusion-v1-4" | |
| # Check if CUDA is available | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| model_id, | |
| cache_dir="/tmp/diffusers_cache", | |
| token=os.environ.get("HF_TOKEN") # Use token from environment variable | |
| ) | |
| pipe = pipe.to(device) | |
| def read_root(): | |
| return {"message": "Stable Diffusion API is running. Use POST /generate endpoint."} | |
| async def generate_image(request: TextToImageRequest): | |
| try: | |
| # Set seed if provided | |
| if request.seed is not None: | |
| generator = torch.Generator(device=device).manual_seed(request.seed) | |
| else: | |
| generator = None | |
| # Generate image | |
| image = pipe( | |
| prompt=request.prompt, | |
| negative_prompt=request.negative_prompt, | |
| num_inference_steps=request.num_inference_steps, | |
| guidance_scale=request.guidance_scale, | |
| height=request.height, | |
| width=request.width, | |
| generator=generator | |
| ).images[0] | |
| # Convert to base64 | |
| buffer = io.BytesIO() | |
| image.save(buffer, format="PNG") | |
| img_str = base64.b64encode(buffer.getvalue()).decode("utf-8") | |
| return JSONResponse({ | |
| "status": "success", | |
| "image": img_str, | |
| "parameters": { | |
| "prompt": request.prompt, | |
| "negative_prompt": request.negative_prompt, | |
| "steps": request.num_inference_steps, | |
| "guidance_scale": request.guidance_scale, | |
| "dimensions": f"{request.width}x{request.height}", | |
| "seed": request.seed | |
| } | |
| }) | |
| except Exception as e: | |
| return JSONResponse( | |
| status_code=500, | |
| content={"status": "error", "message": str(e)} | |
| ) | |
| # For local testing, not necessary in Spaces | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |