SkalskiP commited on
Commit
0c52132
1 Parent(s): 03b9405

Fix `SAM_CHECKPOINT` path

Browse files
Files changed (1) hide show
  1. app.py +23 -5
app.py CHANGED
@@ -1,14 +1,18 @@
 
1
  import torch
2
 
3
  import gradio as gr
4
  import numpy as np
5
  import supervision as sv
6
 
 
7
  from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
8
 
 
 
9
  DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
10
 
11
- SAM_CHECKPOINT = "weights/sam_vit_h_4b8939.pth"
12
  SAM_MODEL_TYPE = "vit_h"
13
 
14
  MARKDOWN = """
@@ -25,12 +29,21 @@ sam = sam_model_registry[SAM_MODEL_TYPE](checkpoint=SAM_CHECKPOINT).to(device=DE
25
  mask_generator = SamAutomaticMaskGenerator(sam)
26
 
27
 
28
- def inference(image: np.ndarray) -> np.ndarray:
29
  return image
30
 
31
 
32
- image_input = gr.Image(label="Input", type="numpy")
33
- image_output = gr.Image(label="SoM Visual Prompt", type="numpy", height=512)
 
 
 
 
 
 
 
 
 
34
  run_button = gr.Button("Run")
35
 
36
  with gr.Blocks() as demo:
@@ -38,10 +51,15 @@ with gr.Blocks() as demo:
38
  with gr.Row():
39
  with gr.Column():
40
  image_input.render()
 
 
41
  with gr.Column():
42
  image_output.render()
43
  run_button.render()
44
 
45
- run_button.click(inference, inputs=[image_input], outputs=image_output)
 
 
 
46
 
47
  demo.queue().launch(debug=False, show_error=True)
 
1
+ import os
2
  import torch
3
 
4
  import gradio as gr
5
  import numpy as np
6
  import supervision as sv
7
 
8
+ from typing import List
9
  from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
10
 
11
+
12
+ HOME = os.getenv("HOME")
13
  DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
14
 
15
+ SAM_CHECKPOINT = os.path.join(HOME, "weights/sam_vit_h_4b8939.pth")
16
  SAM_MODEL_TYPE = "vit_h"
17
 
18
  MARKDOWN = """
 
29
  mask_generator = SamAutomaticMaskGenerator(sam)
30
 
31
 
32
+ def inference(image: np.ndarray, annotation_mode: List[str]) -> np.ndarray:
33
  return image
34
 
35
 
36
+ image_input = gr.Image(
37
+ label="Input",
38
+ type="numpy")
39
+ checkbox_annotation_mode = gr.CheckboxGroup(
40
+ choices=["Mark", "Mask", "Box"],
41
+ value=['Mark'],
42
+ label="Annotation Mode")
43
+ image_output = gr.Image(
44
+ label="SoM Visual Prompt",
45
+ type="numpy",
46
+ height=512)
47
  run_button = gr.Button("Run")
48
 
49
  with gr.Blocks() as demo:
 
51
  with gr.Row():
52
  with gr.Column():
53
  image_input.render()
54
+ with gr.Accordion(label="Detailed prompt settings (e.g., mark type)", open=False):
55
+ checkbox_annotation_mode.render()
56
  with gr.Column():
57
  image_output.render()
58
  run_button.render()
59
 
60
+ run_button.click(
61
+ fn=inference,
62
+ inputs=[image_input, checkbox_annotation_mode],
63
+ outputs=image_output)
64
 
65
  demo.queue().launch(debug=False, show_error=True)