SkalskiP commited on
Commit
7e50af9
1 Parent(s): 54c9770

Updated the 'sam_utils.py' and 'app.py' modules to implement automated mask generation, result highlighting and mark generation functionalities.

Browse files
Files changed (3) hide show
  1. Dockerfile +1 -1
  2. app.py +69 -24
  3. sam_utils.py +10 -1
Dockerfile CHANGED
@@ -31,7 +31,7 @@ WORKDIR $HOME/app
31
  RUN pip install torch==2.0.1+cu117 torchvision==0.15.2+cu117 -f https://download.pytorch.org/whl/torch_stable.html
32
 
33
  # Install dependencies
34
- RUN pip install --no-cache-dir gradio==3.50.2 opencv-python supervision==0.17.0rc3 \
35
  pillow requests
36
 
37
  # Install SAM and Detectron2
 
31
  RUN pip install torch==2.0.1+cu117 torchvision==0.15.2+cu117 -f https://download.pytorch.org/whl/torch_stable.html
32
 
33
  # Install dependencies
34
+ RUN pip install --no-cache-dir gradio==3.50.2 opencv-python supervision==0.17.0rc4 \
35
  pillow requests
36
 
37
  # Install SAM and Detectron2
app.py CHANGED
@@ -1,16 +1,16 @@
1
  import os
2
- from typing import List, Dict, Tuple, Any
3
 
4
  import cv2
5
  import gradio as gr
6
  import numpy as np
7
  import supervision as sv
8
  import torch
9
- from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
10
 
11
  from gpt4v import prompt_image
12
- from utils import postprocess_masks, Visualizer
13
- from sam_utils import sam_interactive_inference
14
 
15
  HOME = os.getenv("HOME")
16
  DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
@@ -19,6 +19,9 @@ SAM_CHECKPOINT = os.path.join(HOME, "app/weights/sam_vit_h_4b8939.pth")
19
  # SAM_CHECKPOINT = "weights/sam_vit_h_4b8939.pth"
20
  SAM_MODEL_TYPE = "vit_h"
21
 
 
 
 
22
  MARKDOWN = """
23
  [![arXiv](https://img.shields.io/badge/arXiv-1703.06870v3-b31b1b.svg)](https://arxiv.org/pdf/2310.11441.pdf)
24
 
@@ -34,7 +37,6 @@ MARKDOWN = """
34
 
35
  - [ ] Support for alphabetic labels
36
  - [ ] Support for Semantic-SAM (multi-level)
37
- - [ ] Support for result highlighting
38
  - [ ] Support for mask filtering based on granularity
