jhj0517 commited on
Commit
e56e825
1 Parent(s): bbbee26

Add video propagation

Browse files
Files changed (2) hide show
  1. app.py +3 -0
  2. modules/sam_inference.py +121 -47
app.py CHANGED
@@ -191,6 +191,9 @@ class App:
191
  btn_generate_preview.click(fn=self.sam_inf.add_filter_to_preview,
192
  inputs=preview_params,
193
  outputs=[img_preview])
 
 
 
194
 
195
  self.demo.queue().launch(inbrowser=True)
196
 
 
191
  btn_generate_preview.click(fn=self.sam_inf.add_filter_to_preview,
192
  inputs=preview_params,
193
  outputs=[img_preview])
194
+ btn_generate.click(fn=self.sam_inf.add_filter_to_video,
195
+ inputs=preview_params,
196
+ outputs=None)
197
 
198
  self.demo.queue().launch(inbrowser=True)
199
 
modules/sam_inference.py CHANGED
@@ -14,7 +14,7 @@ from modules.model_downloader import (
14
  is_sam_exist,
15
  download_sam_model_url
16
  )
17
- from modules.paths import SAM2_CONFIGS_DIR, MODELS_DIR
18
  from modules.constants import BOX_PROMPT_MODE, AUTOMATIC_MODE, COLOR_FILTER, PIXELIZE_FILTER
