er1t0 commited on
Commit
80f96e1
1 Parent(s): 8870220
Files changed (3) hide show
  1. app.py +7 -8
  2. checkpoints/test.txt +0 -0
  3. myapp2.py +0 -204
app.py CHANGED
@@ -219,17 +219,16 @@ demo = gr.Interface(
219
  fn=segment_video,
220
  inputs=[
221
  gr.Video(label="Upload Video (Keep it under 10 seconds for this demo)"),
222
- gr.Textbox(label="Enter text prompt for object detection")
223
  ],
224
  outputs=gr.Video(label="Segmented Video"),
225
- title="Text-Prompted Video Object Segmentation",
226
  description="""
227
- This demo uses [Florence-2](https://huggingface.co/microsoft/Florence-2-large), a vision-language model, to enable text-prompted object detection for [SAM2](https://github.com/facebookresearch/segment-anything).
228
- Florence-2 interprets your text prompt, allowing SAM2 to segment the described object in the video.
229
-
230
- 1. Upload a short video (< 10 sec)
231
- 2. Describe the object to segment
232
- 3. Get your segmented video!
233
  """
234
  )
235
 
 
219
  fn=segment_video,
220
  inputs=[
221
  gr.Video(label="Upload Video (Keep it under 10 seconds for this demo)"),
222
+ gr.Textbox(label="Enter text prompt for object detection (eg - Gymnast , Car ) ")
223
  ],
224
  outputs=gr.Video(label="Segmented Video"),
225
+ title="Text-Prompted Video Object Segmentation with SAMv2",
226
  description="""
227
+ This demo uses [Florence-2](https://huggingface.co/microsoft/Florence-2-large), to enable text-prompted object detection for [SAM2](https://github.com/facebookresearch/segment-anything).
228
+
229
+ 1. Upload a short video (< 10 sec , you can fork this space on larger GPU for longer vids)
230
+ 2. Describe the object to segment.
231
+ 3. Get your segmented video.
 
232
  """
233
  )
234
 
