prithivMLmods commited on
Commit
e933213
·
verified ·
1 Parent(s): 05bb57b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -267
app.py CHANGED
@@ -1,98 +1,15 @@
1
  import os
2
- import sys
3
- import subprocess
4
  import spaces
5
  import gradio as gr
6
  import numpy as np
7
  import torch
8
- import cv2
9
- import tempfile
10
- import shutil
11
- import glob
12
- from PIL import Image
13
  from typing import Iterable
14
  from gradio.themes import Soft
15
  from gradio.themes.utils import colors, fonts, sizes
 
16
 
17
- # ---------------------------------------------------------
18
- # 1. ENVIRONMENT SETUP & REPO CLONING
19
- # ---------------------------------------------------------
20
- REPO_URL = "https://github.com/facebookresearch/sam-3d-body.git"
21
- REPO_DIR = "sam-3d-body"
22
-
23
- def setup_sam_3d_env():
24
- """
25
- Clones the repo, installs dependencies, and fixes sys.path
26
- so that 'utils', 'tools', and 'sam_3d_body' can be imported.
27
- """
28
- # 1. Clone if not exists
29
- if not os.path.exists(REPO_DIR):
30
- print(f"Cloning SAM 3D Body repository from {REPO_URL}...")
31
- try:
32
- subprocess.run(["git", "clone", REPO_URL], check=True)
33
- print("Installing sam-3d-body package in editable mode...")
34
- # We install using pip to resolve internal package dependencies
35
- subprocess.run([sys.executable, "-m", "pip", "install", "-e", REPO_DIR], check=True)
36
-
37
- # Install other requirements usually needed
38
- subprocess.run([sys.executable, "-m", "pip", "install", "trimesh", "opencv-python", "matplotlib"], check=True)
39
- except subprocess.CalledProcessError as e:
40
- print(f"Error during setup: {e}")
41
- return False
42
-
43
- # 2. Add Critical Paths to sys.path
44
- repo_abs_path = os.path.abspath(REPO_DIR)
45
- notebook_path = os.path.join(repo_abs_path, "notebook")
46
-
47
- # CRITICAL: Add repo root first so 'import tools' and 'import sam_3d_body' work inside utils.py
48
- if repo_abs_path not in sys.path:
49
- sys.path.insert(0, repo_abs_path)
50
- print(f"Added to sys.path: {repo_abs_path}")
51
-
52
- # Add notebook folder so we can 'import utils'
53
- if notebook_path not in sys.path:
54
- sys.path.insert(0, notebook_path)
55
- print(f"Added to sys.path: {notebook_path}")
56
-
57
- return True
58
-
59
- # Run setup immediately
60
- env_ready = setup_sam_3d_env()
61
-
62
- # ---------------------------------------------------------
63
- # 2. IMPORTS
64
- # ---------------------------------------------------------
65
-
66
- # --- Import SAM3 (Segmentation) ---
67
- try:
68
- from transformers import Sam3Processor, Sam3Model
69
- SAM3_AVAILABLE = True
70
- except ImportError:
71
- print("Warning: transformers library not found or outdated. SAM3 will be disabled.")
72
- SAM3_AVAILABLE = False
73
-
74
- # --- Import SAM 3D Body Utils ---
75
- # We use a specific alias to avoid confusion with standard python utils
76
- sam3d_utils = None
77
- SAM3D_AVAILABLE = False
78
-
79
- if env_ready:
80
- try:
81
- # Now that sys.path is fixed, this import should work
82
- # and utils.py will successfully find 'tools' and 'sam_3d_body'
83
- import utils as sam3d_utils_module
84
- sam3d_utils = sam3d_utils_module
85
- SAM3D_AVAILABLE = True
86
- print("SAM 3D Body utils imported successfully.")
87
- except ImportError as e:
88
- print(f"Error importing SAM 3D Body utils: {e}")
89
- print("This usually happens if 'tools' or 'sam_3d_body' cannot be found by utils.py")
90
- import traceback
91
- traceback.print_exc()
92
-
93
- # ---------------------------------------------------------
94
- # 3. THEME DEFINITION
95
- # ---------------------------------------------------------
96
  colors.steel_blue = colors.Color(
97
  name="steel_blue",
98
  c50="#EBF3F8",
@@ -158,222 +75,119 @@ steel_blue_theme = SteelBlueTheme()
158
  device = "cuda" if torch.cuda.is_available() else "cpu"
159
  print(f"Using device: {device}")
160
 
161
- # ---------------------------------------------------------
162
- # 4. LOAD MODELS
163
- # ---------------------------------------------------------
164
-
165
- # --- 1. Load SAM3 ---
166
- sam3_model = None
167
- sam3_processor = None
168
- if SAM3_AVAILABLE:
169
- try:
170
- print("Loading SAM3 Model...")
171
- sam3_model = Sam3Model.from_pretrained("facebook/sam3").to(device)
172
- sam3_processor = Sam3Processor.from_pretrained("facebook/sam3")
173
- print("SAM3 Loaded.")
174
- except Exception as e:
175
- print(f"Error loading SAM3: {e}")
176
-
177
- # --- 2. Load SAM 3D Body ---
178
- sam3d_estimator = None
179
- sam3d_visualizer = None
180
-
181
- if SAM3D_AVAILABLE:
182
- try:
183
- print("Loading SAM 3D Body Estimator (this may take a moment)...")
184
- # Initialize estimator using the utility function from the repo
185
- # Note: detector_name="vitdet" is default, requiring 'tools' import to work
186
- sam3d_estimator = sam3d_utils.setup_sam_3d_body(
187
- hf_repo_id="facebook/sam-3d-body-dinov3",
188
- device=device
189
- )
190
- sam3d_visualizer = sam3d_utils.setup_visualizer()
191
- print("SAM 3D Body Loaded Successfully.")
192
- except Exception as e:
193
- print(f"Error loading SAM 3D Body model: {e}")
194
- # If it fails, we set the flag to False so the UI handles it gracefully
195
- SAM3D_AVAILABLE = False
196
- import traceback
197
- traceback.print_exc()
198
-
199
- # ---------------------------------------------------------
200
- # 5. INFERENCE FUNCTIONS
201
- # ---------------------------------------------------------
202
 
203
  @spaces.GPU
204
  def segment_image(input_image, text_prompt, threshold=0.5):
205
- """Handler for Tab 1: Segmentation"""
206
  if input_image is None:
207
  raise gr.Error("Please upload an image.")
208
  if not text_prompt:
209
- raise gr.Error("Please enter a text prompt.")
210
- if sam3_model is None:
211
- raise gr.Error("SAM3 Model is not loaded.")
 
212
 
 
213
  image_pil = input_image.convert("RGB")
214
- inputs = sam3_processor(images=image_pil, text=text_prompt, return_tensors="pt").to(device)
215
 
 
 
 
 
216
  with torch.no_grad():
217
- outputs = sam3_model(**inputs)
218
 
219
- results = sam3_processor.post_process_instance_segmentation(
 
220
  outputs,
221
  threshold=threshold,
222
  mask_threshold=0.5,
223
  target_sizes=inputs.get("original_sizes").tolist()
224
  )[0]
225
 
226
- masks = results['masks'].cpu().numpy()
227
- scores = results['scores'].cpu().numpy()
 
 
 
228
 
229
  annotations = []
230
- for i, mask in enumerate(masks):
231
- label = f"{text_prompt} ({scores[i]:.2f})"
 
 
 
 
 
 
 
232
  annotations.append((mask, label))
233
 
 
234
  return (image_pil, annotations)
235
 
236
-
237
- @spaces.GPU
238
- def process_3d_body(input_image):
239
- """Handler for Tab 2: 3D Body Reconstruction"""
240
- if input_image is None:
241
- raise gr.Error("Please upload an image.")
242
-
243
- if not SAM3D_AVAILABLE or sam3d_estimator is None:
244
- raise gr.Error("SAM 3D Body libraries or model failed to load. Check console logs.")
245
-
246
- # Convert PIL to CV2 BGR for the estimator
247
- img_np = np.array(input_image.convert("RGB"))
248
- img_cv2 = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
249
-
250
- # The estimator.process_one_image expects a file path
251
- with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file:
252
- tmp_path = tmp_file.name
253
- cv2.imwrite(tmp_path, img_cv2)
254
-
255
- try:
256
- print(f"Processing 3D Body for {tmp_path}...")
257
-
258
- # 1. Run Inference
259
- # process_one_image is a method of the estimator class inside sam-3d-body
260
- outputs = sam3d_estimator.process_one_image(tmp_path)
261
-
262
- if not outputs:
263
- return None, None, None, "No people detected."
264
-
265
- # 2. 2D Keypoints Visualization
266
- vis_results_2d = sam3d_utils.visualize_2d_results(img_cv2, outputs, sam3d_visualizer)
267
- # Combine if multiple, or just take first for display simplicity.
268
- # Usually vis_results_2d is a list of full images with drawings.
269
- if vis_results_2d:
270
- # For simplicity, if multiple people, the last one overrides or we assume 1 main person
271
- # Ideally we'd grid them, but for Gradio output, let's take the first result's image
272
- res_2d_rgb = cv2.cvtColor(vis_results_2d[0], cv2.COLOR_BGR2RGB)
273
- else:
274
- res_2d_rgb = img_np
275
-
276
- # 3. 3D Overlay Visualization
277
- # visualize_3d_mesh returns a wide image (Original | Overlay | White | Side)
278
- mesh_results_wide = sam3d_utils.visualize_3d_mesh(img_cv2, outputs, sam3d_estimator.faces)
279
- if mesh_results_wide:
280
- res_3d_overlay_rgb = cv2.cvtColor(mesh_results_wide[0], cv2.COLOR_BGR2RGB)
281
- else:
282
- res_3d_overlay_rgb = img_np
283
-
284
- # 4. Save PLY for Model3D
285
- # Create a unique directory for this run
286
- output_dir = tempfile.mkdtemp()
287
- image_name = "gradio_mesh"
288
-
289
- # save_mesh_results returns list of paths to .ply files
290
- ply_files = sam3d_utils.save_mesh_results(
291
- img_cv2,
292
- outputs,
293
- sam3d_estimator.faces,
294
- output_dir,
295
- image_name
296
- )
297
-
298
- ply_path = None
299
- if ply_files and len(ply_files) > 0:
300
- ply_path = ply_files[0] # Return the first mesh found
301
-
302
- status_msg = f"Detected {len(outputs)} person(s). Displaying Person 0."
303
-
304
- return res_2d_rgb, res_3d_overlay_rgb, ply_path, status_msg
305
-
306
- except Exception as e:
307
- import traceback
308
- traceback.print_exc()
309
- raise gr.Error(f"Inference failed: {str(e)}")
310
-
311
- finally:
312
- # Cleanup input temp file
313
- if os.path.exists(tmp_path):
314
- os.remove(tmp_path)
315
-
316
- # ---------------------------------------------------------
317
- # 6. GUI
318
- # ---------------------------------------------------------
319
-
320
- css = """
321
  #col-container {
322
  margin: 0 auto;
323
- max-width: 1200px;
324
  }
325
- #main-title h1 {font-size: 2.1em !important; text-align: center;}
326
- .gradio-container {min-height: 0px !important;}
327
  """
328
 
329
  with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
330
  with gr.Column(elem_id="col-container"):
331
- gr.Markdown("# **SAM Integrated Vision Suite**", elem_id="main-title")
 
 
 
332
 
333
- with gr.Tabs():
334
- # ================= TAB 1: SEGMENTATION =================
335
- with gr.Tab("SAM3 Segmentation"):
336
- gr.Markdown("Segment objects using **SAM3** with text prompts.")
337
- with gr.Row():
338
- with gr.Column(scale=1):
339
- t1_input = gr.Image(label="Input Image", type="pil", height=350)
340
- t1_prompt = gr.Textbox(label="Text Prompt", placeholder="e.g., cat, face...")
341
- t1_thresh = gr.Slider(0.0, 1.0, 0.4, step=0.05, label="Threshold")
342
- t1_btn = gr.Button("Segment", variant="primary")
343
- with gr.Column(scale=1.5):
344
- t1_output = gr.AnnotatedImage(label="Segmented Output", height=450)
345
-
346
- t1_btn.click(segment_image, [t1_input, t1_prompt, t1_thresh], [t1_output])
347
-
348
- # Optional examples if files exist
349
- # gr.Examples(...)
350
 
351
- # ================= TAB 2: 3D BODY =================
352
- with gr.Tab("SAM 3D Body"):
353
- gr.Markdown("Detect human bodies and reconstruct **3D Meshes**.")
354
 
355
  with gr.Row():
356
- with gr.Column(scale=1):
357
- t2_input = gr.Image(label="Input Image", type="pil", height=350)
358
- t2_btn = gr.Button("Generate 3D Body", variant="primary")
359
- t2_status = gr.Textbox(label="Status", interactive=False)
360
-
361
- with gr.Column(scale=2):
362
- with gr.Row():
363
- t2_vis_2d = gr.Image(label="2D Detection", type="numpy")
364
- t2_vis_overlay = gr.Image(label="3D Visualization (Original | Overlay | White | Side)", type="numpy")
365
-
366
- t2_model_3d = gr.Model3D(
367
- label="Interactive 3D Mesh",
368
- clear_color=[0.0, 0.0, 0.0, 0.0],
369
- camera_position=[0, 0, 4.0]
370
- )
371
 
372
- t2_btn.click(
373
- process_3d_body,
374
- inputs=[t2_input],
375
- outputs=[t2_vis_2d, t2_vis_overlay, t2_model_3d, t2_status]
376
- )
377
 
378
  if __name__ == "__main__":
379
  demo.launch(mcp_server=True, ssr_mode=False, show_error=True)
 
1
  import os
 
 
2
  import spaces
3
  import gradio as gr
4
  import numpy as np
5
  import torch
6
+ import random
7
+ from PIL import Image, ImageDraw
 
 
 
8
  from typing import Iterable
9
  from gradio.themes import Soft
10
  from gradio.themes.utils import colors, fonts, sizes
11
+ from transformers import Sam3Processor, Sam3Model
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  colors.steel_blue = colors.Color(
14
  name="steel_blue",
15
  c50="#EBF3F8",
 
75
  device = "cuda" if torch.cuda.is_available() else "cpu"
76
  print(f"Using device: {device}")
77
 
78
+ try:
79
+ print("Loading SAM3 Model and Processor...")
80
+ model = Sam3Model.from_pretrained("facebook/sam3").to(device)
81
+ processor = Sam3Processor.from_pretrained("facebook/sam3")
82
+ print("Model loaded successfully.")
83
+
84
+ except Exception as e:
85
+ print(f"Error loading model: {e}")
86
+ print("Ensure you have the correct libraries installed and access to the model.")
87
+ # Fallback/Placeholder for demonstration if model doesn't exist in environment yet
88
+ model = None
89
+ processor = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  @spaces.GPU
92
  def segment_image(input_image, text_prompt, threshold=0.5):
 
93
  if input_image is None:
94
  raise gr.Error("Please upload an image.")
95
  if not text_prompt:
96
+ raise gr.Error("Please enter a text prompt (e.g., 'cat', 'face').")
97
+
98
+ if model is None or processor is None:
99
+ raise gr.Error("Model not loaded correctly.")
100
 
101
+ # Convert image to RGB
102
  image_pil = input_image.convert("RGB")
 
103
 
104
+ # Preprocess
105
+ inputs = processor(images=image_pil, text=text_prompt, return_tensors="pt").to(device)
106
+
107
+ # Inference
108
  with torch.no_grad():
109
+ outputs = model(**inputs)
110
 
111
+ # Post-process results
112
+ results = processor.post_process_instance_segmentation(
113
  outputs,
114
  threshold=threshold,
115
  mask_threshold=0.5,
116
  target_sizes=inputs.get("original_sizes").tolist()
117
  )[0]
118
 
119
+ masks = results['masks'] # Boolean tensor [N, H, W]
120
+ scores = results['scores']
121
+
122
+ # Prepare for Gradio AnnotatedImage
123
+ # Gradio expects (image, [(mask, label), ...])
124
 
125
  annotations = []
126
+ masks_np = masks.cpu().numpy()
127
+ scores_np = scores.cpu().numpy()
128
+
129
+ for i, mask in enumerate(masks_np):
130
+ # mask is a boolean array (True/False).
131
+ # AnnotatedImage handles the coloring automatically.
132
+ # We just pass the mask and a label.
133
+ score_val = scores_np[i]
134
+ label = f"{text_prompt} ({score_val:.2f})"
135
  annotations.append((mask, label))
136
 
137
+ # Return tuple format for AnnotatedImage
138
  return (image_pil, annotations)
139
 
140
+ css="""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  #col-container {
142
  margin: 0 auto;
143
+ max-width: 980px;
144
  }
145
+ #main-title h1 {font-size: 2.1em !important;}
 
146
  """
147
 
148
  with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
149
  with gr.Column(elem_id="col-container"):
150
+ gr.Markdown(
151
+ "# **SAM3 Image Segmentation**",
152
+ elem_id="main-title"
153
+ )
154
 
155
+ gr.Markdown("Segment objects in images using **SAM3** (Segment Anything Model 3) with text prompts.")
156
+
157
+ with gr.Row():
158
+ with gr.Column(scale=1):
159
+ input_image = gr.Image(label="Input Image", type="pil", height=300)
160
+ text_prompt = gr.Textbox(
161
+ label="Text Prompt",
162
+ placeholder="e.g., cat, ear, car wheel...",
163
+ )
164
+
165
+ run_button = gr.Button("Segment", variant="primary")
 
 
 
 
 
 
166
 
167
+ with gr.Column(scale=1.5):
168
+ output_image = gr.AnnotatedImage(label="Segmented Output", height=380)
 
169
 
170
  with gr.Row():
171
+ threshold = gr.Slider(label="Confidence Threshold", minimum=0.0, maximum=1.0, value=0.4, step=0.05)
172
+
173
+ gr.Examples(
174
+ examples=[
175
+ ["examples/player.jpg", "player in white", 0.5],
176
+ ["examples/goldencat.webp", "black cat", 0.4],
177
+ ["examples/taxi.jpg", "blue taxi", 0.5],
178
+ ],
179
+ inputs=[input_image, text_prompt, threshold],
180
+ outputs=[output_image],
181
+ fn=segment_image,
182
+ cache_examples="lazy",
183
+ label="Examples"
184
+ )
 
185
 
186
+ run_button.click(
187
+ fn=segment_image,
188
+ inputs=[input_image, text_prompt, threshold],
189
+ outputs=[output_image]
190
+ )
191
 
192
  if __name__ == "__main__":
193
  demo.launch(mcp_server=True, ssr_mode=False, show_error=True)