SkalskiP commited on
Commit
b577d3a
·
1 Parent(s): 725f958

SAM box inference is working

Browse files
Files changed (3) hide show
  1. .gitignore +2 -1
  2. app.py +107 -11
  3. requirements.txt +2 -2
.gitignore CHANGED
@@ -1 +1,2 @@
1
- venv/
 
 
1
+ venv/
2
+ .idea/
app.py CHANGED
@@ -1,29 +1,125 @@
1
- import torch
2
  import gradio as gr
 
 
 
 
 
 
 
3
 
4
  MARKDOWN = """
5
  # EfficientSAM sv. SAM
6
  """
7
 
8
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
9
 
 
 
 
 
 
10
 
11
- def inference(image):
 
 
 
 
 
 
 
 
12
  return image
13
 
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  with gr.Blocks() as demo:
16
  gr.Markdown(MARKDOWN)
17
- with gr.Row():
18
- input_image = gr.Image()
19
- output_image = gr.Image()
20
- with gr.Row():
21
- submit_button = gr.Button("Submit")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  submit_button.click(
24
- inference,
25
- inputs=[input_image],
26
- outputs=output_image
27
  )
28
 
29
  demo.launch(debug=False, show_error=True)
 
1
+ import time
2
  import gradio as gr
3
+ import numpy as np
4
+ import supervision as sv
5
+ from PIL import Image
6
+ import torch
7
+ from transformers import SamModel, SamProcessor
8
+ from typing import Tuple
9
+
10
 
11
  MARKDOWN = """
12
  # EfficientSAM sv. SAM
13
  """
14
 
15
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ SAM_MODEL = SamModel.from_pretrained("facebook/sam-vit-huge").to(DEVICE)
17
+ SAM_PROCESSOR = SamProcessor.from_pretrained("facebook/sam-vit-huge")
18
+ MASK_ANNOTATOR = sv.MaskAnnotator(
19
+ color=sv.Color.red(),
20
+ color_lookup=sv.ColorLookup.INDEX)
21
+
22
 
23
+ def annotate_image(image: np.ndarray, detections: sv.Detections) -> np.ndarray:
24
+ bgr_image = image[:, :, ::-1]
25
+ annotated_bgr_image = MASK_ANNOTATOR.annotate(
26
+ scene=bgr_image, detections=detections)
27
+ return annotated_bgr_image[:, :, ::-1]
28
 
29
+
30
+ def efficient_sam_inference(
31
+ image: np.ndarray,
32
+ x_min: int,
33
+ y_min: int,
34
+ x_max: int,
35
+ y_max: int
36
+ ) -> np.ndarray:
37
+ time.sleep(0.2)
38
  return image
39
 
40
 
41
+ def sam_inference(
42
+ image: np.ndarray,
43
+ x_min: int,
44
+ y_min: int,
45
+ x_max: int,
46
+ y_max: int
47
+ ) -> np.ndarray:
48
+ input_boxes = [[[x_min, y_min, x_max, y_max]]]
49
+ inputs = SAM_PROCESSOR(
50
+ Image.fromarray(image),
51
+ input_boxes=[input_boxes],
52
+ return_tensors="pt"
53
+ ).to(DEVICE)
54
+
55
+ with torch.no_grad():
56
+ outputs = SAM_MODEL(**inputs)
57
+
58
+ mask = SAM_PROCESSOR.image_processor.post_process_masks(
59
+ outputs.pred_masks.cpu(),
60
+ inputs["original_sizes"].cpu(),
61
+ inputs["reshaped_input_sizes"].cpu()
62
+ )[0][0][0].numpy()
63
+ mask = mask[np.newaxis, ...]
64
+ detections = sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask)
65
+ return annotate_image(image=image, detections=detections)
66
+
67
+
68
+ def inference(
69
+ image: np.ndarray,
70
+ x_min: int,
71
+ y_min: int,
72
+ x_max: int,
73
+ y_max: int
74
+ ) -> Tuple[np.ndarray, np.ndarray]:
75
+ return (
76
+ efficient_sam_inference(image, x_min, y_min, x_max, y_max),
77
+ sam_inference(image, x_min, y_min, x_max, y_max)
78
+ )
79
+
80
+
81
  with gr.Blocks() as demo:
82
  gr.Markdown(MARKDOWN)
83
+ with gr.Tab(label="Box prompt"):
84
+ with gr.Row():
85
+ with gr.Column():
86
+ input_image = gr.Image()
87
+ with gr.Accordion(label="Box", open=False):
88
+ with gr.Row():
89
+ x_min_number = gr.Number(label="x_min")
90
+ y_min_number = gr.Number(label="y_min")
91
+ x_max_number = gr.Number(label="x_max")
92
+ y_max_number = gr.Number(label="y_max")
93
+ efficient_sam_output_image = gr.Image()
94
+ sam_output_image = gr.Image()
95
+ with gr.Row():
96
+ submit_button = gr.Button("Submit")
97
+
98
+ gr.Examples(
99
+ fn=inference,
100
+ examples=[
101
+ [
102
+ 'https://media.roboflow.com/notebooks/examples/dog.jpeg',
103
+ 69,
104
+ 247,
105
+ 624,
106
+ 930
107
+ ]
108
+ ],
109
+ inputs=[input_image, x_min_number, y_min_number, x_max_number, y_max_number],
110
+ outputs=[efficient_sam_output_image, sam_output_image],
111
+ )
112
+
113
+ submit_button.click(
114
+ efficient_sam_inference,
115
+ inputs=[input_image, x_min_number, y_min_number, x_max_number, y_max_number],
116
+ outputs=efficient_sam_output_image
117
+ )
118
 
119
  submit_button.click(
120
+ sam_inference,
121
+ inputs=[input_image, x_min_number, y_min_number, x_max_number, y_max_number],
122
+ outputs=sam_output_image
123
  )
124
 
125
  demo.launch(debug=False, show_error=True)
requirements.txt CHANGED
@@ -2,7 +2,7 @@
2
  torch
3
  torchvision
4
 
 
5
  gradio
6
  transformers
7
- supervision
8
- gradio-imageslider
 
2
  torch
3
  torchvision
4
 
5
+ pillow
6
  gradio
7
  transformers
8
+ supervision