39
  """
40
 
@@ -45,7 +47,7 @@ def inference(
45
  image_and_mask: Dict[str, np.ndarray],
46
  annotation_mode: List[str],
47
  mask_alpha: float
48
- ) -> Tuple[Tuple[np.ndarray, List[Any]], sv.Detections]:
49
  image = image_and_mask['image']
50
  mask = cv2.cvtColor(image_and_mask['mask'], cv2.COLOR_RGB2GRAY)
51
  is_interactive = not np.all(mask == 0)
@@ -56,9 +58,10 @@ def inference(
56
  mask=mask,
57
  model=SAM)
58
  else:
59
- mask_generator = SamAutomaticMaskGenerator(SAM)
60
- result = mask_generator.generate(image=image)
61
- detections = sv.Detections.from_sam(result)
 
62
  detections = postprocess_masks(
63
  detections=detections)
64
  bgr_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
@@ -69,23 +72,54 @@ def inference(
69
  with_mask="Mask" in annotation_mode,
70
  with_polygon="Polygon" in annotation_mode,
71
  with_label="Mark" in annotation_mode)
72
- return (cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB), []), detections
73
-
74
-
75
- def prompt(message, history, image: np.ndarray, api_key: str) -> str:
 
 
 
 
 
 
 
 
 
 
76
  if api_key == "":
77
  return "⚠️ Please set your OpenAI API key first"
78
- if image is None:
79
  return "⚠️ Please generate SoM visual prompt first"
80
  return prompt_image(
81
  api_key=api_key,
82
- image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
83
  prompt=message
84
  )
85
 
86
 
87
  def on_image_input_clear():
88
- return None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
 
91
  image_input = gr.Image(
@@ -106,7 +140,12 @@ slider_mask_alpha = gr.Slider(
106
  value=0.05,
107
  label="Mask Alpha")
108
  image_output = gr.AnnotatedImage(
109
- label="SoM Visual Prompt")
 
 
 
 
 
110
  openai_api_key = gr.Textbox(
111
  show_label=False,
112
  placeholder="Before you start chatting, set your OpenAI API key here",
@@ -115,11 +154,12 @@ openai_api_key = gr.Textbox(
115
  chatbot = gr.Chatbot(
116
  label="GPT-4V + SoM",
117
  height=256)
118
- run_button = gr.Button("Run")
 
119
 
120
  with gr.Blocks() as demo:
121
  gr.Markdown(MARKDOWN)
122
- detections_state = gr.State()
123
  with gr.Row():
124
  with gr.Column():
125
  image_input.render()
@@ -132,22 +172,27 @@ with gr.Blocks() as demo:
132
  slider_mask_alpha.render()
133
  with gr.Column():
134
  image_output.render()
135
- run_button.render()
 
136
  with gr.Row():
137
  openai_api_key.render()
138
  with gr.Row():
139
  gr.ChatInterface(
140
  chatbot=chatbot,
141
  fn=prompt,
142
- additional_inputs=[image_output, openai_api_key])
143
 
144
- run_button.click(
145
  fn=inference,
146
  inputs=[image_input, checkbox_annotation_mode, slider_mask_alpha],
147
- outputs=[image_output, detections_state])
148
  image_input.clear(
149
  fn=on_image_input_clear,
150
- outputs=[image_output, detections_state]
151
  )
 
 
 
 
152
 
153
  demo.queue().launch(debug=False, show_error=True)
 
1
  import os
2
+ from typing import List, Dict, Tuple, Any, Optional
3
 
4
  import cv2
5
  import gradio as gr
6
  import numpy as np
7
  import supervision as sv
8
  import torch
9
+ from segment_anything import sam_model_registry
10
 
11
  from gpt4v import prompt_image
12
+ from utils import postprocess_masks, Visualizer, extract_numbers_in_brackets
13
+ from sam_utils import sam_interactive_inference, sam_inference
14
 
15
  HOME = os.getenv("HOME")
16
  DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
 
19
  # SAM_CHECKPOINT = "weights/sam_vit_h_4b8939.pth"
20
  SAM_MODEL_TYPE = "vit_h"
21
 
22
+ ANNOTATED_IMAGE_KEY = "annotated_image"
23
+ DETECTIONS_KEY = "detections"
24
+
25
  MARKDOWN = """
