ayushpfullstack commited on
Commit
545936e
Β·
verified Β·
1 Parent(s): 1a9ca68

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +13 -61
main.py CHANGED
@@ -12,129 +12,81 @@ from fastapi import FastAPI, HTTPException
12
  from pydantic import BaseModel
13
  from contextlib import asynccontextmanager
14
 
15
- # Diffusers & Transformers Libraries - UPDATED IMPORTS
16
  from transformers import DPTForSemanticSegmentation, DPTImageProcessor, DPTForDepthEstimation
17
  from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler
18
 
19
- # --- API Data Models ---
20
  class StagingRequest(BaseModel):
21
  image_url: str
22
  prompt: str
23
  negative_prompt: str = "blurry, low quality, unrealistic, distorted, ugly, watermark, text, messy, deformed, extra windows, extra doors"
24
  seed: int = 1234
25
 
26
- # --- Global State & Model Loading ---
27
  models = {}
28
 
29
  @asynccontextmanager
30
  async def lifespan(app: FastAPI):
31
- # STARTUP: Load all models
32
- print("πŸš€ Server starting up: Loading AI models...")
33
- device = "cuda" if torch.cuda.is_available() else "cpu"
34
- torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
35
-
36
- # --- UPDATED: Load processors and models separately ---
37
- # Segmentation model
38
  models['seg_processor'] = DPTImageProcessor.from_pretrained("Intel/dpt-large-ade")
39
  models['seg_model'] = DPTForSemanticSegmentation.from_pretrained("Intel/dpt-large-ade").to(device)
40
-
41
- # Depth estimation model
42
  models['depth_processor'] = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
43
  models['depth_model'] = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to(device)
44
-
45
- # ControlNet and Inpainting Pipeline
46
  controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-depth", torch_dtype=torch_dtype)
47
  models['inpainting_pipe'] = StableDiffusionControlNetInpaintPipeline.from_pretrained(
48
- "runwayml/stable-diffusion-v1-5",
49
- controlnet=controlnet,
50
- torch_dtype=torch_dtype,
51
- safety_checker=None
52
  ).to(device)
53
  models['inpainting_pipe'].scheduler = UniPCMultistepScheduler.from_config(models['inpainting_pipe'].scheduler.config)
54
-
55
- print("βœ… All models loaded and ready.")
56
  yield
57
- # SHUTDOWN: Clean up
58
  print("⚑ Server shutting down.")
59
  models.clear()
60
 
61
  app = FastAPI(lifespan=lifespan)
62
 
63
- # --- Helper Functions (Core Logic) ---
64
  def create_precise_mask(image_pil: Image.Image) -> Image.Image:
65
- # --- UPDATED: Manual processing and inference ---
66
- processor = models['seg_processor']
67
- model = models['seg_model']
68
-
69
  inputs = processor(images=image_pil, return_tensors="pt").to(model.device)
70
- with torch.no_grad():
71
- outputs = model(**inputs)
72
-
73
  logits = outputs.logits
74
- # ADE20k has 150 classes
75
  upsampled_logits = F.interpolate(logits, size=image_pil.size[::-1], mode="bilinear", align_corners=False)
76
  pred_seg = upsampled_logits.argmax(dim=1)[0].cpu().numpy().astype(np.uint8)
77
-
78
- # Use a simplified mapping for room structure labels
79
- # Wall=2, Floor=3, Ceiling=5 (based on common ADE20k indices)
80
- inclusion_indices = {2, 3, 5}
81
- # Door=14, Window=17
82
- exclusion_indices = {14, 17}
83
-
84
  inclusion_mask_np = np.isin(pred_seg, list(inclusion_indices)).astype(np.uint8) * 255
85
  exclusion_mask_np = np.isin(pred_seg, list(exclusion_indices)).astype(np.uint8) * 255
86
-
87
- raw_mask_np = np.copy(inclusion_mask_np)
88
- raw_mask_np[exclusion_mask_np > 0] = 0
89
  mask_filled_np = cv2.morphologyEx(raw_mask_np, cv2.MORPH_CLOSE, np.ones((10,10),np.uint8))
90
  return Image.fromarray(mask_filled_np)
91
 
92
  def generate_depth_map(image_pil: Image.Image) -> Image.Image:
93
- # --- UPDATED: Manual processing and inference ---
94
- processor = models['depth_processor']
95
- model = models['depth_model']
96
-
97
  inputs = processor(images=image_pil, return_tensors="pt").to(model.device)
98
- with torch.no_grad():
99
- outputs = model(**inputs)
100
-
101
  predicted_depth = outputs.predicted_depth
102
  prediction = F.interpolate(predicted_depth.unsqueeze(1), size=image_pil.size[::-1], mode="bicubic", align_corners=False)
103
-
104
  depth_map = prediction.squeeze().cpu().numpy()
105
  depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min()) * 255.0
106
  depth_map = depth_map.astype(np.uint8)
107
  return Image.fromarray(np.concatenate([depth_map[..., None]] * 3, axis=-1))
108
 
109
- # --- API Endpoints ---
110
- @app.get("/")
111
- def read_root():
112
- return {"status": "Virtual Staging API is running."}
113
-
114
  @app.post("/furnish-room/")
115
  async def furnish_room(request: StagingRequest):
116
  try:
117
  response = requests.get(request.image_url, stream=True)
118
  response.raise_for_status()
119
- image_bytes = response.content
120
- init_image_pil = Image.open(io.BytesIO(image_bytes)).convert("RGB").resize((512, 512))
121
-
122
  mask_image_pil = create_precise_mask(init_image_pil)
123
  control_image_pil = generate_depth_map(init_image_pil)
124
-
125
  generator = torch.Generator(device="cuda").manual_seed(request.seed)
