pablovela5620 commited on
Commit
f0d4b35
1 Parent(s): 48f587e

initial working demo

Browse files
Files changed (4) hide show
  1. .gitignore +2 -1
  2. gradio_demo.py +97 -1
  3. main.py +178 -0
  4. models.py +234 -0
.gitignore CHANGED
@@ -159,4 +159,5 @@ cython_debug/
159
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
  #.idea/
161
  .vscode/*
162
- static/*
 
 
159
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
  #.idea/
161
  .vscode/*
162
+ static/*
163
+ model/*
gradio_demo.py CHANGED
@@ -4,10 +4,27 @@ from fastapi.staticfiles import StaticFiles
4
  import uvicorn
5
  import gradio as gr
6
  from datetime import datetime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  import rerun as rr
9
 
10
- rr.init("cube")
11
 
12
  # create a FastAPI app
13
  app = FastAPI()
@@ -20,12 +37,91 @@ static_dir.mkdir(parents=True, exist_ok=True)
20
  app.mount("/static", StaticFiles(directory=static_dir), name="static")
21
 
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  # Gradio stuff
24
  def predict():
25
  file_name = f"{datetime.utcnow().strftime('%s')}.html"
26
  file_path = static_dir / file_name
27
  rec = rr.memory_recording()
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  with open(file_path, "w") as f:
30
  f.write(rec.as_html())
31
  iframe = f"""<iframe src="/static/{file_name}" width="950" height="712"></iframe>"""
 
4
  import uvicorn
5
  import gradio as gr
6
  from datetime import datetime
7
+ from typing import Union, List
8
+
9
+ import cv2
10
+ import torch
11
+ from main import grounding_dino_detect
12
+ from models import (
13
+ run_segmentation,
14
+ resize_img,
15
+ load_image,
16
+ get_downloaded_model_path,
17
+ load_grounding_model,
18
+ create_sam,
19
+ CONFIG_PATH,
20
+ )
21
+ from segment_anything import SamPredictor
22
+ from segment_anything.modeling import Sam
23
+ from groundingdino.models import GroundingDINO
24
 
25
  import rerun as rr
26
 
27
+ rr.init("GroundingSAM")
28
 
29
  # create a FastAPI app
30
  app = FastAPI()
 
37
  app.mount("/static", StaticFiles(directory=static_dir), name="static")
38
 
39
 
40
+ def log_video_segmentation(
41
+ video_path: Path,
42
+ prompt: str,
43
+ model: GroundingDINO,
44
+ predictor: Sam,
45
+ device: str = "cpu",
46
+ ):
47
+ assert video_path.exists()
48
+ cap = cv2.VideoCapture(str(video_path))
49
+
50
+ idx = 0
51
+ while cap.isOpened():
52
+ ret, bgr = cap.read()
53
+ if not ret or idx > 20:
54
+ break
55
+ rr.set_time_sequence("frame", idx)
56
+ rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
57
+ rgb = resize_img(rgb, 512)
58
+ rr.log_image("image", rgb)
59
+
60
+ detections, phrases, id_from_phrase = grounding_dino_detect(
61
+ model, device, rgb, prompt
62
+ )
63
+
64
+ predictor.set_image(rgb)
65
+ run_segmentation(predictor, rgb, detections, phrases, id_from_phrase)
66
+
67
+ idx += 1
68
+
69
+
70
+ def log_images_segmentation(
71
+ images: list[Union[str, Path]],
72
+ prompt: str,
73
+ model: GroundingDINO,
74
+ predictor: Sam,
75
+ device: str = "cpu",
76
+ ):
77
+ for n, image_uri in enumerate(images):
78
+ rr.set_time_sequence("image", n)
79
+ image = load_image(image_uri)
80
+ rr.log_image("image", image)
81
+
82
+ detections, phrases, id_from_phrase = grounding_dino_detect(
83
+ model, device, image, prompt
84
+ )
85
+
86
+ predictor.set_image(image)
87
+ run_segmentation(predictor, image, detections, phrases, id_from_phrase)
88
+
89
+
90
  # Gradio stuff
91
  def predict():
92
  file_name = f"{datetime.utcnow().strftime('%s')}.html"
93
  file_path = static_dir / file_name
94
  rec = rr.memory_recording()
95
 
96
+ device = "cuda" if torch.cuda.is_available() else "cpu"
97
+
98
+ # load model
99
+ grounded_checkpoint = get_downloaded_model_path("grounding")
100
+ model = load_grounding_model(CONFIG_PATH, grounded_checkpoint, device=device)
101
+ sam = create_sam("vit_b", device)
102
+
103
+ predictor = SamPredictor(sam)
104
+
105
+ # prompt = "tires"
106
+
107
+ log_video_segmentation(
108
+ Path("dog_and_woman.mp4"),
109
+ "dog, woman",
110
+ model,
111
+ predictor,
112
+ device=device,
113
+ )
114
+
115
+ # log_images_segmentation(
116
+ # [
117
+ # "https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/truck.jpg"
118
+ # ],
119
+ # prompt,
120
+ # model,
121
+ # predictor,
122
+ # device=device,
123
+ # )
124
+
125
  with open(file_path, "w") as f:
126
  f.write(rec.as_html())
127
  iframe = f"""<iframe src="/static/{file_name}" width="950" height="712"></iframe>"""
main.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Example of using Rerun to log and visualize the out of grounded dino + segment-anything.
4
+
5
+ See: [segment_anything](https://github.com/IDEA-Research/Grounded-Segment-Anything).
6
+
7
+ Can be used to test mask-generation on one or more images, as well as videos. Images can be local file-paths
8
+ or remote urls. Videos must be local file-paths. Can use multiple prompts.
9
+ """
10
+
11
+
12
+ import argparse
13
+ import logging
14
+ import rerun as rr
15
+ import torch
16
+ import cv2
17
+ from pathlib import Path
18
+ from models import CONFIG_PATH, MODEL_URLS, get_downloaded_model_path
19
+ from models import load_grounding_model, create_sam, load_image, image_to_tensor
20
+ from models import get_grounding_output, run_segmentation, resize_img
21
+ from segment_anything import SamPredictor
22
+ from segment_anything.modeling import Sam
23
+ from groundingdino.models import GroundingDINO
24
+
25
+
26
+ def log_images_segmentation(args, model: GroundingDINO, predictor: Sam):
27
+ for n, image_uri in enumerate(args.images):
28
+ rr.set_time_sequence("image", n)
29
+ image = load_image(image_uri)
30
+ rr.log_image("image", image)
31
+
32
+ detections, phrases, id_from_phrase = grounding_dino_detect(
33
+ model, args.device, image, args.prompt
34
+ )
35
+
36
+ predictor.set_image(image)
37
+ run_segmentation(predictor, image, detections, phrases, id_from_phrase)
38
+
39
+
40
+ def grounding_dino_detect(model, device, image, prompt):
41
+ image_tensor = image_to_tensor(image)
42
+ logging.info(f"Running GroundedDINO with DETECTION PROMPT {prompt}.")
43
+ boxes_filt, box_phrases = get_grounding_output(
44
+ model, image_tensor, prompt, 0.3, 0.25, device=device
45
+ )
46
+ logging.info(f"Grounded output with prediction phrases: {box_phrases}")
47
+
48
+ # denormalize boxes (from [0, 1] to image size)
49
+ H, W, _ = image.shape
50
+ for i in range(boxes_filt.size(0)):
51
+ boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
52
+ boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
53
+ boxes_filt[i][2:] += boxes_filt[i][:2]
54
+
55
+ id_from_phrase = {phrase: i for i, phrase in enumerate(set(box_phrases), start=1)}
56
+ box_ids = [id_from_phrase[phrase] for phrase in box_phrases] # One mask per box
57
+
58
+ # Make sure we have an AnnotationInfo present for every class-id used in this image
59
+ rr.log_annotation_context(
60
+ "image",
61
+ [
62
+ rr.AnnotationInfo(id=id, label=phrase)
63
+ for phrase, id in id_from_phrase.items()
64
+ ],
65
+ timeless=False,
66
+ )
67
+
68
+ rr.log_rects(
69
+ "image/detections",
70
+ rects=boxes_filt.numpy(),
71
+ class_ids=box_ids,
72
+ rect_format=rr.RectFormat.XYXY,
73
+ )
74
+
75
+ return boxes_filt, box_phrases, id_from_phrase
76
+
77
+
78
+ def log_video_segmentation(args, model: GroundingDINO, predictor: Sam):
79
+ video_path = args.video_path
80
+ assert video_path.exists()
81
+ cap = cv2.VideoCapture(str(video_path))
82
+
83
+ idx = 0
84
+ while cap.isOpened():
85
+ ret, bgr = cap.read()
86
+ if not ret:
87
+ break
88
+ rr.set_time_sequence("frame", idx)
89
+ rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
90
+ rgb = resize_img(rgb, 512)
91
+ rr.log_image("image", rgb)
92
+
93
+ detections, phrases, id_from_phrase = grounding_dino_detect(
94
+ model, args.device, rgb, args.prompt
95
+ )
96
+
97
+ predictor.set_image(rgb)
98
+ run_segmentation(predictor, rgb, detections, phrases, id_from_phrase)
99
+
100
+ idx += 1
101
+
102
+
103
+ def main() -> None:
104
+ parser = argparse.ArgumentParser(
105
+ description="Run IDEA Research Grounded Dino + SAM example.",
106
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
107
+ )
108
+ parser.add_argument(
109
+ "--model",
110
+ action="store",
111
+ default="vit_b",
112
+ choices=MODEL_URLS.keys(),
113
+ help="Which model to use."
114
+ "(See: https://github.com/facebookresearch/segment-anything#model-checkpoints)",
115
+ )
116
+ parser.add_argument(
117
+ "--device",
118
+ action="store",
119
+ default="cpu",
120
+ help="Which torch device to use, e.g. cpu or cuda. "
121
+ "(See: https://pytorch.org/docs/stable/tensor_attributes.html#torch.device)",
122
+ )
123
+
124
+ parser.add_argument(
125
+ "--prompt",
126
+ default="tires and windows",
127
+ type=str,
128
+ help="List of prompts to use for bounding box detection.",
129
+ )
130
+
131
+ parser.add_argument(
132
+ "images", metavar="N", type=str, nargs="*", help="A list of images to process."
133
+ )
134
+
135
+ parser.add_argument(
136
+ "--bbox-threshold",
137
+ default=0.3,
138
+ type=float,
139
+ help="Threshold for a bounding box to be considered.",
140
+ )
141
+
142
+ parser.add_argument(
143
+ "--video-path",
144
+ default=None,
145
+ type=Path,
146
+ help="Path to video to run segmentation on",
147
+ )
148
+
149
+ rr.script_add_args(parser)
150
+ args = parser.parse_args()
151
+
152
+ rr.script_setup(args, "grounded_sam")
153
+ logging.getLogger().addHandler(rr.LoggingHandler("logs"))
154
+ logging.getLogger().setLevel(logging.INFO)
155
+
156
+ # load model
157
+ grounded_checkpoint = get_downloaded_model_path("grounding")
158
+ model = load_grounding_model(CONFIG_PATH, grounded_checkpoint, device=args.device)
159
+ sam = create_sam(args.model, args.device)
160
+
161
+ predictor = SamPredictor(sam)
162
+
163
+ if len(args.images) == 0 and args.video_path is None:
164
+ logging.info("No image provided. Using default.")
165
+ args.images = [
166
+ "https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/truck.jpg"
167
+ ]
168
+
169
+ if len(args.images) > 0:
170
+ log_images_segmentation(args, model, predictor)
171
+ elif args.video_path is not None:
172
+ log_video_segmentation(args, model, predictor)
173
+
174
+ rr.script_teardown(args)
175
+
176
+
177
+ if __name__ == "__main__":
178
+ main()
models.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Final, List, Mapping
5
+ from urllib.parse import urlparse
6
+
7
+ import cv2
8
+ from PIL import Image
9
+ import numpy as np
10
+ import requests
11
+ import rerun as rr
12
+ import torch
13
+ import torchvision
14
+ from cv2 import Mat
15
+ from segment_anything import SamPredictor, sam_model_registry
16
+ from segment_anything.modeling import Sam
17
+ from tqdm import tqdm
18
+
19
+ # Grounding DINO
20
+ import GroundingDINO.groundingdino.datasets.transforms as T
21
+ from GroundingDINO.groundingdino.models import build_model
22
+ from GroundingDINO.groundingdino.util.slconfig import SLConfig
23
+ from GroundingDINO.groundingdino.util.utils import (
24
+ clean_state_dict,
25
+ get_phrases_from_posmap,
26
+ )
27
+ from groundingdino.models import GroundingDINO
28
+
29
+
30
+ CONFIG_PATH: Final = (
31
+ Path(os.path.dirname(__file__))
32
+ / "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
33
+ )
34
+ MODEL_DIR: Final = Path(os.path.dirname(__file__)) / "model"
35
+ MODEL_URLS: Final = {
36
+ "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
37
+ "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
38
+ "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
39
+ "grounding": "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth",
40
+ }
41
+
42
+
43
+ def download_with_progress(url: str, dest: Path) -> None:
44
+ """Download file with tqdm progress bar."""
45
+ chunk_size = 1024 * 1024
46
+ resp = requests.get(url, stream=True)
47
+ total_size = int(resp.headers.get("content-length", 0))
48
+ with open(dest, "wb") as dest_file:
49
+ with tqdm(
50
+ desc="Downloading model",
51
+ total=total_size,
52
+ unit="iB",
53
+ unit_scale=True,
54
+ unit_divisor=1024,
55
+ ) as progress:
56
+ for data in resp.iter_content(chunk_size):
57
+ dest_file.write(data)
58
+ progress.update(len(data))
59
+
60
+
61
+ def get_downloaded_model_path(model_name: str) -> Path:
62
+ """Fetch the segment-anything model to a local cache directory."""
63
+ model_url = MODEL_URLS[model_name]
64
+
65
+ model_location = MODEL_DIR / model_url.split("/")[-1]
66
+ if not model_location.exists():
67
+ os.makedirs(MODEL_DIR, exist_ok=True)
68
+ download_with_progress(model_url, model_location)
69
+
70
+ return model_location
71
+
72
+
73
+ def create_sam(model: str, device: str) -> Sam:
74
+ """Load the segment-anything model, fetching the model-file as necessary."""
75
+ model_path = get_downloaded_model_path(model)
76
+
77
+ logging.info("PyTorch version: {}".format(torch.__version__))
78
+ logging.info("Torchvision version: {}".format(torchvision.__version__))
79
+ logging.info("CUDA is available: {}".format(torch.cuda.is_available()))
80
+
81
+ logging.info("Building sam from: {}".format(model_path))
82
+ sam = sam_model_registry[model](checkpoint=model_path)
83
+ return sam.to(device=device)
84
+
85
+
86
+ def run_segmentation(
87
+ predictor: SamPredictor,
88
+ image: Mat,
89
+ detections,
90
+ phrases: List[str],
91
+ id_from_phrase: Mapping[str, int],
92
+ ) -> None:
93
+ """Run segmentation on a single image."""
94
+ if detections.shape[0] == 0:
95
+ return
96
+ logging.info("Finding masks")
97
+ transformed_boxes = predictor.transform.apply_boxes_torch(
98
+ detections, image.shape[:2]
99
+ )
100
+
101
+ masks, _, _ = predictor.predict_torch(
102
+ point_coords=None,
103
+ point_labels=None,
104
+ boxes=transformed_boxes.to(predictor.device),
105
+ multimask_output=False,
106
+ )
107
+
108
+ logging.info("Found {} masks".format(len(masks)))
109
+
110
+ # Layer all of the masks that belong to a single phrase together
111
+ segmentation_img = np.zeros((image.shape[0], image.shape[1]))
112
+ for phrase, mask in zip(phrases, masks):
113
+ segmentation_img[mask.squeeze()] = id_from_phrase[phrase]
114
+
115
+ rr.log_segmentation_image("image/segmentation", segmentation_img)
116
+
117
+
118
+ def is_url(path: str) -> bool:
119
+ """Check if a path is a url or a local file."""
120
+ try:
121
+ result = urlparse(path)
122
+ return all([result.scheme, result.netloc])
123
+ except ValueError:
124
+ return False
125
+
126
+
127
+ def resize_img(img: Mat, max_dimension: int = 512) -> Mat:
128
+ height, width = img.shape[:2]
129
+ # Check if either dimension is larger than the maximum
130
+ if max(height, width) > max_dimension:
131
+ # Calculate the new dimensions while maintaining the aspect ratio
132
+ if height > width:
133
+ new_height = max_dimension
134
+ new_width = int((new_height * width) / height)
135
+ else:
136
+ new_width = max_dimension
137
+ new_height = int((new_width * height) / width)
138
+
139
+ # Resize the image
140
+ resized_image = cv2.resize(
141
+ img, (new_width, new_height), interpolation=cv2.INTER_AREA
142
+ )
143
+ return resized_image
144
+
145
+
146
+ def image_to_tensor(image: Mat) -> torch.Tensor:
147
+ """
148
+ Assumes a RGB OpenCV image, this is required for the DINO model
149
+ """
150
+ image_pil = Image.fromarray(image)
151
+ transform = T.Compose(
152
+ [
153
+ T.RandomResize([800], max_size=1333),
154
+ T.ToTensor(),
155
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
156
+ ]
157
+ )
158
+ image_tensor, _ = transform(image_pil, None) # 3, h, w
159
+ return image_tensor
160
+
161
+
162
+ def load_image(image_uri: str) -> Mat:
163
+ """Conditionally download an image from URL or load it from disk."""
164
+ logging.info("Loading: {}".format(image_uri))
165
+ if is_url(image_uri):
166
+ response = requests.get(image_uri)
167
+ response.raise_for_status()
168
+ image_data = np.asarray(bytearray(response.content), dtype="uint8")
169
+ image = cv2.imdecode(image_data, cv2.IMREAD_COLOR)
170
+ else:
171
+ image = cv2.imread(image_uri, cv2.IMREAD_COLOR)
172
+
173
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
174
+
175
+ return image
176
+
177
+
178
+ def load_grounding_model(
179
+ model_config_path: Path, model_checkpoint_path: Path, device: str
180
+ ) -> GroundingDINO:
181
+ args = SLConfig.fromfile(model_config_path)
182
+ args.device = device
183
+ model = build_model(args)
184
+ checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
185
+ _ = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
186
+ _ = model.eval()
187
+ return model
188
+
189
+
190
+ def get_grounding_output(
191
+ model: GroundingDINO,
192
+ image: torch.Tensor,
193
+ caption: str,
194
+ box_threshold: float,
195
+ text_threshold: float,
196
+ with_logits: bool = False,
197
+ device: str = "cpu",
198
+ ):
199
+ caption = caption.lower()
200
+ caption = caption.strip()
201
+ if not caption.endswith("."):
202
+ caption = caption + "."
203
+ model = model.to(device)
204
+ image = image.to(device)
205
+ with torch.no_grad():
206
+ outputs = model(image[None], captions=[caption])
207
+ logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
208
+ boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
209
+ logits.shape[0]
210
+
211
+ # filter output
212
+ logits_filt = logits.clone()
213
+ boxes_filt = boxes.clone()
214
+ filt_mask = logits_filt.max(dim=1)[0] > box_threshold
215
+ logits_filt = logits_filt[filt_mask] # num_filt, 256
216
+ boxes_filt = boxes_filt[filt_mask] # num_filt, 4
217
+ logits_filt.shape[0]
218
+
219
+ # get phrase
220
+ tokenlizer = model.tokenizer
221
+ tokenized = tokenlizer(caption)
222
+ # build pred
223
+ pred_phrases = []
224
+ for logit, box in zip(logits_filt, boxes_filt):
225
+ pred_phrase = get_phrases_from_posmap(
226
+ logit > text_threshold, tokenized, tokenlizer
227
+ )
228
+
229
+ if with_logits:
230
+ pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
231
+ else:
232
+ pred_phrases.append(pred_phrase)
233
+
234
+ return boxes_filt, pred_phrases