SkalskiP commited on
Commit
017684f
1 Parent(s): 2740889

Deleted 'gpt4v.py' and moved its functionalities to 'utils.py' and 'app.py'.

Browse files
Files changed (4) hide show
  1. Dockerfile +1 -2
  2. app.py +31 -23
  3. gpt4v.py +0 -81
  4. utils.py +10 -67
Dockerfile CHANGED
@@ -32,7 +32,7 @@ RUN pip install torch==2.0.1+cu117 torchvision==0.15.2+cu117 -f https://download
32
 
33
  # Install dependencies
34
  RUN pip install --no-cache-dir gradio==3.50.2 opencv-python supervision==0.17.0rc4 \
35
- pillow requests
36
 
37
  # Install SAM and Detectron2
38
  RUN pip install 'git+https://github.com/facebookresearch/segment-anything.git'
@@ -44,7 +44,6 @@ RUN wget -c -O $HOME/app/weights/sam_vit_h_4b8939.pth https://dl.fbaipublicfiles
44
 
45
  COPY app.py .
46
  COPY utils.py .
47
- COPY gpt4v.py .
48
  COPY sam_utils.py .
49
 
50
  RUN find $HOME/app
 
32
 
33
  # Install dependencies
34
  RUN pip install --no-cache-dir gradio==3.50.2 opencv-python supervision==0.17.0rc4 \
35
+ pillow requests setofmark==0.1.0rc3
36
 
37
  # Install SAM and Detectron2
38
  RUN pip install 'git+https://github.com/facebookresearch/segment-anything.git'
 
44
 
45
  COPY app.py .
46
  COPY utils.py .
 
47
  COPY sam_utils.py .
48
 
49
  RUN find $HOME/app
app.py CHANGED
@@ -4,13 +4,13 @@ from typing import List, Dict, Tuple, Any, Optional
4
  import cv2
5
  import gradio as gr
6
  import numpy as np
 
7
  import supervision as sv
8
  import torch
9
  from segment_anything import sam_model_registry
10
 
11
- from gpt4v import prompt_image
12
  from sam_utils import sam_interactive_inference, sam_inference
13
- from utils import postprocess_masks, Visualizer, extract_numbers_in_brackets
14
 
15
  HOME = os.getenv("HOME")
16
  DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
@@ -21,17 +21,21 @@ SAM_MODEL_TYPE = "vit_h"
21
 
22
  ANNOTATED_IMAGE_KEY = "annotated_image"
23
  DETECTIONS_KEY = "detections"
24
-
25
  MARKDOWN = """
26
- [![arXiv](https://img.shields.io/badge/arXiv-1703.06870v3-b31b1b.svg)](https://arxiv.org/pdf/2310.11441.pdf)
27
-
28
- <h1 style='text-align: center'>
29
- <img
30
- src='https://som-gpt4v.github.io/website/img/som_logo.png'
31
- style='height:50px; display:inline-block'
32
- />
33
- Set-of-Mark (SoM) Prompting Unleashes Extraordinary Visual Grounding in GPT-4V
34
- </h1>
 
 
 
 
 
35
 
36
  ## 🚧 Roadmap
37
 
@@ -90,7 +94,7 @@ def prompt(
90
  return "⚠️ Please set your OpenAI API key first"
91
  if state is None or ANNOTATED_IMAGE_KEY not in state:
92
  return "⚠️ Please generate SoM visual prompt first"
93
- return prompt_image(
94
  api_key=api_key,
95
  image=cv2.cvtColor(state[ANNOTATED_IMAGE_KEY], cv2.COLOR_BGR2RGB),
96
  prompt=message
@@ -114,15 +118,17 @@ def highlight(
114
  if len(history) == 0:
115
  return None
116
 
117
- response = history[-1][-1]
118
- detections_ids = extract_numbers_in_brackets(text=response)
119
- highlighted_detections = [
120
- (detections.mask[detection_id], str(detection_id))
121
- for detection_id
122
- in detections_ids
 
 
 
123
  ]
124
-
125
- return annotated_image, highlighted_detections
126
 
127
 
128
  image_input = gr.Image(
@@ -131,7 +137,8 @@ image_input = gr.Image(
131
  tool="sketch",
132
  interactive=True,
133
  brush_radius=20.0,
134
- brush_color="#FFFFFF"
 
135
  )
136
  checkbox_annotation_mode = gr.CheckboxGroup(
137
  choices=["Mark", "Polygon", "Mask", "Box"],
@@ -147,7 +154,8 @@ image_output = gr.AnnotatedImage(
147
  color_map={
148
  str(i): sv.ColorPalette.default().by_idx(i).as_hex()
149
  for i in range(64)
150
- }
 
151
  )
152
  openai_api_key = gr.Textbox(
153
  show_label=False,
 
4
  import cv2
5
  import gradio as gr
6
  import numpy as np
7
+ import som
8
  import supervision as sv
9
  import torch
10
  from segment_anything import sam_model_registry
11
 
 
12
  from sam_utils import sam_interactive_inference, sam_inference
13
+ from utils import postprocess_masks, Visualizer
14
 
15
  HOME = os.getenv("HOME")
16
  DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
 
21
 
22
  ANNOTATED_IMAGE_KEY = "annotated_image"
23
  DETECTIONS_KEY = "detections"
 
24
  MARKDOWN = """