126
  final_image = models['inpainting_pipe'](
127
  prompt=request.prompt, negative_prompt=request.negative_prompt, image=init_image_pil,
128
  mask_image=mask_image_pil, control_image=control_image_pil,
129
  num_inference_steps=30, guidance_scale=8.0, generator=generator,
130
  ).images[0]
131
-
132
  buffered = io.BytesIO()
133
  final_image.save(buffered, format="PNG")
134
  img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
135
-
136
  return {"result_image_base64": img_str}
137
- except requests.exceptions.RequestException as e:
138
- raise HTTPException(status_code=400, detail=f"Failed to fetch image from URL: {e}")
139
  except Exception as e:
140
  raise HTTPException(status_code=500, detail=str(e))
 
12
  from pydantic import BaseModel
13
  from contextlib import asynccontextmanager
14
 
15
+ # Diffusers & Transformers Libraries
16
  from transformers import DPTForSemanticSegmentation, DPTImageProcessor, DPTForDepthEstimation
17
  from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler
18
 
 
19
  class StagingRequest(BaseModel):
20
  image_url: str
21
  prompt: str
22
  negative_prompt: str = "blurry, low quality, unrealistic, distorted, ugly, watermark, text, messy, deformed, extra windows, extra doors"
23
  seed: int = 1234
24
 
 
25
  models = {}
26
 
27
  @asynccontextmanager
28
  async def lifespan(app: FastAPI):
29
+ print("πŸš€ Server starting up...")
30
+ device = "cuda"
31
+ torch_dtype = torch.float16
 
 
 
 
32
  models['seg_processor'] = DPTImageProcessor.from_pretrained("Intel/dpt-large-ade")
33
  models['seg_model'] = DPTForSemanticSegmentation.from_pretrained("Intel/dpt-large-ade").to(device)
 
 
34
  models['depth_processor'] = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
35
  models['depth_model'] = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to(device)
 
 
36
  controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-depth", torch_dtype=torch_dtype)
37
  models['inpainting_pipe'] = StableDiffusionControlNetInpaintPipeline.from_pretrained(
38
+ "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch_dtype, safety_checker=None
 
 
 
39
  ).to(device)
40
  models['inpainting_pipe'].scheduler = UniPCMultistepScheduler.from_config(models['inpainting_pipe'].scheduler.config)
41
+ print("βœ… All models loaded.")
 
42
  yield
 
43
  print("⚑ Server shutting down.")
44
  models.clear()
45
 
46
  app = FastAPI(lifespan=lifespan)
47
 
 
48
  def create_precise_mask(image_pil: Image.Image) -> Image.Image:
49
+ processor = models['seg_processor']; model = models['seg_model']
 
 
 
50
  inputs = processor(images=image_pil, return_tensors="pt").to(model.device)
51
+ with torch.no_grad(): outputs = model(**inputs)
 
 
52
  logits = outputs.logits
 
53
  upsampled_logits = F.interpolate(logits, size=image_pil.size[::-1], mode="bilinear", align_corners=False)
54
  pred_seg = upsampled_logits.argmax(dim=1)[0].cpu().numpy().astype(np.uint8)
55
+ inclusion_indices = {2, 3, 5}; exclusion_indices = {14, 17}
 
 
 
 
 
 
56
  inclusion_mask_np = np.isin(pred_seg, list(inclusion_indices)).astype(np.uint8) * 255
57
  exclusion_mask_np = np.isin(pred_seg, list(exclusion_indices)).astype(np.uint8) * 255
58
+ raw_mask_np = np.copy(inclusion_mask_np); raw_mask_np[exclusion_mask_np > 0] = 0
 
 
59
  mask_filled_np = cv2.morphologyEx(raw_mask_np, cv2.MORPH_CLOSE, np.ones((10,10),np.uint8))
60
  return Image.fromarray(mask_filled_np)
61
 
62
  def generate_depth_map(image_pil: Image.Image) -> Image.Image:
63
+ processor = models['depth_processor']; model = models['depth_model']
 
 
 
64
  inputs = processor(images=image_pil, return_tensors="pt").to(model.device)
65
+ with torch.no_grad(): outputs = model(**inputs)
 
 
66
  predicted_depth = outputs.predicted_depth
67
  prediction = F.interpolate(predicted_depth.unsqueeze(1), size=image_pil.size[::-1], mode="bicubic", align_corners=False)
 
68
  depth_map = prediction.squeeze().cpu().numpy()
69
  depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min()) * 255.0
70
  depth_map = depth_map.astype(np.uint8)
71
  return Image.fromarray(np.concatenate([depth_map[..., None]] * 3, axis=-1))
72
 
 
 
 
 
 
73
  @app.post("/furnish-room/")
74
  async def furnish_room(request: StagingRequest):
75
  try:
76
  response = requests.get(request.image_url, stream=True)
77
  response.raise_for_status()
78
+ init_image_pil = Image.open(io.BytesIO(response.content)).convert("RGB").resize((512, 512))
 
 
79
  mask_image_pil = create_precise_mask(init_image_pil)
80
  control_image_pil = generate_depth_map(init_image_pil)
 
81
  generator = torch.Generator(device="cuda").manual_seed(request.seed)
82
  final_image = models['inpainting_pipe'](
83
  prompt=request.prompt, negative_prompt=request.negative_prompt, image=init_image_pil,
84
  mask_image=mask_image_pil, control_image=control_image_pil,
85
  num_inference_steps=30, guidance_scale=8.0, generator=generator,
86
  ).images[0]
 
87
  buffered = io.BytesIO()
88
  final_image.save(buffered, format="PNG")
89
  img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
 
90
  return {"result_image_base64": img_str}
 
 
91
  except Exception as e:
92
  raise HTTPException(status_code=500, detail=str(e))