fffiloni commited on
Commit
935805a
1 Parent(s): 05605d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -15
app.py CHANGED
@@ -12,7 +12,7 @@ from sam2.sam2_image_predictor import SAM2ImagePredictor
12
  def preprocess_image(image):
13
  return image, gr.State([]), gr.State([]), image
14
 
15
- def get_point(point_type, tracking_points, trackings_input_label, first_frame_path, evt: gr.SelectData):
16
  print(f"You selected {evt.value} at {evt.index} from {evt.target}")
17
 
18
  tracking_points.value.append(evt.index)
@@ -24,16 +24,25 @@ def get_point(point_type, tracking_points, trackings_input_label, first_frame_pa
24
  trackings_input_label.value.append(0)
25
  print(f"TRACKING INPUT LABEL: {trackings_input_label.value}")
26
 
 
27
  transparent_background = Image.open(first_frame_path).convert('RGBA')
28
  w, h = transparent_background.size
29
- transparent_layer = np.zeros((h, w, 4))
 
 
 
 
 
 
 
30
  for index, track in enumerate(tracking_points.value):
31
  if trackings_input_label.value[index] == 1:
32
- cv2.circle(transparent_layer, track, 20, (0, 255, 0, 255), -1)
33
  else:
34
- cv2.circle(transparent_layer, track, 20, (255, 0, 0, 255), -1)
35
 
36
- transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
 
37
  selected_point_map = Image.alpha_composite(transparent_background, transparent_layer)
38
 
39
  return tracking_points, trackings_input_label, selected_point_map
@@ -170,19 +179,17 @@ with gr.Blocks() as demo:
170
  gr.Markdown("# SAM2 Image Predictor")
171
  with gr.Row():
172
  with gr.Column():
173
- input_image = gr.Image(label="input image", interactive=False, type="filepath", visible=False)
 
 
 
 
 
174
  with gr.Row():
175
  point_type = gr.Radio(label="point type", choices=["include", "exclude"], value="include")
176
  clear_points_btn = gr.Button("Clear Points")
177
-
178
- with gr.Column():
179
- checkpoint = gr.Dropdown(label="Checkpoint", choices=["tiny", "small", "base-plus", "large"], value="tiny")
180
- points_map = gr.Image(
181
- label="points map",
182
- type="filepath",
183
- interactive=True
184
- )
185
- submit_btn = gr.Button("Submit")
186
  with gr.Column():
187
  output_result = gr.Image()
188
  output_result_mask = gr.Image()
 
12
  def preprocess_image(image):
13
  return image, gr.State([]), gr.State([]), image
14
 
15
+ def get_point(point_type, tracking_points, trackings_input_label, first_frame_path, evt):
16
  print(f"You selected {evt.value} at {evt.index} from {evt.target}")
17
 
18
  tracking_points.value.append(evt.index)
 
24
  trackings_input_label.value.append(0)
25
  print(f"TRACKING INPUT LABEL: {trackings_input_label.value}")
26
 
27
+ # Open the image and get its dimensions
28
  transparent_background = Image.open(first_frame_path).convert('RGBA')
29
  w, h = transparent_background.size
30
+
31
+ # Define the circle radius as a fraction of the smaller dimension
32
+ fraction = 0.02 # You can adjust this value as needed
33
+ radius = int(fraction * min(w, h))
34
+
35
+ # Create a transparent layer to draw on
36
+ transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
37
+
38
  for index, track in enumerate(tracking_points.value):
39
  if trackings_input_label.value[index] == 1:
40
+ cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1)
41
  else:
42
+ cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1)
43
 
44
+ # Convert the transparent layer back to an image
45
+ transparent_layer = Image.fromarray(transparent_layer, 'RGBA')
46
  selected_point_map = Image.alpha_composite(transparent_background, transparent_layer)
47
 
48
  return tracking_points, trackings_input_label, selected_point_map
 
179
  gr.Markdown("# SAM2 Image Predictor")
180
  with gr.Row():
181
  with gr.Column():
182
+ input_image = gr.Image(label="input image", interactive=False, type="filepath", visible=False)
183
+ points_map = gr.Image(
184
+ label="points map",
185
+ type="filepath",
186
+ interactive=True
187
+ )
188
  with gr.Row():
189
  point_type = gr.Radio(label="point type", choices=["include", "exclude"], value="include")
190
  clear_points_btn = gr.Button("Clear Points")
191
+ checkpoint = gr.Dropdown(label="Checkpoint", choices=["tiny", "small", "base-plus", "large"], value="tiny")
192
+ submit_btn = gr.Button("Submit")
 
 
 
 
 
 
 
193
  with gr.Column():
194
  output_result = gr.Image()
195
  output_result_mask = gr.Image()