File size: 4,625 Bytes
8f78b88
c494a9e
8f78b88
 
 
 
 
735c7a2
e018ad4
 
8f78b88
 
 
e018ad4
545936e
c494a9e
9fe5653
8f78b88
 
735c7a2
8f78b88
 
 
 
 
 
 
 
545936e
 
 
c494a9e
 
 
 
8f78b88
 
545936e
8f78b88
 
545936e
8f78b88
 
 
 
 
 
 
545936e
c494a9e
545936e
c494a9e
 
 
545936e
c494a9e
 
545936e
8f78b88
 
 
 
545936e
c494a9e
545936e
c494a9e
 
 
 
 
 
8f78b88
 
 
 
735c7a2
68df5ea
545936e
8f78b88
 
 
 
 
 
 
 
 
 
 
 
 
e018ad4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
import torch.nn.functional as F
import numpy as np
import cv2
from PIL import Image
import base64
import io
import requests

# API & Server Libraries
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from contextlib import asynccontextmanager

# Diffusers & Transformers Libraries
from transformers import DPTForSemanticSegmentation, DPTImageProcessor, DPTForDepthEstimation
from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler

class StagingRequest(BaseModel):
    image_url: str
    prompt: str
    negative_prompt: str = "blurry, low quality, unrealistic, distorted, ugly, watermark, text, messy, deformed, extra windows, extra doors"
    seed: int = 1234

models = {}

@asynccontextmanager
async def lifespan(app: FastAPI):
    print("🚀 Server starting up...")
    device = "cuda"
    torch_dtype = torch.float16
    models['seg_processor'] = DPTImageProcessor.from_pretrained("Intel/dpt-large-ade")
    models['seg_model'] = DPTForSemanticSegmentation.from_pretrained("Intel/dpt-large-ade").to(device)
    models['depth_processor'] = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
    models['depth_model'] = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to(device)
    controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-depth", torch_dtype=torch_dtype)
    models['inpainting_pipe'] = StableDiffusionControlNetInpaintPipeline.from_pretrained(
        "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch_dtype, safety_checker=None
    ).to(device)
    models['inpainting_pipe'].scheduler = UniPCMultistepScheduler.from_config(models['inpainting_pipe'].scheduler.config)
    print("✅ All models loaded.")
    yield
    print("⚡ Server shutting down.")
    models.clear()

app = FastAPI(lifespan=lifespan)

def create_precise_mask(image_pil: Image.Image) -> Image.Image:
    processor = models['seg_processor']; model = models['seg_model']
    inputs = processor(images=image_pil, return_tensors="pt").to(model.device)
    with torch.no_grad(): outputs = model(**inputs)
    logits = outputs.logits
    upsampled_logits = F.interpolate(logits, size=image_pil.size[::-1], mode="bilinear", align_corners=False)
    pred_seg = upsampled_logits.argmax(dim=1)[0].cpu().numpy().astype(np.uint8)
    inclusion_indices = {2, 3, 5}; exclusion_indices = {14, 17}
    inclusion_mask_np = np.isin(pred_seg, list(inclusion_indices)).astype(np.uint8) * 255
    exclusion_mask_np = np.isin(pred_seg, list(exclusion_indices)).astype(np.uint8) * 255
    raw_mask_np = np.copy(inclusion_mask_np); raw_mask_np[exclusion_mask_np > 0] = 0
    mask_filled_np = cv2.morphologyEx(raw_mask_np, cv2.MORPH_CLOSE, np.ones((10,10),np.uint8))
    return Image.fromarray(mask_filled_np)

def generate_depth_map(image_pil: Image.Image) -> Image.Image:
    processor = models['depth_processor']; model = models['depth_model']
    inputs = processor(images=image_pil, return_tensors="pt").to(model.device)
    with torch.no_grad(): outputs = model(**inputs)
    predicted_depth = outputs.predicted_depth
    prediction = F.interpolate(predicted_depth.unsqueeze(1), size=image_pil.size[::-1], mode="bicubic", align_corners=False)
    depth_map = prediction.squeeze().cpu().numpy()
    depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min()) * 255.0
    depth_map = depth_map.astype(np.uint8)
    return Image.fromarray(np.concatenate([depth_map[..., None]] * 3, axis=-1))

@app.post("/furnish-room/")
async def furnish_room(request: StagingRequest):
    try:
        response = requests.get(request.image_url, stream=True)
        response.raise_for_status()
        init_image_pil = Image.open(io.BytesIO(response.content)).convert("RGB").resize((512, 512))
        mask_image_pil = create_precise_mask(init_image_pil)
        control_image_pil = generate_depth_map(init_image_pil)
        generator = torch.Generator(device="cuda").manual_seed(request.seed)
        final_image = models['inpainting_pipe'](
            prompt=request.prompt, negative_prompt=request.negative_prompt, image=init_image_pil,
            mask_image=mask_image_pil, control_image=control_image_pil,
            num_inference_steps=30, guidance_scale=8.0, generator=generator,
        ).images[0]
        buffered = io.BytesIO()
        final_image.save(buffered, format="PNG")
        img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
        return {"result_image_base64": img_str}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))