import torch import torch.nn.functional as F import numpy as np import cv2 from PIL import Image import base64 import io import requests # API & Server Libraries from fastapi import FastAPI, HTTPException from pydantic import BaseModel from contextlib import asynccontextmanager # Diffusers & Transformers Libraries from transformers import DPTForSemanticSegmentation, DPTImageProcessor, DPTForDepthEstimation from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler class StagingRequest(BaseModel): image_url: str prompt: str negative_prompt: str = "blurry, low quality, unrealistic, distorted, ugly, watermark, text, messy, deformed, extra windows, extra doors" seed: int = 1234 models = {} @asynccontextmanager async def lifespan(app: FastAPI): print("🚀 Server starting up...") device = "cuda" torch_dtype = torch.float16 models['seg_processor'] = DPTImageProcessor.from_pretrained("Intel/dpt-large-ade") models['seg_model'] = DPTForSemanticSegmentation.from_pretrained("Intel/dpt-large-ade").to(device) models['depth_processor'] = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas") models['depth_model'] = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to(device) controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-depth", torch_dtype=torch_dtype) models['inpainting_pipe'] = StableDiffusionControlNetInpaintPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch_dtype, safety_checker=None ).to(device) models['inpainting_pipe'].scheduler = UniPCMultistepScheduler.from_config(models['inpainting_pipe'].scheduler.config) print("✅ All models loaded.") yield print("⚡ Server shutting down.") models.clear() app = FastAPI(lifespan=lifespan) def create_precise_mask(image_pil: Image.Image) -> Image.Image: processor = models['seg_processor']; model = models['seg_model'] inputs = processor(images=image_pil, return_tensors="pt").to(model.device) with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits upsampled_logits = F.interpolate(logits, size=image_pil.size[::-1], mode="bilinear", align_corners=False) pred_seg = upsampled_logits.argmax(dim=1)[0].cpu().numpy().astype(np.uint8) inclusion_indices = {2, 3, 5}; exclusion_indices = {14, 17} inclusion_mask_np = np.isin(pred_seg, list(inclusion_indices)).astype(np.uint8) * 255 exclusion_mask_np = np.isin(pred_seg, list(exclusion_indices)).astype(np.uint8) * 255 raw_mask_np = np.copy(inclusion_mask_np); raw_mask_np[exclusion_mask_np > 0] = 0 mask_filled_np = cv2.morphologyEx(raw_mask_np, cv2.MORPH_CLOSE, np.ones((10,10),np.uint8)) return Image.fromarray(mask_filled_np) def generate_depth_map(image_pil: Image.Image) -> Image.Image: processor = models['depth_processor']; model = models['depth_model'] inputs = processor(images=image_pil, return_tensors="pt").to(model.device) with torch.no_grad(): outputs = model(**inputs) predicted_depth = outputs.predicted_depth prediction = F.interpolate(predicted_depth.unsqueeze(1), size=image_pil.size[::-1], mode="bicubic", align_corners=False) depth_map = prediction.squeeze().cpu().numpy() depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min()) * 255.0 depth_map = depth_map.astype(np.uint8) return Image.fromarray(np.concatenate([depth_map[..., None]] * 3, axis=-1)) @app.post("/furnish-room/") async def furnish_room(request: StagingRequest): try: response = requests.get(request.image_url, stream=True) response.raise_for_status() init_image_pil = Image.open(io.BytesIO(response.content)).convert("RGB").resize((512, 512)) mask_image_pil = create_precise_mask(init_image_pil) control_image_pil = generate_depth_map(init_image_pil) generator = torch.Generator(device="cuda").manual_seed(request.seed) final_image = models['inpainting_pipe']( prompt=request.prompt, negative_prompt=request.negative_prompt, image=init_image_pil, mask_image=mask_image_pil, control_image=control_image_pil, num_inference_steps=30, guidance_scale=8.0, generator=generator, ).images[0] buffered = io.BytesIO() final_image.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") return {"result_image_base64": img_str} except Exception as e: raise HTTPException(status_code=500, detail=str(e))