import asyncio import json import logging import traceback from pydantic import BaseModel from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import ( StreamingResponse, JSONResponse, HTMLResponse, FileResponse, ) from diffusers import AutoPipelineForImage2Image, AutoencoderTiny from compel import Compel import torch try: import intel_extension_for_pytorch as ipex except: pass from PIL import Image import numpy as np import gradio as gr import io import uuid import os import time import psutil MAX_QUEUE_SIZE = int(os.environ.get("MAX_QUEUE_SIZE", 0)) TIMEOUT = float(os.environ.get("TIMEOUT", 0)) SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None) TORCH_COMPILE = os.environ.get("TORCH_COMPILE", None) WIDTH = 512 HEIGHT = 512 # disable tiny autoencoder for better quality speed tradeoff USE_TINY_AUTOENCODER = True # check if MPS is available OSX only M1/M2/M3 chips mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() xpu_available = hasattr(torch, "xpu") and torch.xpu.is_available() device = torch.device( "cuda" if torch.cuda.is_available() else "xpu" if xpu_available else "cpu" ) torch_device = device # change to torch.float16 to save GPU memory torch_dtype = torch.float32 print(f"TIMEOUT: {TIMEOUT}") print(f"SAFETY_CHECKER: {SAFETY_CHECKER}") print(f"MAX_QUEUE_SIZE: {MAX_QUEUE_SIZE}") print(f"device: {device}") if mps_available: device = torch.device("mps") torch_device = "cpu" torch_dtype = torch.float32 if SAFETY_CHECKER == "True": pipe = AutoPipelineForImage2Image.from_pretrained( "SimianLuo/LCM_Dreamshaper_v7", ) else: pipe = AutoPipelineForImage2Image.from_pretrained( "SimianLuo/LCM_Dreamshaper_v7", safety_checker=None, ) if USE_TINY_AUTOENCODER: pipe.vae = AutoencoderTiny.from_pretrained( "madebyollin/taesd", torch_dtype=torch_dtype, use_safetensors=True ) pipe.set_progress_bar_config(disable=True) pipe.to(device=torch_device, dtype=torch_dtype).to(device) pipe.unet.to(memory_format=torch.channels_last) if psutil.virtual_memory().total < 64 * 1024**3: pipe.enable_attention_slicing() if TORCH_COMPILE: pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=True) pipe(prompt="warmup", image=[Image.new("RGB", (512, 512))]) compel_proc = Compel( tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder, truncate_long_prompts=False, ) user_queue_map = {} class InputParams(BaseModel): seed: int = 2159232 prompt: str guidance_scale: float = 8.0 strength: float = 0.5 steps: int = 4 lcm_steps: int = 50 width: int = WIDTH height: int = HEIGHT def predict( input_image: Image.Image, params: InputParams, prompt_embeds: torch.Tensor = None ): generator = torch.manual_seed(params.seed) results = pipe( prompt_embeds=prompt_embeds, generator=generator, image=input_image, strength=params.strength, num_inference_steps=params.steps, guidance_scale=params.guidance_scale, width=params.width, height=params.height, original_inference_steps=params.lcm_steps, output_type="pil", ) nsfw_content_detected = ( results.nsfw_content_detected[0] if "nsfw_content_detected" in results else False ) if nsfw_content_detected: return None return results.images[0] app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): await websocket.accept() if MAX_QUEUE_SIZE > 0 and len(user_queue_map) >= MAX_QUEUE_SIZE: print("Server is full") await websocket.send_json({"status": "error", "message": "Server is full"}) await websocket.close() return try: uid = str(uuid.uuid4()) print(f"New user connected: {uid}") await websocket.send_json( {"status": "success", "message": "Connected", "userId": uid} ) user_queue_map[uid] = {"queue": asyncio.Queue()} await websocket.send_json( {"status": "start", "message": "Start Streaming", "userId": uid} ) await handle_websocket_data(websocket, uid) except WebSocketDisconnect as e: logging.error(f"WebSocket Error: {e}, {uid}") traceback.print_exc() finally: print(f"User disconnected: {uid}") queue_value = user_queue_map.pop(uid, None) queue = queue_value.get("queue", None) if queue: while not queue.empty(): try: queue.get_nowait() except asyncio.QueueEmpty: continue @app.get("/queue_size") async def get_queue_size(): queue_size = len(user_queue_map) return JSONResponse({"queue_size": queue_size}) @app.get("/stream/{user_id}") async def stream(user_id: uuid.UUID): uid = str(user_id) try: user_queue = user_queue_map[uid] queue = user_queue["queue"] async def generate(): last_prompt: str = None prompt_embeds: torch.Tensor = None while True: data = await queue.get() input_image = data["image"] params = data["params"] if input_image is None: continue # avoid recalculate prompt embeds if last_prompt != params.prompt: print("new prompt") prompt_embeds = compel_proc(params.prompt) last_prompt = params.prompt image = predict( input_image, params, prompt_embeds, ) if image is None: continue frame_data = io.BytesIO() image.save(frame_data, format="JPEG") frame_data = frame_data.getvalue() if frame_data is not None and len(frame_data) > 0: yield b"--frame\r\nContent-Type: image/jpeg\r\n\r\n" + frame_data + b"\r\n" await asyncio.sleep(1.0 / 120.0) return StreamingResponse( generate(), media_type="multipart/x-mixed-replace;boundary=frame" ) except Exception as e: logging.error(f"Streaming Error: {e}, {user_queue_map}") traceback.print_exc() return HTTPException(status_code=404, detail="User not found") async def handle_websocket_data(websocket: WebSocket, user_id: uuid.UUID): uid = str(user_id) user_queue = user_queue_map[uid] queue = user_queue["queue"] if not queue: return HTTPException(status_code=404, detail="User not found") last_time = time.time() try: while True: data = await websocket.receive_bytes() params = await websocket.receive_json() params = InputParams(**params) pil_image = Image.open(io.BytesIO(data)) while not queue.empty(): try: queue.get_nowait() except asyncio.QueueEmpty: continue await queue.put({"image": pil_image, "params": params}) if TIMEOUT > 0 and time.time() - last_time > TIMEOUT: await websocket.send_json( { "status": "timeout", "message": "Your session has ended", "userId": uid, } ) await websocket.close() return except Exception as e: logging.error(f"Error: {e}") traceback.print_exc() @app.get("/", response_class=HTMLResponse) async def root(): return FileResponse("./static/img2img.html")