Spaces:
Build error
Build error
import torch | |
import torch.nn.functional as F | |
import numpy as np | |
import cv2 | |
from PIL import Image | |
import base64 | |
import io | |
import requests | |
# API & Server Libraries | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from contextlib import asynccontextmanager | |
# Diffusers & Transformers Libraries | |
from transformers import DPTForSemanticSegmentation, DPTImageProcessor, DPTForDepthEstimation | |
from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler | |
class StagingRequest(BaseModel): | |
image_url: str | |
prompt: str | |
negative_prompt: str = "blurry, low quality, unrealistic, distorted, ugly, watermark, text, messy, deformed, extra windows, extra doors" | |
seed: int = 1234 | |
models = {} | |
async def lifespan(app: FastAPI): | |
print("π Server starting up...") | |
device = "cuda" | |
torch_dtype = torch.float16 | |
models['seg_processor'] = DPTImageProcessor.from_pretrained("Intel/dpt-large-ade") | |
models['seg_model'] = DPTForSemanticSegmentation.from_pretrained("Intel/dpt-large-ade").to(device) | |
models['depth_processor'] = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas") | |
models['depth_model'] = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to(device) | |
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-depth", torch_dtype=torch_dtype) | |
models['inpainting_pipe'] = StableDiffusionControlNetInpaintPipeline.from_pretrained( | |
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch_dtype, safety_checker=None | |
).to(device) | |
models['inpainting_pipe'].scheduler = UniPCMultistepScheduler.from_config(models['inpainting_pipe'].scheduler.config) | |
print("β All models loaded.") | |
yield | |
print("β‘ Server shutting down.") | |
models.clear() | |
app = FastAPI(lifespan=lifespan) | |
def create_precise_mask(image_pil: Image.Image) -> Image.Image: | |
processor = models['seg_processor']; model = models['seg_model'] | |
inputs = processor(images=image_pil, return_tensors="pt").to(model.device) | |
with torch.no_grad(): outputs = model(**inputs) | |
logits = outputs.logits | |
upsampled_logits = F.interpolate(logits, size=image_pil.size[::-1], mode="bilinear", align_corners=False) | |
pred_seg = upsampled_logits.argmax(dim=1)[0].cpu().numpy().astype(np.uint8) | |
inclusion_indices = {2, 3, 5}; exclusion_indices = {14, 17} | |
inclusion_mask_np = np.isin(pred_seg, list(inclusion_indices)).astype(np.uint8) * 255 | |
exclusion_mask_np = np.isin(pred_seg, list(exclusion_indices)).astype(np.uint8) * 255 | |
raw_mask_np = np.copy(inclusion_mask_np); raw_mask_np[exclusion_mask_np > 0] = 0 | |
mask_filled_np = cv2.morphologyEx(raw_mask_np, cv2.MORPH_CLOSE, np.ones((10,10),np.uint8)) | |
return Image.fromarray(mask_filled_np) | |
def generate_depth_map(image_pil: Image.Image) -> Image.Image: | |
processor = models['depth_processor']; model = models['depth_model'] | |
inputs = processor(images=image_pil, return_tensors="pt").to(model.device) | |
with torch.no_grad(): outputs = model(**inputs) | |
predicted_depth = outputs.predicted_depth | |
prediction = F.interpolate(predicted_depth.unsqueeze(1), size=image_pil.size[::-1], mode="bicubic", align_corners=False) | |
depth_map = prediction.squeeze().cpu().numpy() | |
depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min()) * 255.0 | |
depth_map = depth_map.astype(np.uint8) | |
return Image.fromarray(np.concatenate([depth_map[..., None]] * 3, axis=-1)) | |
async def furnish_room(request: StagingRequest): | |
try: | |
response = requests.get(request.image_url, stream=True) | |
response.raise_for_status() | |
init_image_pil = Image.open(io.BytesIO(response.content)).convert("RGB").resize((512, 512)) | |
mask_image_pil = create_precise_mask(init_image_pil) | |
control_image_pil = generate_depth_map(init_image_pil) | |
generator = torch.Generator(device="cuda").manual_seed(request.seed) | |
final_image = models['inpainting_pipe']( | |
prompt=request.prompt, negative_prompt=request.negative_prompt, image=init_image_pil, | |
mask_image=mask_image_pil, control_image=control_image_pil, | |
num_inference_steps=30, guidance_scale=8.0, generator=generator, | |
).images[0] | |
buffered = io.BytesIO() | |
final_image.save(buffered, format="PNG") | |
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
return {"result_image_base64": img_str} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) |