mart9992 commited on
Commit
9856e13
1 Parent(s): b793f0c
GroundingDINO/groundingdino/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (149 Bytes). View file
 
GroundingDINO/groundingdino/datasets/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (158 Bytes). View file
 
GroundingDINO/groundingdino/datasets/__pycache__/transforms.cpython-310.pyc ADDED
Binary file (10.1 kB). View file
 
__pycache__/grounded_sam_demo.cpython-310.pyc ADDED
Binary file (3.58 kB). View file
 
__pycache__/handler.cpython-310.pyc ADDED
Binary file (1.88 kB). View file
 
__pycache__/test.cpython-310.pyc ADDED
Binary file (1.73 kB). View file
 
grounded_sam_demo.py CHANGED
@@ -1,4 +1,5 @@
1
- import argparse
 
2
  import os
3
  import copy
4
 
@@ -16,8 +17,8 @@ from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases
16
 
17
  # segment anything
18
  from segment_anything import (
19
- sam_model_registry,
20
- sam_hq_model_registry,
21
  SamPredictor
22
  )
23
  import cv2
@@ -25,27 +26,13 @@ import numpy as np
25
  import matplotlib.pyplot as plt
26
 
27
 
28
- def load_image(image_path):
29
- # load image
30
- image_pil = Image.open(image_path).convert("RGB") # load image
31
-
32
- transform = T.Compose(
33
- [
34
- T.RandomResize([800], max_size=1333),
35
- T.ToTensor(),
36
- T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
37
- ]
38
- )
39
- image, _ = transform(image_pil, None) # 3, h, w
40
- return image_pil, image
41
-
42
-
43
  def load_model(model_config_path, model_checkpoint_path, device):
44
  args = SLConfig.fromfile(model_config_path)
45
  args.device = device
46
  model = build_model(args)
47
  checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
48
- load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
 
49
  print(load_res)
50
  _ = model.eval()
51
  return model
