wzhouxiff commited on
Commit
979cf8b
1 Parent(s): c24e97b
Files changed (2) hide show
  1. app.py +50 -2
  2. objctrl_2_5d/utils/ui_utils.py +0 -43
app.py CHANGED
@@ -6,8 +6,10 @@ import torch
6
  from gradio_image_prompter import ImagePrompter
7
  from sam2.sam2_image_predictor import SAM2ImagePredictor
8
  from omegaconf import OmegaConf
 
 
9
 
10
- from objctrl_2_5d.utils.ui_utils import process_image, get_camera_pose, run_segment, run_depth, get_points, undo_points
11
 
12
 
13
  from cameractrl.inference import get_pipeline
@@ -114,6 +116,51 @@ pipeline = get_pipeline(model_id, "unet", model_config['down_block_types'], mode
114
  # d_model_NK = None
115
  # pipeline = None
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  # -------------- UI definition --------------
119
  with gr.Blocks() as demo:
@@ -270,7 +317,8 @@ with gr.Blocks() as demo:
270
  )
271
 
272
  select_button.click(
273
- run_segment(segmentor),
 
274
  [canvas, original_image, mask_logits],
275
  [mask, mask_output, masked_original_image, mask_logits]
276
  )
 
6
  from gradio_image_prompter import ImagePrompter
7
  from sam2.sam2_image_predictor import SAM2ImagePredictor
8
  from omegaconf import OmegaConf
9
+ from PIL import Image
10
+ import numpy as np
11
 
12
+ from objctrl_2_5d.utils.ui_utils import process_image, get_camera_pose, get_subject_points, run_depth, get_points, undo_points, mask_image
13
 
14
 
15
  from cameractrl.inference import get_pipeline
 
116
  # d_model_NK = None
117
  # pipeline = None
118
 
119
+ ### run the demo ##
120
+ @spaces.GPU(duration=50)
121
+ # def run_segment(segmentor):
122
+ def segment(canvas, image, logits):
123
+ if logits is not None:
124
+ logits *= 32.0
125
+ _, points = get_subject_points(canvas)
126
+ image = np.array(image)
127
+
128
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
129
+ segmentor.set_image(image)
130
+ input_points = []
131
+ input_boxes = []
132
+ for p in points:
133
+ [x1, y1, _, x2, y2, _] = p
134
+ if x2==0 and y2==0:
135
+ input_points.append([x1, y1])
136
+ else:
137
+ input_boxes.append([x1, y1, x2, y2])
138
+ if len(input_points) == 0:
139
+ input_points = None
140
+ input_labels = None
141
+ else:
142
+ input_points = np.array(input_points)
143
+ input_labels = np.ones(len(input_points))
144
+ if len(input_boxes) == 0:
145
+ input_boxes = None
146
+ else:
147
+ input_boxes = np.array(input_boxes)
148
+ masks, _, logits = segmentor.predict(
149
+ point_coords=input_points,
150
+ point_labels=input_labels,
151
+ box=input_boxes,
152
+ multimask_output=False,
153
+ return_logits=True,
154
+ mask_input=logits,
155
+ )
156
+ mask = masks > 0
157
+ masked_img = mask_image(image, mask[0], color=[252, 140, 90], alpha=0.9)
158
+ masked_img = Image.fromarray(masked_img)
159
+
160
+ return mask[0], masked_img, masked_img, logits / 32.0
161
+
162
+ # return segment
163
+
164
 
165
  # -------------- UI definition --------------
166
  with gr.Blocks() as demo:
 
317
  )
318
 
319
  select_button.click(
320
+ # run_segment(segmentor),
321
+ segment,
322
  [canvas, original_image, mask_logits],
323
  [mask, mask_output, masked_original_image, mask_logits]
324
  )
objctrl_2_5d/utils/ui_utils.py CHANGED
@@ -52,49 +52,6 @@ def process_image(raw_image):
52
  def get_subject_points(canvas):
53
  return canvas["image"], canvas["points"]
54
 
55
- @spaces.GPU(duration=50)
56
- def run_segment(segmentor):
57
- def segment(canvas, image, logits):
58
- if logits is not None:
59
- logits *= 32.0
60
- _, points = get_subject_points(canvas)
61
- image = np.array(image)
62
-
63
- with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
64
- segmentor.set_image(image)
65
- input_points = []
66
- input_boxes = []
67
- for p in points:
68
- [x1, y1, _, x2, y2, _] = p
69
- if x2==0 and y2==0:
70
- input_points.append([x1, y1])
71
- else:
72
- input_boxes.append([x1, y1, x2, y2])
73
- if len(input_points) == 0:
74
- input_points = None
75
- input_labels = None
76
- else:
77
- input_points = np.array(input_points)
78
- input_labels = np.ones(len(input_points))
79
- if len(input_boxes) == 0:
80
- input_boxes = None
81
- else:
82
- input_boxes = np.array(input_boxes)
83
- masks, _, logits = segmentor.predict(
84
- point_coords=input_points,
85
- point_labels=input_labels,
86
- box=input_boxes,
87
- multimask_output=False,
88
- return_logits=True,
89
- mask_input=logits,
90
- )
91
- mask = masks > 0
92
- masked_img = mask_image(image, mask[0], color=[252, 140, 90], alpha=0.9)
93
- masked_img = Image.fromarray(masked_img)
94
-
95
- return mask[0], masked_img, masked_img, logits / 32.0
96
-
97
- return segment
98
 
99
  def mask_image(image,
100
  mask,
 
52
  def get_subject_points(canvas):
53
  return canvas["image"], canvas["points"]
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  def mask_image(image,
57
  mask,