hysts HF Staff commited on
Commit
9487a42
·
1 Parent(s): 9827da7
Files changed (1) hide show
  1. app.py +167 -175
app.py CHANGED
@@ -129,18 +129,18 @@ class AppState:
129
 
130
 
131
  def init_video_session(
132
- GLOBAL_STATE: AppState, video: str | dict, active_tab: str = "point_box"
133
  ) -> tuple[AppState, int, int, Image.Image, str]:
134
- GLOBAL_STATE.video_frames = []
135
- GLOBAL_STATE.masks_by_frame = {}
136
- GLOBAL_STATE.color_by_obj = {}
137
- GLOBAL_STATE.color_by_prompt = {}
138
- GLOBAL_STATE.text_prompts_by_frame_obj = {}
139
- GLOBAL_STATE.clicks_by_frame_obj = {}
140
- GLOBAL_STATE.boxes_by_frame_obj = {}
141
- GLOBAL_STATE.composited_frames = {}
142
- GLOBAL_STATE.inference_session = None
143
- GLOBAL_STATE.active_tab = active_tab
144
 
145
  video_path: str | None = None
146
  if isinstance(video, dict):
@@ -165,14 +165,14 @@ def init_video_session(
165
  trimmed_note = f" (trimmed to {int(MAX_SECONDS)}s = {len(frames)} frames)"
166
  if isinstance(info, dict):
167
  info["num_frames"] = len(frames)
168
- GLOBAL_STATE.video_frames = frames
169
- GLOBAL_STATE.video_fps = float(fps_in) if fps_in else None
170
 
171
  raw_video = [np.array(frame) for frame in frames]
172
 
173
  if active_tab == "text":
174
  processor = TEXT_VIDEO_PROCESSOR
175
- GLOBAL_STATE.inference_session = processor.init_video_session(
176
  video=frames,
177
  inference_device=DEVICE,
178
  inference_state_device=DEVICE,
@@ -182,7 +182,7 @@ def init_video_session(
182
  )
183
  else:
184
  processor = TRACKER_PROCESSOR
185
- GLOBAL_STATE.inference_session = processor.init_video_session(
186
  video=raw_video,
187
  inference_device=DEVICE,
188
  inference_state_device=DEVICE,
@@ -195,15 +195,15 @@ def init_video_session(
195
  max_idx = len(frames) - 1
196
  if active_tab == "text":
197
  status = (
198
- f"Loaded {len(frames)} frames @ {GLOBAL_STATE.video_fps or 'unknown'} fps{trimmed_note}. "
199
  f"Device: {DEVICE}, dtype: bfloat16. Ready for text prompting."
200
  )
201
  else:
202
  status = (
203
- f"Loaded {len(frames)} frames @ {GLOBAL_STATE.video_fps or 'unknown'} fps{trimmed_note}. "
204
  f"Device: {DEVICE}, dtype: bfloat16. Video session initialized."
205
  )
206
- return GLOBAL_STATE, 0, max_idx, first_frame, status
207
 
208
 
209
  def compose_frame(state: AppState, frame_idx: int) -> Image.Image:
@@ -596,24 +596,24 @@ def _get_active_prompts_display(state: AppState) -> str:
596
  return "**Active prompts:** None"
597
 
598
 
599
- def propagate_masks(GLOBAL_STATE: AppState) -> Iterator[tuple[AppState, str, dict]]:
600
- if GLOBAL_STATE is None:
601
- return GLOBAL_STATE, "Load a video first.", gr.update()
602
 
603
- if GLOBAL_STATE.active_tab != "text" and GLOBAL_STATE.inference_session is None:
604
- return GLOBAL_STATE, "Load a video first.", gr.update()
605
 
606
- total = max(1, GLOBAL_STATE.num_frames)
607
  processed = 0
608
 
609
- yield GLOBAL_STATE, f"Propagating masks: {processed}/{total}", gr.update()
610
 
611
  last_frame_idx = 0
612
 
613
  with torch.no_grad():
614
- if GLOBAL_STATE.active_tab == "text":
615
- if GLOBAL_STATE.inference_session is None:
616
- yield GLOBAL_STATE, "Text video model not loaded.", gr.update()
617
  return
618
 
619
  model = TEXT_VIDEO_MODEL
@@ -621,7 +621,7 @@ def propagate_masks(GLOBAL_STATE: AppState) -> Iterator[tuple[AppState, str, dic
621
 
622
  # Collect all unique prompts from existing frame annotations
623
  text_prompt_to_obj_ids = {}
624
- for frame_idx, frame_texts in GLOBAL_STATE.text_prompts_by_frame_obj.items():
625
  for obj_id, text_prompt in frame_texts.items():
626
  if text_prompt not in text_prompt_to_obj_ids:
627
  text_prompt_to_obj_ids[text_prompt] = []
@@ -629,8 +629,8 @@ def propagate_masks(GLOBAL_STATE: AppState) -> Iterator[tuple[AppState, str, dic
629
  text_prompt_to_obj_ids[text_prompt].append(obj_id)
630
 
631
  # Also check if there are prompts already in the inference session
632
- if hasattr(GLOBAL_STATE.inference_session, "prompts") and GLOBAL_STATE.inference_session.prompts:
633
- for prompt_text in GLOBAL_STATE.inference_session.prompts.values():
634
  if prompt_text not in text_prompt_to_obj_ids:
635
  text_prompt_to_obj_ids[prompt_text] = []
636
 
@@ -638,31 +638,29 @@ def propagate_masks(GLOBAL_STATE: AppState) -> Iterator[tuple[AppState, str, dic
638
  text_prompt_to_obj_ids[text_prompt].sort()
639
 
640
  if not text_prompt_to_obj_ids:
641
- yield GLOBAL_STATE, "No text prompts found. Please add a text prompt first.", gr.update()
642
  return
643
 
644
  # Add all prompts to the inference session (processor handles deduplication)
645
  for text_prompt in text_prompt_to_obj_ids:
646
- GLOBAL_STATE.inference_session = processor.add_text_prompt(
647
- inference_session=GLOBAL_STATE.inference_session,
648
  text=text_prompt,
649
  )
650
 
651
- earliest_frame = (
652
- min(GLOBAL_STATE.text_prompts_by_frame_obj.keys()) if GLOBAL_STATE.text_prompts_by_frame_obj else 0
653
- )
654
 
655
- frames_to_track = GLOBAL_STATE.num_frames - earliest_frame
656
 
657
  outputs_per_frame = {}
658
 
659
  for model_outputs in model.propagate_in_video_iterator(
660
- inference_session=GLOBAL_STATE.inference_session,
661
  start_frame_idx=earliest_frame,
662
  max_frame_num_to_track=frames_to_track,
663
  ):
664
  processed_outputs = processor.postprocess_outputs(
665
- GLOBAL_STATE.inference_session,
666
  model_outputs,
667
  )
668
  frame_idx = model_outputs.frame_idx
@@ -673,8 +671,8 @@ def propagate_masks(GLOBAL_STATE: AppState) -> Iterator[tuple[AppState, str, dic
673
  scores = processed_outputs["scores"]
674
  prompt_to_obj_ids = processed_outputs.get("prompt_to_obj_ids", {})
675
 
676
- masks_for_frame = GLOBAL_STATE.masks_by_frame.setdefault(frame_idx, {})
677
- frame_texts = GLOBAL_STATE.text_prompts_by_frame_obj.setdefault(frame_idx, {})
678
 
679
  num_objects = len(object_ids)
680
  if num_objects > 0:
@@ -701,137 +699,131 @@ def propagate_masks(GLOBAL_STATE: AppState) -> Iterator[tuple[AppState, str, dic
701
  # Store prompt and assign color
702
  if found_prompt:
703
  frame_texts[current_obj_id] = found_prompt.strip()
704
- _ensure_color_for_obj(GLOBAL_STATE, current_obj_id)
705
 
706
- GLOBAL_STATE.composited_frames.pop(frame_idx, None)
707
  last_frame_idx = frame_idx
708
  processed += 1
709
  if processed % 30 == 0 or processed == total:
710
- yield GLOBAL_STATE, f"Propagating masks: {processed}/{total}", gr.update(value=frame_idx)
711
  else:
712
- if GLOBAL_STATE.inference_session is None:
713
- yield GLOBAL_STATE, "Tracker model not loaded.", gr.update()
714
  return
715
 
716
  model = TRACKER_MODEL
717
  processor = TRACKER_PROCESSOR
718
 
719
- for sam2_video_output in model.propagate_in_video_iterator(
720
- inference_session=GLOBAL_STATE.inference_session
721
- ):
722
  video_res_masks = processor.post_process_masks(
723
  [sam2_video_output.pred_masks],
724
- original_sizes=[
725
- [GLOBAL_STATE.inference_session.video_height, GLOBAL_STATE.inference_session.video_width]
726
- ],
727
  )[0]
728
 
729
  frame_idx = sam2_video_output.frame_idx
730
- for i, out_obj_id in enumerate(GLOBAL_STATE.inference_session.obj_ids):
731
- _ensure_color_for_obj(GLOBAL_STATE, int(out_obj_id))
732
  mask_2d = video_res_masks[i].cpu().numpy()
733
- masks_for_frame = GLOBAL_STATE.masks_by_frame.setdefault(frame_idx, {})
734
  masks_for_frame[int(out_obj_id)] = mask_2d
735
- GLOBAL_STATE.composited_frames.pop(frame_idx, None)
736
 
737
  last_frame_idx = frame_idx
738
  processed += 1
739
  if processed % 30 == 0 or processed == total:
740
- yield GLOBAL_STATE, f"Propagating masks: {processed}/{total}", gr.update(value=frame_idx)
741
 
742
  text = f"Propagated masks across {processed} frames."
743
- yield GLOBAL_STATE, text, gr.update(value=last_frame_idx)
744
 
745
 
746
- def reset_prompts(GLOBAL_STATE: AppState) -> tuple[AppState, Image.Image, str, str]:
747
  """Reset prompts and all outputs, but keep processed frames and cached vision features."""
748
- if GLOBAL_STATE is None or GLOBAL_STATE.inference_session is None:
749
- active_prompts = _get_active_prompts_display(GLOBAL_STATE)
750
- return GLOBAL_STATE, None, "No active session to reset.", active_prompts
751
 
752
- if GLOBAL_STATE.active_tab != "text":
753
- active_prompts = _get_active_prompts_display(GLOBAL_STATE)
754
- return GLOBAL_STATE, None, "Reset prompts is only available for text prompting mode.", active_prompts
755
 
756
  # Reset inference session tracking data but keep cache and processed frames
757
- if hasattr(GLOBAL_STATE.inference_session, "reset_tracking_data"):
758
- GLOBAL_STATE.inference_session.reset_tracking_data()
759
 
760
  # Manually clear prompts (reset_tracking_data doesn't clear prompts themselves)
761
- if hasattr(GLOBAL_STATE.inference_session, "prompts"):
762
- GLOBAL_STATE.inference_session.prompts.clear()
763
- if hasattr(GLOBAL_STATE.inference_session, "prompt_input_ids"):
764
- GLOBAL_STATE.inference_session.prompt_input_ids.clear()
765
- if hasattr(GLOBAL_STATE.inference_session, "prompt_embeddings"):
766
- GLOBAL_STATE.inference_session.prompt_embeddings.clear()
767
- if hasattr(GLOBAL_STATE.inference_session, "prompt_attention_masks"):
768
- GLOBAL_STATE.inference_session.prompt_attention_masks.clear()
769
- if hasattr(GLOBAL_STATE.inference_session, "obj_id_to_prompt_id"):
770
- GLOBAL_STATE.inference_session.obj_id_to_prompt_id.clear()
771
 
772
  # Reset detection-tracking fusion state
773
- if hasattr(GLOBAL_STATE.inference_session, "obj_id_to_score"):
774
- GLOBAL_STATE.inference_session.obj_id_to_score.clear()
775
- if hasattr(GLOBAL_STATE.inference_session, "obj_id_to_tracker_score_frame_wise"):
776
- GLOBAL_STATE.inference_session.obj_id_to_tracker_score_frame_wise.clear()
777
- if hasattr(GLOBAL_STATE.inference_session, "obj_id_to_last_occluded"):
778
- GLOBAL_STATE.inference_session.obj_id_to_last_occluded.clear()
779
- if hasattr(GLOBAL_STATE.inference_session, "max_obj_id"):
780
- GLOBAL_STATE.inference_session.max_obj_id = -1
781
- if hasattr(GLOBAL_STATE.inference_session, "obj_first_frame_idx"):
782
- GLOBAL_STATE.inference_session.obj_first_frame_idx.clear()
783
- if hasattr(GLOBAL_STATE.inference_session, "unmatched_frame_inds"):
784
- GLOBAL_STATE.inference_session.unmatched_frame_inds.clear()
785
- if hasattr(GLOBAL_STATE.inference_session, "overlap_pair_to_frame_inds"):
786
- GLOBAL_STATE.inference_session.overlap_pair_to_frame_inds.clear()
787
- if hasattr(GLOBAL_STATE.inference_session, "trk_keep_alive"):
788
- GLOBAL_STATE.inference_session.trk_keep_alive.clear()
789
- if hasattr(GLOBAL_STATE.inference_session, "removed_obj_ids"):
790
- GLOBAL_STATE.inference_session.removed_obj_ids.clear()
791
- if hasattr(GLOBAL_STATE.inference_session, "suppressed_obj_ids"):
792
- GLOBAL_STATE.inference_session.suppressed_obj_ids.clear()
793
- if hasattr(GLOBAL_STATE.inference_session, "hotstart_removed_obj_ids"):
794
- GLOBAL_STATE.inference_session.hotstart_removed_obj_ids.clear()
795
 
796
  # Clear all app state outputs
797
- GLOBAL_STATE.masks_by_frame.clear()
798
- GLOBAL_STATE.text_prompts_by_frame_obj.clear()
799
- GLOBAL_STATE.composited_frames.clear()
800
- GLOBAL_STATE.color_by_obj.clear()
801
- GLOBAL_STATE.color_by_prompt.clear()
802
 
803
  # Update display
804
- current_idx = int(getattr(GLOBAL_STATE, "current_frame_idx", 0))
805
- current_idx = max(0, min(current_idx, GLOBAL_STATE.num_frames - 1))
806
- preview_img = update_frame_display(GLOBAL_STATE, current_idx)
807
- active_prompts = _get_active_prompts_display(GLOBAL_STATE)
808
  status = "Prompts and outputs reset. Processed frames and cached vision features preserved."
809
 
810
- return GLOBAL_STATE, preview_img, status, active_prompts
811
 
812
 
813
- def reset_session(GLOBAL_STATE: AppState) -> tuple[AppState, Image.Image, int, int, str, str]:
814
- if not GLOBAL_STATE.video_frames:
815
- return GLOBAL_STATE, None, 0, 0, "Session reset. Load a new video.", "**Active prompts:** None"
816
 
817
- if GLOBAL_STATE.active_tab == "text":
818
- if GLOBAL_STATE.video_frames:
819
  processor = TEXT_VIDEO_PROCESSOR
820
- GLOBAL_STATE.inference_session = processor.init_video_session(
821
- video=GLOBAL_STATE.video_frames,
822
  inference_device=DEVICE,
823
  processing_device="cpu",
824
  video_storage_device="cpu",
825
  dtype=DTYPE,
826
  )
827
- elif GLOBAL_STATE.inference_session is not None and hasattr(
828
- GLOBAL_STATE.inference_session, "reset_inference_session"
829
- ):
830
- GLOBAL_STATE.inference_session.reset_inference_session()
831
- elif GLOBAL_STATE.video_frames:
832
  processor = TRACKER_PROCESSOR
833
- raw_video = [np.array(frame) for frame in GLOBAL_STATE.video_frames]
834
- GLOBAL_STATE.inference_session = processor.init_video_session(
835
  video=raw_video,
836
  inference_device=DEVICE,
837
  video_storage_device="cpu",
@@ -839,44 +831,44 @@ def reset_session(GLOBAL_STATE: AppState) -> tuple[AppState, Image.Image, int, i
839
  dtype=DTYPE,
840
  )
841
 
842
- GLOBAL_STATE.masks_by_frame.clear()
843
- GLOBAL_STATE.clicks_by_frame_obj.clear()
844
- GLOBAL_STATE.boxes_by_frame_obj.clear()
845
- GLOBAL_STATE.text_prompts_by_frame_obj.clear()
846
- GLOBAL_STATE.composited_frames.clear()
847
- GLOBAL_STATE.color_by_obj.clear()
848
- GLOBAL_STATE.color_by_prompt.clear()
849
- GLOBAL_STATE.pending_box_start = None
850
- GLOBAL_STATE.pending_box_start_frame_idx = None
851
- GLOBAL_STATE.pending_box_start_obj_id = None
852
 
853
  gc.collect()
854
 
855
- current_idx = int(getattr(GLOBAL_STATE, "current_frame_idx", 0))
856
- current_idx = max(0, min(current_idx, GLOBAL_STATE.num_frames - 1))
857
- preview_img = update_frame_display(GLOBAL_STATE, current_idx)
858
- slider_minmax = gr.update(minimum=0, maximum=max(GLOBAL_STATE.num_frames - 1, 0), interactive=True)
859
  slider_value = gr.update(value=current_idx)
860
  status = "Session reset. Prompts cleared; video preserved."
861
- active_prompts = _get_active_prompts_display(GLOBAL_STATE)
862
- return GLOBAL_STATE, preview_img, slider_minmax, slider_value, status, active_prompts
863
 
864
 
865
- def _on_video_change_pointbox(GLOBAL_STATE: AppState, video: str | dict) -> tuple[AppState, dict, Image.Image, str]:
866
- GLOBAL_STATE, min_idx, max_idx, first_frame, status = init_video_session(GLOBAL_STATE, video, "point_box")
867
  return (
868
- GLOBAL_STATE,
869
  gr.update(minimum=min_idx, maximum=max_idx, value=min_idx, interactive=True),
870
  first_frame,
871
  status,
872
  )
873
 
874
 
875
- def _on_video_change_text(GLOBAL_STATE: AppState, video: str | dict) -> tuple[AppState, dict, Image.Image, str, str]:
876
- GLOBAL_STATE, min_idx, max_idx, first_frame, status = init_video_session(GLOBAL_STATE, video, "text")
877
- active_prompts = _get_active_prompts_display(GLOBAL_STATE)
878
  return (
879
- GLOBAL_STATE,
880
  gr.update(minimum=min_idx, maximum=max_idx, value=min_idx, interactive=True),
881
  first_frame,
882
  status,
@@ -885,7 +877,7 @@ def _on_video_change_text(GLOBAL_STATE: AppState, video: str | dict) -> tuple[Ap
885
 
886
 
887
  with gr.Blocks(title="SAM3", theme=Soft(primary_hue="blue", secondary_hue="rose", neutral_hue="slate")) as demo:
888
- GLOBAL_STATE = gr.State(AppState())
889
 
890
  gr.Markdown(
891
  """
@@ -953,9 +945,9 @@ with gr.Blocks(title="SAM3", theme=Soft(primary_hue="blue", secondary_hue="rose"
953
  with gr.Row():
954
  gr.Examples(
955
  examples=examples_list_text,
956
- inputs=[GLOBAL_STATE, video_in_text],
957
  fn=_on_video_change_text,
958
- outputs=[GLOBAL_STATE, frame_slider_text, preview_text, load_status_text, active_prompts_display],
959
  label="Examples",
960
  cache_examples=False,
961
  examples_per_page=5,
@@ -1016,9 +1008,9 @@ with gr.Blocks(title="SAM3", theme=Soft(primary_hue="blue", secondary_hue="rose"
1016
  with gr.Row():
1017
  gr.Examples(
1018
  examples=examples_list_pointbox,
1019
- inputs=[GLOBAL_STATE, video_in_pointbox],
1020
  fn=_on_video_change_pointbox,
1021
- outputs=[GLOBAL_STATE, frame_slider_pointbox, preview_pointbox, load_status_pointbox],
1022
  label="Examples",
1023
  cache_examples=False,
1024
  examples_per_page=5,
@@ -1026,8 +1018,8 @@ with gr.Blocks(title="SAM3", theme=Soft(primary_hue="blue", secondary_hue="rose"
1026
 
1027
  video_in_pointbox.change(
1028
  _on_video_change_pointbox,
1029
- inputs=[GLOBAL_STATE, video_in_pointbox],
1030
- outputs=[GLOBAL_STATE, frame_slider_pointbox, preview_pointbox, load_status_pointbox],
1031
  show_progress=True,
1032
  )
1033
 
@@ -1038,14 +1030,14 @@ with gr.Blocks(title="SAM3", theme=Soft(primary_hue="blue", secondary_hue="rose"
1038
 
1039
  frame_slider_pointbox.change(
1040
  _sync_frame_idx_pointbox,
1041
- inputs=[GLOBAL_STATE, frame_slider_pointbox],
1042
  outputs=preview_pointbox,
1043
  )
1044
 
1045
  video_in_text.change(
1046
  _on_video_change_text,
1047
- inputs=[GLOBAL_STATE, video_in_text],
1048
- outputs=[GLOBAL_STATE, frame_slider_text, preview_text, load_status_text, active_prompts_display],
1049
  show_progress=True,
1050
  )
1051
 
@@ -1056,7 +1048,7 @@ with gr.Blocks(title="SAM3", theme=Soft(primary_hue="blue", secondary_hue="rose"
1056
 
1057
  frame_slider_text.change(
1058
  _sync_frame_idx_text,
1059
- inputs=[GLOBAL_STATE, frame_slider_text],
1060
  outputs=preview_text,
1061
  )
1062
 
@@ -1065,14 +1057,14 @@ with gr.Blocks(title="SAM3", theme=Soft(primary_hue="blue", secondary_hue="rose"
1065
  s.current_obj_id = int(oid)
1066
  return gr.update()
1067
 
1068
- obj_id_inp.change(_sync_obj_id, inputs=[GLOBAL_STATE, obj_id_inp], outputs=[])
1069
 
1070
  def _sync_label(s: AppState, lab: str):
1071
  if s is not None and lab is not None:
1072
  s.current_label = str(lab)
1073
  return gr.update()
1074
 
1075
- label_radio.change(_sync_label, inputs=[GLOBAL_STATE, label_radio], outputs=[])
1076
 
1077
  def _sync_prompt_type(s: AppState, val: str):
1078
  if s is not None and val is not None:
@@ -1087,13 +1079,13 @@ with gr.Blocks(title="SAM3", theme=Soft(primary_hue="blue", secondary_hue="rose"
1087
 
1088
  prompt_type.change(
1089
  _sync_prompt_type,
1090
- inputs=[GLOBAL_STATE, prompt_type],
1091
  outputs=[label_radio, clear_old_chk],
1092
  )
1093
 
1094
  preview_pointbox.select(
1095
  on_image_click,
1096
- [preview_pointbox, GLOBAL_STATE, frame_slider_pointbox, obj_id_inp, label_radio, clear_old_chk],
1097
  preview_pointbox,
1098
  )
1099
 
@@ -1103,14 +1095,14 @@ with gr.Blocks(title="SAM3", theme=Soft(primary_hue="blue", secondary_hue="rose"
1103
 
1104
  text_apply_btn.click(
1105
  _on_text_apply,
1106
- inputs=[GLOBAL_STATE, frame_slider_text, text_prompt_input],
1107
  outputs=[preview_text, text_status, active_prompts_display],
1108
  )
1109
 
1110
  reset_prompts_btn.click(
1111
  reset_prompts,
1112
- inputs=[GLOBAL_STATE],
1113
- outputs=[GLOBAL_STATE, preview_text, text_status, active_prompts_display],
1114
  )
1115
 
1116
  def _render_video(s: AppState):
@@ -1139,32 +1131,32 @@ with gr.Blocks(title="SAM3", theme=Soft(primary_hue="blue", secondary_hue="rose"
1139
  print(f"Failed to render video with cv2: {e}")
1140
  raise gr.Error(f"Failed to render video: {e}")
1141
 
1142
- render_btn_pointbox.click(_render_video, inputs=[GLOBAL_STATE], outputs=[playback_video_pointbox])
1143
- render_btn_text.click(_render_video, inputs=[GLOBAL_STATE], outputs=[playback_video_text])
1144
 
1145
  propagate_btn_pointbox.click(
1146
  propagate_masks,
1147
- inputs=[GLOBAL_STATE],
1148
- outputs=[GLOBAL_STATE, propagate_status_pointbox, frame_slider_pointbox],
1149
  )
1150
 
1151
  propagate_btn_text.click(
1152
  propagate_masks,
1153
- inputs=[GLOBAL_STATE],
1154
- outputs=[GLOBAL_STATE, propagate_status_text, frame_slider_text],
1155
  )
1156
 
1157
  reset_btn_pointbox.click(
1158
  reset_session,
1159
- inputs=GLOBAL_STATE,
1160
- outputs=[GLOBAL_STATE, preview_pointbox, frame_slider_pointbox, frame_slider_pointbox, load_status_pointbox],
1161
  )
1162
 
1163
  reset_btn_text.click(
1164
  reset_session,
1165
- inputs=GLOBAL_STATE,
1166
  outputs=[
1167
- GLOBAL_STATE,
1168
  preview_text,
1169
  frame_slider_text,
1170
  frame_slider_text,
 
129
 
130
 
131
  def init_video_session(
132
+ state: AppState, video: str | dict, active_tab: str = "point_box"
133
  ) -> tuple[AppState, int, int, Image.Image, str]:
134
+ state.video_frames = []
135
+ state.masks_by_frame = {}
136
+ state.color_by_obj = {}
137
+ state.color_by_prompt = {}
138
+ state.text_prompts_by_frame_obj = {}
139
+ state.clicks_by_frame_obj = {}
140
+ state.boxes_by_frame_obj = {}
141
+ state.composited_frames = {}
142
+ state.inference_session = None
143
+ state.active_tab = active_tab
144
 
145
  video_path: str | None = None
146
  if isinstance(video, dict):
 
165
  trimmed_note = f" (trimmed to {int(MAX_SECONDS)}s = {len(frames)} frames)"
166
  if isinstance(info, dict):
167
  info["num_frames"] = len(frames)
168
+ state.video_frames = frames
169
+ state.video_fps = float(fps_in) if fps_in else None
170
 
171
  raw_video = [np.array(frame) for frame in frames]
172
 
173
  if active_tab == "text":
174
  processor = TEXT_VIDEO_PROCESSOR
175
+ state.inference_session = processor.init_video_session(
176
  video=frames,
177
  inference_device=DEVICE,
178
  inference_state_device=DEVICE,
 
182
  )
183
  else:
184
  processor = TRACKER_PROCESSOR
185
+ state.inference_session = processor.init_video_session(
186
  video=raw_video,
187
  inference_device=DEVICE,
188
  inference_state_device=DEVICE,
 
195
  max_idx = len(frames) - 1
196
  if active_tab == "text":
197
  status = (
198
+ f"Loaded {len(frames)} frames @ {state.video_fps or 'unknown'} fps{trimmed_note}. "
199
  f"Device: {DEVICE}, dtype: bfloat16. Ready for text prompting."
200
  )
201
  else:
202
  status = (
203
+ f"Loaded {len(frames)} frames @ {state.video_fps or 'unknown'} fps{trimmed_note}. "
204
  f"Device: {DEVICE}, dtype: bfloat16. Video session initialized."
205
  )
206
+ return state, 0, max_idx, first_frame, status
207
 
208
 
209
  def compose_frame(state: AppState, frame_idx: int) -> Image.Image:
 
596
  return "**Active prompts:** None"
597
 
598
 
599
+ def propagate_masks(state: AppState) -> Iterator[tuple[AppState, str, dict]]:
600
+ if state is None:
601
+ return state, "Load a video first.", gr.update()
602
 
603
+ if state.active_tab != "text" and state.inference_session is None:
604
+ return state, "Load a video first.", gr.update()
605
 
606
+ total = max(1, state.num_frames)
607
  processed = 0
608
 
609
+ yield state, f"Propagating masks: {processed}/{total}", gr.update()
610
 
611
  last_frame_idx = 0
612
 
613
  with torch.no_grad():
614
+ if state.active_tab == "text":
615
+ if state.inference_session is None:
616
+ yield state, "Text video model not loaded.", gr.update()
617
  return
618
 
619
  model = TEXT_VIDEO_MODEL
 
621
 
622
  # Collect all unique prompts from existing frame annotations
623
  text_prompt_to_obj_ids = {}
624
+ for frame_idx, frame_texts in state.text_prompts_by_frame_obj.items():
625
  for obj_id, text_prompt in frame_texts.items():
626
  if text_prompt not in text_prompt_to_obj_ids:
627
  text_prompt_to_obj_ids[text_prompt] = []
 
629
  text_prompt_to_obj_ids[text_prompt].append(obj_id)
630
 
631
  # Also check if there are prompts already in the inference session
632
+ if hasattr(state.inference_session, "prompts") and state.inference_session.prompts:
633
+ for prompt_text in state.inference_session.prompts.values():
634
  if prompt_text not in text_prompt_to_obj_ids:
635
  text_prompt_to_obj_ids[prompt_text] = []
636
 
 
638
  text_prompt_to_obj_ids[text_prompt].sort()
639
 
640
  if not text_prompt_to_obj_ids:
641
+ yield state, "No text prompts found. Please add a text prompt first.", gr.update()
642
  return
643
 
644
  # Add all prompts to the inference session (processor handles deduplication)
645
  for text_prompt in text_prompt_to_obj_ids:
646
+ state.inference_session = processor.add_text_prompt(
647
+ inference_session=state.inference_session,
648
  text=text_prompt,
649
  )
650
 
651
+ earliest_frame = min(state.text_prompts_by_frame_obj.keys()) if state.text_prompts_by_frame_obj else 0
 
 
652
 
653
+ frames_to_track = state.num_frames - earliest_frame
654
 
655
  outputs_per_frame = {}
656
 
657
  for model_outputs in model.propagate_in_video_iterator(
658
+ inference_session=state.inference_session,
659
  start_frame_idx=earliest_frame,
660
  max_frame_num_to_track=frames_to_track,
661
  ):
662
  processed_outputs = processor.postprocess_outputs(
663
+ state.inference_session,
664
  model_outputs,
665
  )
666
  frame_idx = model_outputs.frame_idx
 
671
  scores = processed_outputs["scores"]
672
  prompt_to_obj_ids = processed_outputs.get("prompt_to_obj_ids", {})
673
 
674
+ masks_for_frame = state.masks_by_frame.setdefault(frame_idx, {})
675
+ frame_texts = state.text_prompts_by_frame_obj.setdefault(frame_idx, {})
676
 
677
  num_objects = len(object_ids)
678
  if num_objects > 0:
 
699
  # Store prompt and assign color
700
  if found_prompt:
701
  frame_texts[current_obj_id] = found_prompt.strip()
702
+ _ensure_color_for_obj(state, current_obj_id)
703
 
704
+ state.composited_frames.pop(frame_idx, None)
705
  last_frame_idx = frame_idx
706
  processed += 1
707
  if processed % 30 == 0 or processed == total:
708
+ yield state, f"Propagating masks: {processed}/{total}", gr.update(value=frame_idx)
709
  else:
710
+ if state.inference_session is None:
711
+ yield state, "Tracker model not loaded.", gr.update()
712
  return
713
 
714
  model = TRACKER_MODEL
715
  processor = TRACKER_PROCESSOR
716
 
717
+ for sam2_video_output in model.propagate_in_video_iterator(inference_session=state.inference_session):
 
 
718
  video_res_masks = processor.post_process_masks(
719
  [sam2_video_output.pred_masks],
720
+ original_sizes=[[state.inference_session.video_height, state.inference_session.video_width]],
 
 
721
  )[0]
722
 
723
  frame_idx = sam2_video_output.frame_idx
724
+ for i, out_obj_id in enumerate(state.inference_session.obj_ids):
725
+ _ensure_color_for_obj(state, int(out_obj_id))
726
  mask_2d = video_res_masks[i].cpu().numpy()
727
+ masks_for_frame = state.masks_by_frame.setdefault(frame_idx, {})
728
  masks_for_frame[int(out_obj_id)] = mask_2d
729
+ state.composited_frames.pop(frame_idx, None)
730
 
731
  last_frame_idx = frame_idx
732
  processed += 1
733
  if processed % 30 == 0 or processed == total:
734
+ yield state, f"Propagating masks: {processed}/{total}", gr.update(value=frame_idx)
735
 
736
  text = f"Propagated masks across {processed} frames."
737
+ yield state, text, gr.update(value=last_frame_idx)
738
 
739
 
740
+ def reset_prompts(state: AppState) -> tuple[AppState, Image.Image, str, str]:
741
  """Reset prompts and all outputs, but keep processed frames and cached vision features."""
742
+ if state is None or state.inference_session is None:
743
+ active_prompts = _get_active_prompts_display(state)
744
+ return state, None, "No active session to reset.", active_prompts
745
 
746
+ if state.active_tab != "text":
747
+ active_prompts = _get_active_prompts_display(state)
748
+ return state, None, "Reset prompts is only available for text prompting mode.", active_prompts
749
 
750
  # Reset inference session tracking data but keep cache and processed frames
751
+ if hasattr(state.inference_session, "reset_tracking_data"):
752
+ state.inference_session.reset_tracking_data()
753
 
754
  # Manually clear prompts (reset_tracking_data doesn't clear prompts themselves)
755
+ if hasattr(state.inference_session, "prompts"):
756
+ state.inference_session.prompts.clear()
757
+ if hasattr(state.inference_session, "prompt_input_ids"):
758
+ state.inference_session.prompt_input_ids.clear()
759
+ if hasattr(state.inference_session, "prompt_embeddings"):
760
+ state.inference_session.prompt_embeddings.clear()
761
+ if hasattr(state.inference_session, "prompt_attention_masks"):
762
+ state.inference_session.prompt_attention_masks.clear()
763
+ if hasattr(state.inference_session, "obj_id_to_prompt_id"):
764
+ state.inference_session.obj_id_to_prompt_id.clear()
765
 
766
  # Reset detection-tracking fusion state
767
+ if hasattr(state.inference_session, "obj_id_to_score"):
768
+ state.inference_session.obj_id_to_score.clear()
769
+ if hasattr(state.inference_session, "obj_id_to_tracker_score_frame_wise"):
770
+ state.inference_session.obj_id_to_tracker_score_frame_wise.clear()
771
+ if hasattr(state.inference_session, "obj_id_to_last_occluded"):
772
+ state.inference_session.obj_id_to_last_occluded.clear()
773
+ if hasattr(state.inference_session, "max_obj_id"):
774
+ state.inference_session.max_obj_id = -1
775
+ if hasattr(state.inference_session, "obj_first_frame_idx"):
776
+ state.inference_session.obj_first_frame_idx.clear()
777
+ if hasattr(state.inference_session, "unmatched_frame_inds"):
778
+ state.inference_session.unmatched_frame_inds.clear()
779
+ if hasattr(state.inference_session, "overlap_pair_to_frame_inds"):
780
+ state.inference_session.overlap_pair_to_frame_inds.clear()
781
+ if hasattr(state.inference_session, "trk_keep_alive"):
782
+ state.inference_session.trk_keep_alive.clear()
783
+ if hasattr(state.inference_session, "removed_obj_ids"):
784
+ state.inference_session.removed_obj_ids.clear()
785
+ if hasattr(state.inference_session, "suppressed_obj_ids"):
786
+ state.inference_session.suppressed_obj_ids.clear()
787
+ if hasattr(state.inference_session, "hotstart_removed_obj_ids"):
788
+ state.inference_session.hotstart_removed_obj_ids.clear()
789
 
790
  # Clear all app state outputs
791
+ state.masks_by_frame.clear()
792
+ state.text_prompts_by_frame_obj.clear()
793
+ state.composited_frames.clear()
794
+ state.color_by_obj.clear()
795
+ state.color_by_prompt.clear()
796
 
797
  # Update display
798
+ current_idx = int(getattr(state, "current_frame_idx", 0))
799
+ current_idx = max(0, min(current_idx, state.num_frames - 1))
800
+ preview_img = update_frame_display(state, current_idx)
801
+ active_prompts = _get_active_prompts_display(state)
802
  status = "Prompts and outputs reset. Processed frames and cached vision features preserved."
803
 
804
+ return state, preview_img, status, active_prompts
805
 
806
 
807
+ def reset_session(state: AppState) -> tuple[AppState, Image.Image, int, int, str, str]:
808
+ if not state.video_frames:
809
+ return state, None, 0, 0, "Session reset. Load a new video.", "**Active prompts:** None"
810
 
811
+ if state.active_tab == "text":
812
+ if state.video_frames:
813
  processor = TEXT_VIDEO_PROCESSOR
814
+ state.inference_session = processor.init_video_session(
815
+ video=state.video_frames,
816
  inference_device=DEVICE,
817
  processing_device="cpu",
818
  video_storage_device="cpu",
819
  dtype=DTYPE,
820
  )
821
+ elif state.inference_session is not None and hasattr(state.inference_session, "reset_inference_session"):
822
+ state.inference_session.reset_inference_session()
823
+ elif state.video_frames:
 
 
824
  processor = TRACKER_PROCESSOR
825
+ raw_video = [np.array(frame) for frame in state.video_frames]
826
+ state.inference_session = processor.init_video_session(
827
  video=raw_video,
828
  inference_device=DEVICE,
829
  video_storage_device="cpu",
 
831
  dtype=DTYPE,
832
  )
833
 
834
+ state.masks_by_frame.clear()
835
+ state.clicks_by_frame_obj.clear()
836
+ state.boxes_by_frame_obj.clear()
837
+ state.text_prompts_by_frame_obj.clear()
838
+ state.composited_frames.clear()
839
+ state.color_by_obj.clear()
840
+ state.color_by_prompt.clear()
841
+ state.pending_box_start = None
842
+ state.pending_box_start_frame_idx = None
843
+ state.pending_box_start_obj_id = None
844
 
845
  gc.collect()
846
 
847
+ current_idx = int(getattr(state, "current_frame_idx", 0))
848
+ current_idx = max(0, min(current_idx, state.num_frames - 1))
849
+ preview_img = update_frame_display(state, current_idx)
850
+ slider_minmax = gr.update(minimum=0, maximum=max(state.num_frames - 1, 0), interactive=True)
851
  slider_value = gr.update(value=current_idx)
852
  status = "Session reset. Prompts cleared; video preserved."
853
+ active_prompts = _get_active_prompts_display(state)
854
+ return state, preview_img, slider_minmax, slider_value, status, active_prompts
855
 
856
 
857
+ def _on_video_change_pointbox(state: AppState, video: str | dict) -> tuple[AppState, dict, Image.Image, str]:
858
+ state, min_idx, max_idx, first_frame, status = init_video_session(state, video, "point_box")
859
  return (
860
+ state,
861
  gr.update(minimum=min_idx, maximum=max_idx, value=min_idx, interactive=True),
862
  first_frame,
863
  status,
864
  )
865
 
866
 
867
+ def _on_video_change_text(state: AppState, video: str | dict) -> tuple[AppState, dict, Image.Image, str, str]:
868
+ state, min_idx, max_idx, first_frame, status = init_video_session(state, video, "text")
869
+ active_prompts = _get_active_prompts_display(state)
870
  return (
871
+ state,
872
  gr.update(minimum=min_idx, maximum=max_idx, value=min_idx, interactive=True),
873
  first_frame,
874
  status,
 
877
 
878
 
879
  with gr.Blocks(title="SAM3", theme=Soft(primary_hue="blue", secondary_hue="rose", neutral_hue="slate")) as demo:
880
+ app_state = gr.State(AppState())
881
 
882
  gr.Markdown(
883
  """
 
945
  with gr.Row():
946
  gr.Examples(
947
  examples=examples_list_text,
948
+ inputs=[app_state, video_in_text],
949
  fn=_on_video_change_text,
950
+ outputs=[app_state, frame_slider_text, preview_text, load_status_text, active_prompts_display],
951
  label="Examples",
952
  cache_examples=False,
953
  examples_per_page=5,
 
1008
  with gr.Row():
1009
  gr.Examples(
1010
  examples=examples_list_pointbox,
1011
+ inputs=[app_state, video_in_pointbox],
1012
  fn=_on_video_change_pointbox,
1013
+ outputs=[app_state, frame_slider_pointbox, preview_pointbox, load_status_pointbox],
1014
  label="Examples",
1015
  cache_examples=False,
1016
  examples_per_page=5,
 
1018
 
1019
  video_in_pointbox.change(
1020
  _on_video_change_pointbox,
1021
+ inputs=[app_state, video_in_pointbox],
1022
+ outputs=[app_state, frame_slider_pointbox, preview_pointbox, load_status_pointbox],
1023
  show_progress=True,
1024
  )
1025
 
 
1030
 
1031
  frame_slider_pointbox.change(
1032
  _sync_frame_idx_pointbox,
1033
+ inputs=[app_state, frame_slider_pointbox],
1034
  outputs=preview_pointbox,
1035
  )
1036
 
1037
  video_in_text.change(
1038
  _on_video_change_text,
1039
+ inputs=[app_state, video_in_text],
1040
+ outputs=[app_state, frame_slider_text, preview_text, load_status_text, active_prompts_display],
1041
  show_progress=True,
1042
  )
1043
 
 
1048
 
1049
  frame_slider_text.change(
1050
  _sync_frame_idx_text,
1051
+ inputs=[app_state, frame_slider_text],
1052
  outputs=preview_text,
1053
  )
1054
 
 
1057
  s.current_obj_id = int(oid)
1058
  return gr.update()
1059
 
1060
+ obj_id_inp.change(_sync_obj_id, inputs=[app_state, obj_id_inp], outputs=[])
1061
 
1062
  def _sync_label(s: AppState, lab: str):
1063
  if s is not None and lab is not None:
1064
  s.current_label = str(lab)
1065
  return gr.update()
1066
 
1067
+ label_radio.change(_sync_label, inputs=[app_state, label_radio], outputs=[])
1068
 
1069
  def _sync_prompt_type(s: AppState, val: str):
1070
  if s is not None and val is not None:
 
1079
 
1080
  prompt_type.change(
1081
  _sync_prompt_type,
1082
+ inputs=[app_state, prompt_type],
1083
  outputs=[label_radio, clear_old_chk],
1084
  )
1085
 
1086
  preview_pointbox.select(
1087
  on_image_click,
1088
+ [preview_pointbox, app_state, frame_slider_pointbox, obj_id_inp, label_radio, clear_old_chk],
1089
  preview_pointbox,
1090
  )
1091
 
 
1095
 
1096
  text_apply_btn.click(
1097
  _on_text_apply,
1098
+ inputs=[app_state, frame_slider_text, text_prompt_input],
1099
  outputs=[preview_text, text_status, active_prompts_display],
1100
  )
1101
 
1102
  reset_prompts_btn.click(
1103
  reset_prompts,
1104
+ inputs=[app_state],
1105
+ outputs=[app_state, preview_text, text_status, active_prompts_display],
1106
  )
1107
 
1108
  def _render_video(s: AppState):
 
1131
  print(f"Failed to render video with cv2: {e}")
1132
  raise gr.Error(f"Failed to render video: {e}")
1133
 
1134
+ render_btn_pointbox.click(_render_video, inputs=[app_state], outputs=[playback_video_pointbox])
1135
+ render_btn_text.click(_render_video, inputs=[app_state], outputs=[playback_video_text])
1136
 
1137
  propagate_btn_pointbox.click(
1138
  propagate_masks,
1139
+ inputs=[app_state],
1140
+ outputs=[app_state, propagate_status_pointbox, frame_slider_pointbox],
1141
  )
1142
 
1143
  propagate_btn_text.click(
1144
  propagate_masks,
1145
+ inputs=[app_state],
1146
+ outputs=[app_state, propagate_status_text, frame_slider_text],
1147
  )
1148
 
1149
  reset_btn_pointbox.click(
1150
  reset_session,
1151
+ inputs=app_state,
1152
+ outputs=[app_state, preview_pointbox, frame_slider_pointbox, frame_slider_pointbox, load_status_pointbox],
1153
  )
1154
 
1155
  reset_btn_text.click(
1156
  reset_session,
1157
+ inputs=app_state,
1158
  outputs=[
1159
+ app_state,
1160
  preview_text,
1161
  frame_slider_text,
1162
  frame_slider_text,