@@ -72,136 +59,38 @@ def get_grounding_output(model, image, caption, box_threshold, text_threshold, w
72
  boxes_filt = boxes_filt[filt_mask] # num_filt, 4
73
  logits_filt.shape[0]
74
 
75
- # get phrase
76
- tokenlizer = model.tokenizer
77
- tokenized = tokenlizer(caption)
78
- # build pred
79
- pred_phrases = []
80
- for logit, box in zip(logits_filt, boxes_filt):
81
- pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
82
- if with_logits:
83
- pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
84
- else:
85
- pred_phrases.append(pred_phrase)
86
-
87
- return boxes_filt, pred_phrases
88
-
89
- def show_mask(mask, ax, random_color=False):
90
- if random_color:
91
- color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
92
- else:
93
- color = np.array([30/255, 144/255, 255/255, 0.6])
94
- h, w = mask.shape[-2:]
95
- mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
96
- ax.imshow(mask_image)
97
-
98
-
99
- def show_box(box, ax, label):
100
- x0, y0 = box[0], box[1]
101
- w, h = box[2] - box[0], box[3] - box[1]
102
- ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
103
- ax.text(x0, y0, label)
104
-
105
-
106
- def save_mask_data(output_dir, mask_list, box_list, label_list):
107
- value = 0 # 0 for background
108
 
109
- mask_img = torch.zeros(mask_list.shape[-2:])
110
- for idx, mask in enumerate(mask_list):
111
- mask_img[mask.cpu().numpy()[0] == True] = value + idx + 1
112
- plt.figure(figsize=(10, 10))
113
- plt.imshow(mask_img.numpy())
114
- plt.axis('off')
115
- plt.savefig(os.path.join(output_dir, 'mask.jpg'), bbox_inches="tight", dpi=300, pad_inches=0.0)
116
-
117
- json_data = [{
118
- 'value': value,
119
- 'label': 'background'
120
- }]
121
- for label, box in zip(label_list, box_list):
122
- value += 1
123
- name, logit = label.split('(')
124
- logit = logit[:-1] # the last is ')'
125
- json_data.append({
126
- 'value': value,
127
- 'label': name,
128
- 'logit': float(logit),
129
- 'box': box.numpy().tolist(),
130
- })
131
- with open(os.path.join(output_dir, 'mask.json'), 'w') as f:
132
- json.dump(json_data, f)
133
-
134
-
135
- if __name__ == "__main__":
136
-
137
- parser = argparse.ArgumentParser("Grounded-Segment-Anything Demo", add_help=True)
138
- parser.add_argument("--config", type=str, required=True, help="path to config file")
139
- parser.add_argument(
140
- "--grounded_checkpoint", type=str, required=True, help="path to checkpoint file"
141
- )
142
- parser.add_argument(
143
- "--sam_version", type=str, default="vit_h", required=False, help="SAM ViT version: vit_b / vit_l / vit_h"
144
- )
145
- parser.add_argument(
146
- "--sam_checkpoint", type=str, required=False, help="path to sam checkpoint file"
147
- )
148
- parser.add_argument(
149
- "--sam_hq_checkpoint", type=str, default=None, help="path to sam-hq checkpoint file"
150
- )
151
- parser.add_argument(
152
- "--use_sam_hq", action="store_true", help="using sam-hq for prediction"
153
- )
154
- parser.add_argument("--input_image", type=str, required=True, help="path to image file")
155
- parser.add_argument("--text_prompt", type=str, required=True, help="text prompt")
156
- parser.add_argument(
157
- "--output_dir", "-o", type=str, default="outputs", required=True, help="output directory"
158
- )
159
 
160
- parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold")
161
- parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold")
162
-
163
- parser.add_argument("--device", type=str, default="cpu", help="running on cpu only!, default=False")
164
- args = parser.parse_args()
165
-
166
- # cfg
167
- config_file = args.config # change the path of the model config file
168
- grounded_checkpoint = args.grounded_checkpoint # change the path of the model
169
- sam_version = args.sam_version
170
- sam_checkpoint = args.sam_checkpoint
171
- sam_hq_checkpoint = args.sam_hq_checkpoint
172
- use_sam_hq = args.use_sam_hq
173
- image_path = args.input_image
174
- text_prompt = args.text_prompt
175
- output_dir = args.output_dir
176
- box_threshold = args.box_threshold
177
- text_threshold = args.text_threshold
178
- device = args.device
179
-
180
- # make dir
181
- os.makedirs(output_dir, exist_ok=True)
182
- # load image
183
- image_pil, image = load_image(image_path)
184
- # load model
185
- model = load_model(config_file, grounded_checkpoint, device=device)
186
 
187
- # visualize raw image
188
- image_pil.save(os.path.join(output_dir, "raw_image.jpg"))
 
 
 
 
189
 
190
- # run grounding dino model
191
- boxes_filt, pred_phrases = get_grounding_output(
192
- model, image, text_prompt, box_threshold, text_threshold, device=device
193
- )
 
 
 
194
 
195
- # initialize SAM
196
- if use_sam_hq:
197
- predictor = SamPredictor(sam_hq_model_registry[sam_version](checkpoint=sam_hq_checkpoint).to(device))
198
- else:
199
- predictor = SamPredictor(sam_model_registry[sam_version](checkpoint=sam_checkpoint).to(device))
200
- image = cv2.imread(image_path)
201
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
202
  predictor.set_image(image)
203
 
204
- size = image_pil.size
205
  H, W = size[1], size[0]
206
  for i in range(boxes_filt.size(0)):
207
  boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
@@ -209,27 +98,30 @@ if __name__ == "__main__":
209
  boxes_filt[i][2:] += boxes_filt[i][:2]
210
 
211
  boxes_filt = boxes_filt.cpu()
212
- transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)
 
213
 
214
  masks, _, _ = predictor.predict_torch(
215
- point_coords = None,
216
- point_labels = None,
217
- boxes = transformed_boxes.to(device),
218
- multimask_output = False,
219
  )
220
 
221
- # draw output image
222
- plt.figure(figsize=(10, 10))
223
- plt.imshow(image)
224
- for mask in masks:
225
- show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
226
- for box, label in zip(boxes_filt, pred_phrases):
227
- show_box(box.numpy(), plt.gca(), label)
228
 
 
 
229
  plt.axis('off')
230
- plt.savefig(
231
- os.path.join(output_dir, "grounded_sam_output.jpg"),
232
- bbox_inches="tight", dpi=300, pad_inches=0.0
233
- )
234
 
235
- save_mask_data(output_dir, masks, boxes_filt, pred_phrases)
 
 
 
 
 
 
 
1
+ from GroundingDINO.groundingdino.datasets.transforms import Compose, RandomResize, ToTensor, Normalize
2
+ from io import BytesIO
3
  import os