19
  from modules.mask_utils import (
20
  save_psd_with_masks,
@@ -23,6 +23,8 @@ from modules.mask_utils import (
23
  create_mask_pixelized_image,
24
  create_solid_color_mask_image
25
  )
 
 
26
  from modules.logger_util import get_logger
27
 
28
  MODEL_CONFIGS = {
@@ -45,7 +47,8 @@ class SamInference:
45
  self.model_dir = model_dir
46
  self.output_dir = output_dir
47
  self.model_path = os.path.join(self.model_dir, AVAILABLE_MODELS[DEFAULT_MODEL_TYPE][0])
48
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
49
  self.mask_generator = None
50
  self.image_predictor = None
51
  self.video_predictor = None
@@ -89,8 +92,10 @@ class SamInference:
89
  raise f"Error while loading SAM2 model!: {e}"
90
 
91
  def init_video_inference_state(self,
92
- model_type: str,
93
- vid_input: str):
 
 
94
 
95
  if self.video_predictor is None or model_type != self.current_model_type:
96
  self.current_model_type = model_type
@@ -141,21 +146,25 @@ class SamInference:
141
  multimask_output=params["multimask_output"],
142
  )
143
  except Exception as e:
144
- logger.exception("Error while predicting image with prompt")
145
- raise f"Error while predicting image with prompt: {str(e)}"
146
  return masks, scores, logits
147
 
148
- def predict_frame(self,
149
- frame_idx: int,
150
- obj_id: int,
151
- inference_state: Dict,
152
- points: Optional[np.ndarray] = None,
153
- labels: Optional[np.ndarray] = None,
154
- box: Optional[np.ndarray] = None):
155
- if self.video_predictor is None or self.video_inference_state is None:
 
156
  logger.exception("Error while predicting frame from video, load video predictor first")
157
  raise f"Error while predicting frame from video"
158
 
 
 
 
159
  try:
160
  out_frame_idx, out_obj_ids, out_mask_logits = self.video_predictor.add_new_points_or_box(
161
  inference_state=inference_state,
@@ -166,15 +175,43 @@ class SamInference:
166
  box=box
167
  )
168
  except Exception as e:
169
- logger.exception("Error while predicting frame with prompt")
170
- print(e)
171
- raise f"Error while predicting frame with prompt"
172
 
173
  return out_frame_idx, out_obj_ids, out_mask_logits
174
 
175
- def predict_video(self,
176
- video_input):
177
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
  def add_filter_to_preview(self,
180
  image_prompt_input_data: Dict,
@@ -183,49 +220,86 @@ class SamInference:
183
  pixel_size: Optional[int] = None,
184
  color_hex: Optional[str] = None,
185
  ):
186
- if not image_prompt_input_data["points"]:
187
- error_message = ("Prompt data is empty! Please provide at least one point or box on the image. <br>"
188
- "If you've already added prompts, please press the eraser button "
189
- "and add your prompts again.")
190
- logger.error(error_message)
191
- raise gr.Error(error_message, duration=20)
192
-
193
  if self.video_predictor is None or self.video_inference_state is None:
194
  logger.exception("Error while adding filter to preview, load video predictor first")
195
  raise f"Error while adding filter to preview"
196
 
 
 
 
 
 
 
197
  image, prompt = image_prompt_input_data["image"], image_prompt_input_data["points"]
198
  image = np.array(image.convert("RGB"))
199
 
200
  point_labels, point_coords, box = self.handle_prompt_data(prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
  if filter_mode == COLOR_FILTER:
203
- idx, scores, logits = self.predict_frame(
204
- frame_idx=frame_idx,
205
- obj_id=0,
206
- inference_state=self.video_inference_state,
207
- points=point_coords,
208
- labels=point_labels,
209
- box=box
210
- )
211
- masks = (logits[0] > 0.0).cpu().numpy()
212
- generated_masks = self.format_to_auto_result(masks)
213
  image = create_solid_color_mask_image(image, generated_masks, color_hex)
214
 
215
  elif filter_mode == PIXELIZE_FILTER:
216
- idx, scores, logits = self.predict_frame(
217
- frame_idx=frame_idx,
218
- obj_id=0,
219
- inference_state=self.video_inference_state,
220
- points=point_coords,
221
- labels=point_labels,
222
- box=box
223
- )
224
- masks = (logits[0] > 0.0).cpu().numpy()
225
- generated_masks = self.format_to_auto_result(masks)
226
  image = create_mask_pixelized_image(image, generated_masks, pixel_size)
 
227
  return image
228
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  def divide_layer(self,
230
  image_input: np.ndarray,
231
  image_prompt_input_data: Dict,
 
14
  is_sam_exist,
15
  download_sam_model_url
16
  )
17
+ from modules.paths import SAM2_CONFIGS_DIR, MODELS_DIR, TEMP_OUT_DIR, TEMP_DIR
18
  from modules.constants import BOX_PROMPT_MODE, AUTOMATIC_MODE, COLOR_FILTER, PIXELIZE_FILTER
19
  from modules.mask_utils import (
20
  save_psd_with_masks,
 
23
  create_mask_pixelized_image,
24
  create_solid_color_mask_image
25
  )
26
+ from modules.video_utils import get_frames_from_dir
27
+ from modules.utils import save_image
28
  from modules.logger_util import get_logger
29
 
30
  MODEL_CONFIGS = {
 
47
  self.model_dir = model_dir
48
  self.output_dir = output_dir
49
  self.model_path = os.path.join(self.model_dir, AVAILABLE_MODELS[DEFAULT_MODEL_TYPE][0])
50
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
51
+ self.dtype = torch.float16 if torch.cuda.is_available() else torch.float32
52
  self.mask_generator = None
53
  self.image_predictor = None
54
  self.video_predictor = None
 
92
  raise f"Error while loading SAM2 model!: {e}"
93
 
94
  def init_video_inference_state(self,
95
+ vid_input: str,
96
+ model_type: Optional[str] = None):
97
+ if model_type is None:
98
+ model_type = self.current_model_type
99
 
100
  if self.video_predictor is None or model_type != self.current_model_type:
101
  self.current_model_type = model_type
 
146
  multimask_output=params["multimask_output"],
147
  )
148
  except Exception as e:
149
+ logger.exception(f"Error while predicting image with prompt: {str(e)}")
150
+ raise RuntimeError(f"Error while predicting image with prompt: {str(e)}") from e
151
  return masks, scores, logits
152
 
153
+ def add_prediction_to_frame(self,
154
+ frame_idx: int,
155
+ obj_id: int,
156
+ inference_state: Optional[Dict] = None,
157
+ points: Optional[np.ndarray] = None,
158
+ labels: Optional[np.ndarray] = None,
159
+ box: Optional[np.ndarray] = None):
160
+ if (self.video_predictor is None or
161
+ inference_state is None and self.video_inference_state is None):
162
  logger.exception("Error while predicting frame from video, load video predictor first")
163
  raise f"Error while predicting frame from video"
164
 
165
+ if inference_state is None:
166
+ inference_state = self.video_inference_state
167
+
168
  try:
169
  out_frame_idx, out_obj_ids, out_mask_logits = self.video_predictor.add_new_points_or_box(
170
  inference_state=inference_state,
 
175
  box=box
176
  )
177
  except Exception as e:
178
+ logger.exception(f"Error while predicting frame with prompt: {str(e)}")
179
+ raise RuntimeError(f"Failed to predicting frame with prompt: {str(e)}") from e
 
180
 
181
  return out_frame_idx, out_obj_ids, out_mask_logits
182
 
183
+ def propagate_in_video(self,
184
+ inference_state: Optional[Dict] = None,):
185
+ if inference_state is None and self.video_inference_state is None:
186
+ logger.exception("Error while propagating in video, load video predictor first")
187
+ raise f"Error while propagating in video"
188
+
189
+ if inference_state is None:
190
+ inference_state = self.video_inference_state
191
+
192
+ video_segments = {}
193
+
194
+ try:
195
+ generator = self.video_predictor.propagate_in_video(
196
+ inference_state=inference_state,
197
+ start_frame_idx=0
198
+ )
199
+ cached_images = inference_state["images"]
200
+ images = get_frames_from_dir(vid_dir=TEMP_DIR, as_numpy=True)
201
+
202
+ with torch.autocast(device_type=self.device, dtype=torch.float16):
203
+ for out_frame_idx, out_obj_ids, out_mask_logits in generator:
204
+ mask = (out_mask_logits[0] > 0.0).cpu().numpy()
205
+ video_segments[out_frame_idx] = {
206
+ "image": images[out_frame_idx],
207
+ "mask": mask
208
+ }
209
+ print("frame_idx: ", out_frame_idx)
210
+ except Exception as e:
211
+ logger.exception(f"Error while propagating in video: {str(e)}")
212
+ raise RuntimeError(f"Failed to propagate in video: {str(e)}") from e
213
+
214
+ return video_segments
215
 
216
  def add_filter_to_preview(self,
217
  image_prompt_input_data: Dict,
 
220
  pixel_size: Optional[int] = None,
221
  color_hex: Optional[str] = None,
222
  ):
 
 
 
 
 
 
 
223
  if self.video_predictor is None or self.video_inference_state is None:
224
  logger.exception("Error while adding filter to preview, load video predictor first")
225
  raise f"Error while adding filter to preview"
226
 
227
+ if not image_prompt_input_data["points"]:
228
+ error_message = ("No prompt data provided. If this is an incorrect flag, "
229
+ "Please press the eraser button (on the image prompter) and add your prompts again.")
230
+ logger.error(error_message)
231
+ raise gr.Error(error_message, duration=20)
232
+
233
  image, prompt = image_prompt_input_data["image"], image_prompt_input_data["points"]
234
  image = np.array(image.convert("RGB"))
235
 
236
  point_labels, point_coords, box = self.handle_prompt_data(prompt)
237
+ obj_id = frame_idx
238
+
239
+ self.video_predictor.reset_state(self.video_inference_state)
240
+ idx, scores, logits = self.add_prediction_to_frame(
241
+ frame_idx=frame_idx,
242
+ obj_id=obj_id,
243
+ inference_state=self.video_inference_state,
244
+ points=point_coords,
245
+ labels=point_labels,
246
+ box=box
247
+ )
248
+ masks = (logits[0] > 0.0).cpu().numpy()
249
+ generated_masks = self.format_to_auto_result(masks)
250
 
251
  if filter_mode == COLOR_FILTER:
 
 
 
 
 
 
 
 
 
 
252
  image = create_solid_color_mask_image(image, generated_masks, color_hex)
253
 
254
  elif filter_mode == PIXELIZE_FILTER:
 
 
 
 
 
 
 
 
 
 
255
  image = create_mask_pixelized_image(image, generated_masks, pixel_size)
256
+
257
  return image
258
 
259
+ def add_filter_to_video(self,
260
+ image_prompt_input_data: Dict,
261
+ filter_mode: str,
262
+ frame_idx: int,
263
+ pixel_size: Optional[int] = None,
264
+ color_hex: Optional[str] = None,):
265
+ if self.video_predictor is None or self.video_inference_state is None:
266
+ logger.exception("Error while adding filter to preview, load video predictor first")
267
+ raise f"Error while adding filter to preview"
268
+
269
+ if not image_prompt_input_data["points"]:
270
+ error_message = ("No prompt data provided. If this is an incorrect flag, "
271
+ "Please press the eraser button (on the image prompter) and add your prompts again.")
272
+ logger.error(error_message)
273
+ raise gr.Error(error_message, duration=20)
274
+
275
+ prompt_frame_image, prompt = image_prompt_input_data["image"], image_prompt_input_data["points"]
276
+
277
+ point_labels, point_coords, box = self.handle_prompt_data(prompt)
278
+ obj_id = frame_idx
279
+
280
+ self.video_predictor.reset_state(self.video_inference_state)
281
+ idx, scores, logits = self.add_prediction_to_frame(
282
+ frame_idx=frame_idx,
283
+ obj_id=obj_id,
284
+ inference_state=self.video_inference_state,
285
+ points=point_coords,
286
+ labels=point_labels,
287
+ box=box
288
+ )
289
+
290
+ video_segments = self.propagate_in_video(inference_state=self.video_inference_state)
291
+ for frame_index, info in video_segments.items():
292
+ orig_image, masks = info["image"], info["mask"]
293
+ masks = self.format_to_auto_result(masks)
294
+
295
+ if filter_mode == COLOR_FILTER:
296
+ filtered_image = create_solid_color_mask_image(orig_image, masks, color_hex)
297
+
298
+ elif filter_mode == PIXELIZE_FILTER:
299
+ filtered_image = create_mask_pixelized_image(orig_image, masks, pixel_size)
300
+
301
+ save_image(filtered_image, os.path.join(TEMP_OUT_DIR, "%05d.jpg"))
302
+
303
  def divide_layer(self,
304
  image_input: np.ndarray,
305
  image_prompt_input_data: Dict,