checkpoints/test.txt DELETED
File without changes
myapp2.py DELETED
@@ -1,204 +0,0 @@
1
- import os
2
- import torch
3
- import numpy as np
4
- import gradio as gr
5
- from PIL import Image
6
- from transformers import AutoProcessor, AutoModelForCausalLM
7
- from sam2.build_sam import build_sam2_video_predictor, build_sam2
8
- from sam2.sam2_image_predictor import SAM2ImagePredictor
9
- import cv2
10
- import traceback
11
- import matplotlib.pyplot as plt
12
-
13
- # CUDA optimizations
14
- torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
15
- if torch.cuda.get_device_properties(0).major >= 8:
16
- torch.backends.cuda.matmul.allow_tf32 = True
17
- torch.backends.cudnn.allow_tf32 = True
18
-
19
- # Initialize models
20
- sam2_checkpoint = "../checkpoints/sam2_hiera_large.pt"
21
- model_cfg = "sam2_hiera_l.yaml"
22
-
23
- video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
24
- sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
25
- image_predictor = SAM2ImagePredictor(sam2_model)
26
-
27
- model_id = 'microsoft/Florence-2-large'
28
- florence_model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, torch_dtype=torch.bfloat16).eval().cuda()
29
- florence_processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
30
-
31
- def apply_color_mask(frame, mask, obj_id):
32
- cmap = plt.get_cmap("tab10")
33
- color = np.array(cmap(obj_id % 10)[:3]) # Use modulo 10 to cycle through colors
34
-
35
- # Ensure mask has the correct shape
36
- if mask.ndim == 4:
37
- mask = mask.squeeze() # Remove singleton dimensions
38
- if mask.ndim == 3 and mask.shape[0] == 1:
39
- mask = mask[0] # Take the first channel if it's a single-channel 3D array
40
-
41
- # Reshape mask to match frame dimensions
42
- mask = cv2.resize(mask.astype(np.float32), (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_LINEAR)
43
-
44
- # Expand dimensions of mask and color for broadcasting
45
- mask = np.expand_dims(mask, axis=2)
46
- color = color.reshape(1, 1, 3)
47
-
48
- colored_mask = mask * color
49
- return frame * (1 - mask) + colored_mask * 255
50
-
51
- def run_florence(image, text_input):
52
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
53
- task_prompt = '<OPEN_VOCABULARY_DETECTION>'
54
- prompt = task_prompt + text_input
55
- inputs = florence_processor(text=prompt, images=image, return_tensors="pt").to('cuda', torch.bfloat16)
56
- generated_ids = florence_model.generate(
57
- input_ids=inputs["input_ids"].cuda(),
58
- pixel_values=inputs["pixel_values"].cuda(),
59
- max_new_tokens=1024,
60
- early_stopping=False,
61
- do_sample=False,
62
- num_beams=3,
63
- )
64
- generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
65
- parsed_answer = florence_processor.post_process_generation(
66
- generated_text,
67
- task=task_prompt,
68
- image_size=(image.width, image.height)
69
- )
70
- return parsed_answer[task_prompt]['bboxes'][0]
71
-
72
- def remove_directory_contents(directory):
73
- for root, dirs, files in os.walk(directory, topdown=False):
74
- for name in files:
75
- os.remove(os.path.join(root, name))
76
- for name in dirs:
77
- os.rmdir(os.path.join(root, name))
78
-
79
- def process_video(video_path, prompt, chunk_size=30):
80
- try:
81
- video = cv2.VideoCapture(video_path)
82
- if not video.isOpened():
83
- raise ValueError("Unable to open video file")
84
-
85
- fps = video.get(cv2.CAP_PROP_FPS)
86
- frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
87
-
88
- # Process video in chunks
89
- all_segmented_frames = []
90
- for chunk_start in range(0, frame_count, chunk_size):
91
- chunk_end = min(chunk_start + chunk_size, frame_count)
92
-
93
- frames = []
94
- video.set(cv2.CAP_PROP_POS_FRAMES, chunk_start)
95
- for _ in range(chunk_end - chunk_start):
96
- ret, frame = video.read()
97
- if not ret:
98
- break
99
- frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
100
-
101
- if not frames:
102
- print(f"No frames extracted for chunk starting at {chunk_start}")
103
- continue
104
-
105
- # Florence detection on first frame of the chunk
106
- first_frame = Image.fromarray(frames[0])
107
- mask_box = run_florence(first_frame, prompt)
108
- print("Original mask box:", mask_box)
109
-
110
- # Convert mask_box to numpy array and ensure it's in the correct format
111
- mask_box = np.array(mask_box)
112
- print("Reshaped mask box:", mask_box)
113
-
114
- # SAM2 segmentation on first frame
115
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
116
- image_predictor.set_image(first_frame)
117
- masks, _, _ = image_predictor.predict(
118
- point_coords=None,
119
- point_labels=None,
120
- box=mask_box[None, :],
121
- multimask_output=False,
122
- )
123
- print("masks.shape",masks.shape)
124
-
125
- mask = masks.squeeze().astype(bool)
126
- print("Mask shape:", mask.shape)
127
- print("Frame shape:", frames[0].shape)
128
-
129
- # SAM2 video propagation
130
- temp_dir = f"temp_frames_{chunk_start}"
131
- os.makedirs(temp_dir, exist_ok=True)
132
- for i, frame in enumerate(frames):
133
- cv2.imwrite(os.path.join(temp_dir, f"{i:04d}.jpg"), cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
134
-
135
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
136
- inference_state = video_predictor.init_state(video_path=temp_dir)
137
- _, _, _ = video_predictor.add_new_mask(
138
- inference_state=inference_state,
139
- frame_idx=0,
140
- obj_id=1,
141
- mask=mask
142
- )
143
-
144
- video_segments = {}
145
- for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state):
146
- video_segments[out_frame_idx] = {
147
- out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
148
- for i, out_obj_id in enumerate(out_obj_ids)
149
- }
150
-
151
- print('segmenting for main vid done')
152
-
153
- # Apply segmentation masks to frames
154
- for i, frame in enumerate(frames):
155
- if i in video_segments:
156
- for out_obj_id, mask in video_segments[i].items():
157
- frame = apply_color_mask(frame, mask, out_obj_id)
158
- all_segmented_frames.append(frame.astype(np.uint8))
159
- else:
160
- all_segmented_frames.append(frame)
161
-
162
- # Clean up temporary files
163
- remove_directory_contents(temp_dir)
164
- os.rmdir(temp_dir)
165
-
166
- video.release()
167
-
168
- if not all_segmented_frames:
169
- raise ValueError("No frames were processed successfully")
170
-
171
- # Create video from segmented frames
172
- output_path = "segmented_video.mp4"
173
- out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps,
174
- (all_segmented_frames[0].shape[1], all_segmented_frames[0].shape[0]))
175
- for frame in all_segmented_frames:
176
- out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
177
- out.release()
178
-
179
- return output_path
180
-
181
- except Exception as e:
182
- print(f"Error in process_video: {str(e)}")
183
- print(traceback.format_exc()) # This will print the full stack trace
184
- return None
185
-
186
- def segment_video(video_file, prompt, chunk_size):
187
- if video_file is None:
188
- return None
189
- output_video = process_video(video_file, prompt, int(chunk_size))
190
- return output_video
191
-
192
- demo = gr.Interface(
193
- fn=segment_video,
194
- inputs=[
195
- gr.Video(label="Upload Video"),
196
- gr.Textbox(label="Enter prompt (e.g., 'a gymnast')"),
197
- gr.Slider(minimum=10, maximum=100, step=10, value=30, label="Chunk Size (frames)")
198
- ],
199
- outputs=gr.Video(label="Segmented Video"),
200
- title="Video Object Segmentation with Florence and SAM2",
201
- description="Upload a video and provide a text prompt to segment a specific object throughout the video."
202
- )
203
-
204
- demo.launch()