4
  import copy
5
 
 
17
 
18
  # segment anything
19
  from segment_anything import (
20
+ build_sam,
21
+ build_sam_hq,
22
  SamPredictor
23
  )
24
  import cv2
 
26
  import matplotlib.pyplot as plt
27
 
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  def load_model(model_config_path, model_checkpoint_path, device):
30
  args = SLConfig.fromfile(model_config_path)
31
  args.device = device
32
  model = build_model(args)
33
  checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
34
+ load_res = model.load_state_dict(
35
+ clean_state_dict(checkpoint["model"]), strict=False)
36
  print(load_res)
37
  _ = model.eval()
38
  return model
 
59
  boxes_filt = boxes_filt[filt_mask] # num_filt, 4
60
  logits_filt.shape[0]
61
 
62
+ return boxes_filt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
+ def grounded_sam_demo(input_pil, config_file, grounded_checkpoint, sam_checkpoint,
66
+ text_prompt, box_threshold=0.3, text_threshold=0.25,
67
+ device="cuda"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ # Convert PIL image to tensor with normalization
70
+ transform = Compose([
71
+ RandomResize([800], max_size=1333),
72
+ ToTensor(),
73
+ Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
74
+ ])
75
 
76
+ if input_pil.mode != "RGB":
77
+ input_pil = input_pil.convert("RGB")
78
+
79
+ image, _ = transform(input_pil, None)
80
+
81
+ # Load model
82
+ model = load_model(config_file, grounded_checkpoint, device=device)
83
 
84
+ # Get grounding dino model output
85
+ boxes_filt = get_grounding_output(
86
+ model, image, text_prompt, box_threshold, text_threshold, device=device)
87
+
88
+ # Initialize SAM
89
+ predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint).to(device))
90
+ image = cv2.cvtColor(np.array(input_pil), cv2.COLOR_RGB2BGR)
91
  predictor.set_image(image)
92
 
93
+ size = input_pil.size
94
  H, W = size[1], size[0]
95
  for i in range(boxes_filt.size(0)):
96
  boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
 
98
  boxes_filt[i][2:] += boxes_filt[i][:2]
99
 
100
  boxes_filt = boxes_filt.cpu()
101
+ transformed_boxes = predictor.transform.apply_boxes_torch(
102
+ boxes_filt, image.shape[:2]).to(device)
103
 
104
  masks, _, _ = predictor.predict_torch(
105
+ point_coords=None,
106
+ point_labels=None,
107
+ boxes=transformed_boxes.to(device),
108
+ multimask_output=False,
109
  )
110
 
111
+ # Create mask image
112
+ value = 0 # 0 for background
113
+ mask_img = torch.zeros(masks.shape[-2:])
114
+ for idx, mask in enumerate(masks):
115
+ mask_img[mask.cpu().numpy()[0] == True] = value + idx + 1
 
 
116
 
117
+ fig = plt.figure(figsize=(10, 10))
118
+ plt.imshow(mask_img.numpy())
119
  plt.axis('off')
 
 
 
 
120
 
