Spaces:
Build error
Build error
File size: 4,625 Bytes
8f78b88 c494a9e 8f78b88 735c7a2 e018ad4 8f78b88 e018ad4 545936e c494a9e 9fe5653 8f78b88 735c7a2 8f78b88 545936e c494a9e 8f78b88 545936e 8f78b88 545936e 8f78b88 545936e c494a9e 545936e c494a9e 545936e c494a9e 545936e 8f78b88 545936e c494a9e 545936e c494a9e 8f78b88 735c7a2 68df5ea 545936e 8f78b88 e018ad4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import torch
import torch.nn.functional as F
import numpy as np
import cv2
from PIL import Image
import base64
import io
import requests
# API & Server Libraries
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from contextlib import asynccontextmanager
# Diffusers & Transformers Libraries
from transformers import DPTForSemanticSegmentation, DPTImageProcessor, DPTForDepthEstimation
from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler
class StagingRequest(BaseModel):
image_url: str
prompt: str
negative_prompt: str = "blurry, low quality, unrealistic, distorted, ugly, watermark, text, messy, deformed, extra windows, extra doors"
seed: int = 1234
models = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
print("🚀 Server starting up...")
device = "cuda"
torch_dtype = torch.float16
models['seg_processor'] = DPTImageProcessor.from_pretrained("Intel/dpt-large-ade")
models['seg_model'] = DPTForSemanticSegmentation.from_pretrained("Intel/dpt-large-ade").to(device)
models['depth_processor'] = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
models['depth_model'] = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to(device)
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-depth", torch_dtype=torch_dtype)
models['inpainting_pipe'] = StableDiffusionControlNetInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch_dtype, safety_checker=None
).to(device)
models['inpainting_pipe'].scheduler = UniPCMultistepScheduler.from_config(models['inpainting_pipe'].scheduler.config)
print("✅ All models loaded.")
yield
print("⚡ Server shutting down.")
models.clear()
app = FastAPI(lifespan=lifespan)
def create_precise_mask(image_pil: Image.Image) -> Image.Image:
processor = models['seg_processor']; model = models['seg_model']
inputs = processor(images=image_pil, return_tensors="pt").to(model.device)
with torch.no_grad(): outputs = model(**inputs)
logits = outputs.logits
upsampled_logits = F.interpolate(logits, size=image_pil.size[::-1], mode="bilinear", align_corners=False)
pred_seg = upsampled_logits.argmax(dim=1)[0].cpu().numpy().astype(np.uint8)
inclusion_indices = {2, 3, 5}; exclusion_indices = {14, 17}
inclusion_mask_np = np.isin(pred_seg, list(inclusion_indices)).astype(np.uint8) * 255
exclusion_mask_np = np.isin(pred_seg, list(exclusion_indices)).astype(np.uint8) * 255
raw_mask_np = np.copy(inclusion_mask_np); raw_mask_np[exclusion_mask_np > 0] = 0
mask_filled_np = cv2.morphologyEx(raw_mask_np, cv2.MORPH_CLOSE, np.ones((10,10),np.uint8))
return Image.fromarray(mask_filled_np)
def generate_depth_map(image_pil: Image.Image) -> Image.Image:
processor = models['depth_processor']; model = models['depth_model']
inputs = processor(images=image_pil, return_tensors="pt").to(model.device)
with torch.no_grad(): outputs = model(**inputs)
predicted_depth = outputs.predicted_depth
prediction = F.interpolate(predicted_depth.unsqueeze(1), size=image_pil.size[::-1], mode="bicubic", align_corners=False)
depth_map = prediction.squeeze().cpu().numpy()
depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min()) * 255.0
depth_map = depth_map.astype(np.uint8)
return Image.fromarray(np.concatenate([depth_map[..., None]] * 3, axis=-1))
@app.post("/furnish-room/")
async def furnish_room(request: StagingRequest):
try:
response = requests.get(request.image_url, stream=True)
response.raise_for_status()
init_image_pil = Image.open(io.BytesIO(response.content)).convert("RGB").resize((512, 512))
mask_image_pil = create_precise_mask(init_image_pil)
control_image_pil = generate_depth_map(init_image_pil)
generator = torch.Generator(device="cuda").manual_seed(request.seed)
final_image = models['inpainting_pipe'](
prompt=request.prompt, negative_prompt=request.negative_prompt, image=init_image_pil,
mask_image=mask_image_pil, control_image=control_image_pil,
num_inference_steps=30, guidance_scale=8.0, generator=generator,
).images[0]
buffered = io.BytesIO()
final_image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return {"result_image_base64": img_str}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) |