26
  [![arXiv](https://img.shields.io/badge/arXiv-1703.06870v3-b31b1b.svg)](https://arxiv.org/pdf/2310.11441.pdf)
27
 
 
37
 
38
  - [ ] Support for alphabetic labels
39
  - [ ] Support for Semantic-SAM (multi-level)
 
40
  - [ ] Support for mask filtering based on granularity
41
  """
42
 
 
47
  image_and_mask: Dict[str, np.ndarray],
48
  annotation_mode: List[str],
49
  mask_alpha: float
50
+ ) -> Tuple[Tuple[np.ndarray, List[Tuple[np.ndarray, str]]], Dict[str, Any]]:
51
  image = image_and_mask['image']
52
  mask = cv2.cvtColor(image_and_mask['mask'], cv2.COLOR_RGB2GRAY)
53
  is_interactive = not np.all(mask == 0)
 
58
  mask=mask,
59
  model=SAM)
60
  else:
61
+ detections = sam_inference(
62
+ image=image,
63
+ model=SAM
64
+ )
65
  detections = postprocess_masks(
66
  detections=detections)
67
  bgr_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
 
72
  with_mask="Mask" in annotation_mode,
73
  with_polygon="Polygon" in annotation_mode,
74
  with_label="Mark" in annotation_mode)
75
+ annotated_image = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB)
76
+ state = {
77
+ ANNOTATED_IMAGE_KEY: annotated_image,
78
+ DETECTIONS_KEY: detections
79
+ }
80
+ return (annotated_image, []), state
81
+
82
+
83
+ def prompt(
84
+ message: str,
85
+ history: List[List[str]],
86
+ state: Dict[str, Any],
87
+ api_key: Optional[str]
88
+ ) -> str:
89
  if api_key == "":
90
  return "⚠️ Please set your OpenAI API key first"
91
+ if state is None or ANNOTATED_IMAGE_KEY not in state:
92
  return "⚠️ Please generate SoM visual prompt first"
93
  return prompt_image(
94
  api_key=api_key,
95
+ image=cv2.cvtColor(state[ANNOTATED_IMAGE_KEY], cv2.COLOR_BGR2RGB),
96
  prompt=message
97
  )
98
 
99
 
100
  def on_image_input_clear():
101
+ return None, {}
102
+
103
+
104
+ def highlight(
105
+ state: Dict[str, Any],
106
+ history: List[List[str]]
107
+ ) -> Optional[Tuple[np.ndarray, List[Tuple[np.ndarray, str]]]]:
108
+ if DETECTIONS_KEY not in state or ANNOTATED_IMAGE_KEY not in state:
109
+ return None
110
+
111
+ detections: sv.Detections = state[DETECTIONS_KEY]
112
+ annotated_image: np.ndarray = state[ANNOTATED_IMAGE_KEY]
113
+
114
+ response = history[-1][-1]
115
+ detections_ids = extract_numbers_in_brackets(text=response)
116
+ highlighted_detections = [
117
+ (detections.mask[detection_id], str(detection_id))
118
+ for detection_id
119
+ in detections_ids
120
+ ]
121
+
122
+ return annotated_image, highlighted_detections
123
 
124
 
125
  image_input = gr.Image(
 
140
  value=0.05,
141
  label="Mask Alpha")
142
  image_output = gr.AnnotatedImage(
143
+ label="SoM Visual Prompt",
144
+ color_map={
145
+ str(i): sv.ColorPalette.default().by_idx(i).as_hex()
146
+ for i in range(64)
147
+ }
148
+ )
149
  openai_api_key = gr.Textbox(
150
  show_label=False,
151
  placeholder="Before you start chatting, set your OpenAI API key here",
 
154
  chatbot = gr.Chatbot(
155
  label="GPT-4V + SoM",
156
  height=256)
157
+ generate_button = gr.Button("Generate Marks")
158
+ highlight_button = gr.Button("Highlight Marks")
159
 
160
  with gr.Blocks() as demo:
161
  gr.Markdown(MARKDOWN)
162
+ inference_state = gr.State({})
163
  with gr.Row():
164
  with gr.Column():
165
  image_input.render()
 
172
  slider_mask_alpha.render()
173
  with gr.Column():
174
  image_output.render()
175
+ generate_button.render()
176
+ highlight_button.render()
177
  with gr.Row():
178
  openai_api_key.render()
179
  with gr.Row():
180
  gr.ChatInterface(
181
  chatbot=chatbot,
182
  fn=prompt,
183
+ additional_inputs=[inference_state, openai_api_key])
184
 
185
+ generate_button.click(
186
  fn=inference,
187
  inputs=[image_input, checkbox_annotation_mode, slider_mask_alpha],
188
+ outputs=[image_output, inference_state])
189
  image_input.clear(
190
  fn=on_image_input_clear,
191
+ outputs=[image_output, inference_state]
192
  )
193
+ highlight_button.click(
194
+ fn=highlight,
195
+ inputs=[inference_state, chatbot],
196
+ outputs=[image_output])
197
 
198
  demo.queue().launch(debug=False, show_error=True)
sam_utils.py CHANGED
@@ -2,7 +2,16 @@ import numpy as np
2
  import supervision as sv
3
 
4
  from segment_anything.modeling.sam import Sam
5
- from segment_anything import SamPredictor
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  def sam_interactive_inference(
 
2
  import supervision as sv
3
 
4
  from segment_anything.modeling.sam import Sam
5
+ from segment_anything import SamPredictor, SamAutomaticMaskGenerator
6
+
7
+
8
+ def sam_inference(
9
+ image: np.ndarray,
10
+ model: Sam
11
+ ) -> sv.Detections:
12
+ mask_generator = SamAutomaticMaskGenerator(model)
13
+ result = mask_generator.generate(image=image)
14
+ return sv.Detections.from_sam(result)
15
 
16
 
17
  def sam_interactive_inference(