121
+ buf = BytesIO()
122
+ plt.savefig(buf, format='png', bbox_inches="tight",
123
+ dpi=300, pad_inches=0.0)
124
+ buf.seek(0)
125
+ out_pil = Image.open(buf)
126
+
127
+ return out_pil
handler.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import torch
4
+ from test import just_get_sd_mask
5
+ import requests
6
+ from PIL import Image
7
+ from io import BytesIO
8
+
9
+ print(os.listdir('/usr/local/'))
10
+ print(torch.version.cuda)
11
+
12
+ class EndpointHandler():
13
+ def __init__(self, path="."):
14
+ is_production = True
15
+
16
+ if False:
17
+ return
18
+
19
+ os.chdir(path)
20
+
21
+ os.environ['AM_I_DOCKER'] = 'False'
22
+ os.environ['BUILD_WITH_CUDA'] = 'True'
23
+ os.environ['CUDA_HOME'] = '/usr/local/cuda-11.7/' if is_production else '/usr/local/cuda-12.1/'
24
+
25
+ # Install Segment Anything
26
+ subprocess.run(["python", "-m", "pip", "install", "-e", "segment_anything"])
27
+
28
+ # Install Grounding DINO
29
+ subprocess.run(["python", "-m", "pip", "install", "-e", "GroundingDINO"])
30
+
31
+ # Install diffusers
32
+ subprocess.run(["pip", "install", "--upgrade", "diffusers[torch]"])
33
+
34
+ # Install osx
35
+ subprocess.run(["git", "submodule", "update", "--init", "--recursive"])
36
+ subprocess.run(["bash", "grounded-sam-osx/install.sh"], cwd="grounded-sam-osx")
37
+
38
+ # Install RAM & Tag2Text
39
+ subprocess.run(["git", "clone", "https://github.com/xinyu1205/recognize-anything.git"])
40
+ subprocess.run(["pip", "install", "-r", "./recognize-anything/requirements.txt"])
41
+ subprocess.run(["pip", "install", "-e", "./recognize-anything/"])
42
+
43
+ def __call__(self, data):
44
+ mask_pil = just_get_sd_mask(Image.open("assets/demo1.jpg"), "bear", 10)
45
+
46
+ if mask_pil.mode != 'RGB':
47
+ mask_pil = mask_pil.convert('RGB')
48
+
49
+ # Convert PIL image to byte array
50
+ img_byte_arr = BytesIO()
51
+ mask_pil.save(img_byte_arr, format='JPEG')
52
+ img_byte_arr = img_byte_arr.getvalue()
53
+
54
+ # Upload to file.io
55
+ response = requests.post("https://file.io/", files={"file": img_byte_arr})
56
+ url = response.json().get('link')
57
+
58
+ return {"url": url}
handler_test.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from handler import EndpointHandler
2
+
3
+ # init handler
4
+ my_handler = EndpointHandler(path=".")
5
+
6
+ # prepare sample payload
7
+ non_holiday_payload = {"inputs": "I am quite excited how this will turn out", "date": "2022-08-08"}
8
+
9
+ # test the handler
10
+ non_holiday_pred=my_handler(non_holiday_payload)
11
+
12
+ # show results
13
+ print("non_holiday_pred", non_holiday_pred)
test.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from grounded_sam_demo import grounded_sam_demo
2
+ import numpy as np
3
+ from PIL import Image
4
+ from scipy.ndimage import convolve
5
+ from scipy.ndimage import binary_dilation
6
+
7
+
8
+ def get_sd_mask(color_mask_pil, target=(72, 4, 84), tolerance=50):
9
+ image_array = np.array(color_mask_pil)
10
+
11
+ # Update target based on the number of color channels in the image array
12
+ target = np.array(list(target) + [255] *
13
+ (image_array.shape[-1] - len(target)))
14
+
15
+ mask = np.abs(image_array - target) <= tolerance
16
+ mask = np.all(mask, axis=-1)
17
+
18
+ new_image_array = np.ones_like(image_array) * 255 # Start with white
19
+ # Apply black where condition met
20
+ new_image_array[mask] = [0] * image_array.shape[-1]
21
+
22
+ return Image.fromarray(new_image_array)
23
+
24
+
25
+ def expand_white_pixels(input_pil, expand_by=1):
26
+ img_array = np.array(input_pil)
27
+ is_white = np.all(img_array == 255, axis=-1)
28
+
29
+ kernel = np.ones((2*expand_by+1, 2*expand_by+1), bool)
30
+ expanded_white = binary_dilation(is_white, structure=kernel)
31
+
32
+ expanded_array = np.where(expanded_white[..., None], 255, img_array)
33
+
34
+ expanded_pil = Image.fromarray(expanded_array.astype('uint8'))
35
+ return expanded_pil
36
+
37
+
38
+ config_file = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
39
+ grounded_checkpoint = "groundingdino_swint_ogc.pth"
40
+ sam_checkpoint = "sam_hq_vit_h.pth"
41
+
42
+
43
+ def just_get_sd_mask(input_pil, text_prompt, padding):
44
+ print("Doing sam")
45
+
46
+ colored_mask_pil = grounded_sam_demo(
47
+ input_pil, config_file, grounded_checkpoint, sam_checkpoint, text_prompt)
48
+
49
+ print("doing to white")
50
+
51
+ sd_mask_pil = get_sd_mask(colored_mask_pil)
52
+
53
+ print("expanding white pixels")
54
+
55
+ sd_mask_withpadding_pil = expand_white_pixels(sd_mask_pil, padding)
56
+
57
+ return sd_mask_withpadding_pil