25
+ <div align='center'>
26
+ <h1>
27
+ <img
28
+ src='https://som-gpt4v.github.io/website/img/som_logo.png'
29
+ style='height:50px; display:inline-block'
30
+ />
31
+ Set-of-Mark (SoM) Prompting Unleashes Extraordinary Visual Grounding in GPT-4V
32
+ </h1>
33
+ <br>
34
+ [<a href="https://arxiv.org/abs/2109.07529"> arXiv paper </a>]
35
+ [<a href="https://som-gpt4v.github.io"> project page </a>]
36
+ [<a href="https://github.com/roboflow/set-of-mark"> python package </a>]
37
+ [<a href="https://github.com/microsoft/SoM"> code </a>]
38
+ </div>
39
 
40
  ## 🚧 Roadmap
41
 
 
94
  return "⚠️ Please set your OpenAI API key first"
95
  if state is None or ANNOTATED_IMAGE_KEY not in state:
96
  return "⚠️ Please generate SoM visual prompt first"
97
+ return som.prompt_image(
98
  api_key=api_key,
99
  image=cv2.cvtColor(state[ANNOTATED_IMAGE_KEY], cv2.COLOR_BGR2RGB),
100
  prompt=message
 
118
  if len(history) == 0:
119
  return None
120
 
121
+ text = history[-1][-1]
122
+ relevant_masks = som.extract_relevant_masks(
123
+ text=text,
124
+ detections=detections
125
+ )
126
+ relevant_masks = [
127
+ (mask, mark)
128
+ for mark, mask
129
+ in relevant_masks.items()
130
  ]
131
+ return annotated_image, relevant_masks
 
132
 
133
 
134
  image_input = gr.Image(
 
137
  tool="sketch",
138
  interactive=True,
139
  brush_radius=20.0,
140
+ brush_color="#FFFFFF",
141
+ height=512
142
  )
143
  checkbox_annotation_mode = gr.CheckboxGroup(
144
  choices=["Mark", "Polygon", "Mask", "Box"],
 
154
  color_map={
155
  str(i): sv.ColorPalette.default().by_idx(i).as_hex()
156
  for i in range(64)
157
+ },
158
+ height=512
159
  )
160
  openai_api_key = gr.Textbox(
161
  show_label=False,
gpt4v.py DELETED
@@ -1,81 +0,0 @@
1
- import cv2
2
- import base64
3
- import requests
4
-
5
- import numpy as np
6
-
7
-
8
- META_PROMPT = '''
9
- For any labels or markings on an image that you reference in your response, please
10
- enclose them in square brackets ([]) and list them explicitly. Do not use ranges; for
11
- example, instead of '1 - 4', list as '[1], [2], [3], [4]'. These labels could be
12
- numbers or letters and typically correspond to specific segments or parts of the image.
13
- '''
14
- API_URL = "https://api.openai.com/v1/chat/completions"
15
-
16
-
17
- def encode_image_to_base64(image: np.ndarray) -> str:
18
- """
19
- Encodes an image into a base64-encoded string in JPEG format.
20
-
21
- Parameters:
22
- image (np.ndarray): The image to be encoded. This should be a numpy array as
23
- typically used in OpenCV.
24
-
25
- Returns:
26
- str: A base64-encoded string representing the image in JPEG format.
27
- """
28
- success, buffer = cv2.imencode('.jpg', image)
29
- if not success:
30
- raise ValueError("Could not encode image to JPEG format.")
31
-
32
- encoded_image = base64.b64encode(buffer).decode('utf-8')
33
- return encoded_image
34
-
35
-
36
- def compose_headers(api_key: str) -> dict:
37
- return {
38
- "Content-Type": "application/json",
39
- "Authorization": f"Bearer {api_key}"
40
- }
41
-
42
-
43
- def compose_payload(image: np.ndarray, prompt: str) -> dict:
44
- base64_image = encode_image_to_base64(image)
45
- return {
46
- "model": "gpt-4-vision-preview",
47
- "messages": [
48
- {
49
- "role": "system",
50
- "content": [
51
- META_PROMPT
52
- ]
53
- },
54
- {
55
- "role": "user",
56
- "content": [
57
- {
58
- "type": "text",
59
- "text": prompt
60
- },
61
- {
62
- "type": "image_url",
63
- "image_url": {
64
- "url": f"data:image/jpeg;base64,{base64_image}"
65
- }
66
- }
67
- ]
68
- }
69
- ],
70
- "max_tokens": 800
71
- }
72
-
73
-
74
- def prompt_image(api_key: str, image: np.ndarray, prompt: str) -> str:
75
- headers = compose_headers(api_key=api_key)
76
- payload = compose_payload(image=image, prompt=prompt)
77
- response = requests.post(url=API_URL, headers=headers, json=payload).json()
78
-
79
- if 'error' in response:
80
- raise ValueError(response['error']['message'])
81
- return response['choices'][0]['message']['content']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils.py CHANGED
@@ -1,7 +1,5 @@
1
- import re
2
- from typing import List
3
-
4
  import cv2
 
5
 
6
  import numpy as np
7
  import supervision as sv
@@ -13,7 +11,7 @@ class Visualizer:
13
  self,
14
  line_thickness: int = 2,
15
  mask_opacity: float = 0.1,
16
- text_scale: float = 0.5
17
  ) -> None:
18
  self.box_annotator = sv.BoundingBoxAnnotator(
19
  color_lookup=sv.ColorLookup.INDEX,
@@ -25,6 +23,8 @@ class Visualizer:
25
  color_lookup=sv.ColorLookup.INDEX,
26
  thickness=line_thickness)
27
  self.label_annotator = sv.LabelAnnotator(
 
 
28
  color_lookup=sv.ColorLookup.INDEX,
29
  text_position=sv.Position.CENTER_OF_MASS,
30
  text_scale=text_scale)
@@ -85,7 +85,11 @@ def refine_mask(
85
  relative_area = area / total_area
86
  if relative_area < area_threshold:
87
  cv2.drawContours(
88
- mask, [contour], -1, (0 if mode == 'islands' else 255), -1
 
 
 
 
89
  )
90
 
91
  return np.where(mask > 0, 1, 0).astype(bool)
@@ -116,52 +120,6 @@ def filter_masks_by_relative_area(
116
  return masks[min_area_filter & max_area_filter]
117
 
118
 
119
- def compute_iou(mask1: np.ndarray, mask2: np.ndarray) -> float:
120
- """
121
- Computes the Intersection over Union (IoU) of two masks.
122
-
123
- Parameters:
124
- mask1, mask2 (np.ndarray): Two mask arrays.
125
-
126
- Returns:
127
- float: The IoU of the two masks.
128
- """
129
- intersection = np.logical_and(mask1, mask2).sum()
130
- union = np.logical_or(mask1, mask2).sum()
131
- return intersection / union if union != 0 else 0
132
-
133
-
134
- def filter_highly_overlapping_masks(
135
- masks: np.ndarray,
136
- iou_threshold: float
137
- ) -> np.ndarray:
138
- """
139
- Removes masks with high overlap from a set of masks.
140
-
141
- Parameters:
142
- masks (np.ndarray): A 3D numpy array with shape (N, H, W), where N is the
143
- number of masks, and H and W are the height and width of the masks.
144
- iou_threshold (float): The IoU threshold above which masks will be considered as
145
- overlapping.
146
-
147
- Returns:
148
- np.ndarray: A 3D numpy array of masks with highly overlapping masks removed.
149
- """
150
- num_masks = masks.shape[0]
151
- keep_mask = np.ones(num_masks, dtype=bool)
152
-
153
- for i in range(num_masks):
154
- for j in range(i + 1, num_masks):
155
- if not keep_mask[i] or not keep_mask[j]:
156
- continue
157
-
158
- iou = compute_iou(masks[i, :, :], masks[j, :, :])
159
- if iou > iou_threshold:
160
- keep_mask[j] = False
161
-
162
- return masks[keep_mask]
163
-
164
-
165
  def postprocess_masks(
166
  detections: sv.Detections,
167
  area_threshold: float = 0.01,
@@ -200,7 +158,7 @@ def postprocess_masks(
200
  masks=masks,
201
  min_relative_area=min_relative_area,
202
  max_relative_area=max_relative_area)
203
- masks = filter_highly_overlapping_masks(
204
  masks=masks,
205
  iou_threshold=iou_threshold)
206
 
@@ -208,18 +166,3 @@ def postprocess_masks(
208
  xyxy=sv.mask_to_xyxy(masks),
209
  mask=masks
210
  )
211
-
212
-
213
- def extract_numbers_in_brackets(text: str) -> List[int]:
214
- """
215
- Extracts all numbers enclosed in square brackets from a given string.
216
-
217
- Args:
218
- text (str): The string to be searched.
219
-
220
- Returns:
221
- List[int]: A list of integers found within square brackets.
222
- """
223
- pattern = r'\[(\d+)\]'
224
- numbers = [int(num) for num in re.findall(pattern, text)]
225
- return numbers
 
 
 
 
1
  import cv2
2
+ import som
3
 
4
  import numpy as np
5
  import supervision as sv
 
11
  self,
12
  line_thickness: int = 2,
13
  mask_opacity: float = 0.1,
14
+ text_scale: float = 0.6
15
  ) -> None:
16
  self.box_annotator = sv.BoundingBoxAnnotator(
17
  color_lookup=sv.ColorLookup.INDEX,
 
23
  color_lookup=sv.ColorLookup.INDEX,
24
  thickness=line_thickness)
25
  self.label_annotator = sv.LabelAnnotator(
26
+ color=sv.Color.black(),
27
+ text_color=sv.Color.white(),
28
  color_lookup=sv.ColorLookup.INDEX,
29
  text_position=sv.Position.CENTER_OF_MASS,
30
  text_scale=text_scale)
 
85
  relative_area = area / total_area
86
  if relative_area < area_threshold:
87
  cv2.drawContours(
88
+ image=mask,
89
+ contours=[contour],
90
+ contourIdx=-1,
91
+ color=(0 if mode == 'islands' else 255),
92
+ thickness=-1
93
  )
94
 
95
  return np.where(mask > 0, 1, 0).astype(bool)
 
120
  return masks[min_area_filter & max_area_filter]
121
 
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  def postprocess_masks(
124
  detections: sv.Detections,
125
  area_threshold: float = 0.01,
 
158
  masks=masks,
159
  min_relative_area=min_relative_area,
160
  max_relative_area=max_relative_area)
161
+ masks = som.mask_non_max_suppression(
162
  masks=masks,
163
  iou_threshold=iou_threshold)
164
 
 
166
  xyxy=sv.mask_to_xyxy(masks),
167
  mask=masks
168
  )