curt-park commited on
Commit
f17c02a
1 Parent(s): 4ddb621

Add clip query

Browse files
Files changed (2) hide show
  1. ViT-B-32.pt +3 -0
  2. app.py +63 -14
ViT-B-32.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af
3
+ size 353976522
app.py CHANGED
@@ -1,8 +1,9 @@
1
  import os
2
  from functools import lru_cache
3
  from random import randint
4
- from typing import Dict, List
5
 
 
6
  import cv2
7
  import gradio as gr
8
  import numpy as np
@@ -13,17 +14,27 @@ from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
13
  CHECKPOINT_PATH = "sam_vit_h_4b8939.pth"
14
  MODEL_TYPE = "default"
15
  MAX_WIDTH = MAX_HEIGHT = 800
 
16
  THRESHOLD = 0.05
17
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
 
19
 
20
  @lru_cache
21
- def load_mask_generator(model_size: str = "large") -> SamAutomaticMaskGenerator:
22
  sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device)
23
  mask_generator = SamAutomaticMaskGenerator(sam)
24
  return mask_generator
25
 
26
 
 
 
 
 
 
 
 
 
 
27
  def adjust_image_size(image: np.ndarray) -> np.ndarray:
28
  height, width = image.shape[:2]
29
  if height > width:
@@ -36,23 +47,56 @@ def adjust_image_size(image: np.ndarray) -> np.ndarray:
36
  return image
37
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  def filter_masks(
40
- masks: List[Dict[str, np.ndarray]],
 
41
  predicted_iou_threshold: float,
42
  stability_score_threshold: float,
43
  query: str,
44
  clip_threshold: float,
45
- ) -> List[np.ndarray]:
46
- filtered_masks: List[Dict[str, np.ndarray]] = []
 
 
47
  for mask in masks:
48
  if (
49
  mask["predicted_iou"] < predicted_iou_threshold
50
  or mask["stability_score"] < stability_score_threshold
51
  ):
52
  continue
 
53
  filtered_masks.append(mask)
54
 
55
- return [mask["segmentation"] for mask in filtered_masks]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
 
58
  def draw_masks(
@@ -62,7 +106,7 @@ def draw_masks(
62
  color = [randint(127, 255) for _ in range(3)]
63
 
64
  # draw mask overlay
65
- colored_mask = np.expand_dims(mask, 0).repeat(3, axis=0)
66
  colored_mask = np.moveaxis(colored_mask, 0, -1)
67
  masked = np.ma.MaskedArray(image, mask=colored_mask, fill_value=color)
68
  image_overlay = masked.filled()
@@ -70,7 +114,7 @@ def draw_masks(
70
 
71
  # draw contour
72
  contours, _ = cv2.findContours(
73
- np.uint8(mask), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
74
  )
75
  cv2.drawContours(image, contours, -1, (255, 0, 0), 2)
76
  return image
@@ -88,7 +132,12 @@ def segment(
88
  image = adjust_image_size(cv2.imread(image_path))
89
  masks = mask_generator.generate(image)
90
  masks = filter_masks(
91
- masks, predicted_iou_threshold, stability_score_threshold, query, clip_threshold
 
 
 
 
 
92
  )
93
  image = draw_masks(image, masks)
94
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
@@ -112,16 +161,16 @@ demo = gr.Interface(
112
  [
113
  0.9,
114
  0.8,
115
- 0.05,
116
  os.path.join(os.path.dirname(__file__), "examples/dog.jpg"),
117
- "",
118
  ],
119
  [
120
  0.9,
121
  0.8,
122
- 0.05,
123
  os.path.join(os.path.dirname(__file__), "examples/city.jpg"),
124
- "",
125
  ],
126
  [
127
  0.9,
@@ -135,7 +184,7 @@ demo = gr.Interface(
135
  0.8,
136
  0.05,
137
  os.path.join(os.path.dirname(__file__), "examples/horse.jpg"),
138
- "",
139
  ],
140
  ],
141
  )
 
1
  import os
2
  from functools import lru_cache
3
  from random import randint
4
+ from typing import Any, Callable, Dict, List, Tuple
5
 
6
+ import clip
7
  import cv2
8
  import gradio as gr
9
  import numpy as np
 
14
  CHECKPOINT_PATH = "sam_vit_h_4b8939.pth"
15
  MODEL_TYPE = "default"
16
  MAX_WIDTH = MAX_HEIGHT = 800
17
+ CLIP_WIDTH = CLIP_HEIGHT = 300
18
  THRESHOLD = 0.05
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
 
21
 
22
  @lru_cache
23
+ def load_mask_generator() -> SamAutomaticMaskGenerator:
24
  sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device)
25
  mask_generator = SamAutomaticMaskGenerator(sam)
26
  return mask_generator
27
 
28
 
29
+ @lru_cache
30
+ def load_clip(
31
+ name: str = "ViT-B-32.pt",
32
+ ) -> Tuple[torch.nn.Module, Callable[[PIL.Image.Image], torch.Tensor]]:
33
+ model_path = os.path.join(".", name)
34
+ model, preprocess = clip.load(model_path, device=device)
35
+ return model.to(device), preprocess
36
+
37
+
38
  def adjust_image_size(image: np.ndarray) -> np.ndarray:
39
  height, width = image.shape[:2]
40
  if height > width:
 
47
  return image
48
 
49
 
50
+ @torch.no_grad()
51
+ def get_scores(crops: List[PIL.Image.Image], query: str) -> torch.Tensor:
52
+ model, preprocess = load_clip()
53
+ preprocessed = [preprocess(crop) for crop in crops]
54
+ preprocessed = torch.stack(preprocessed).to(device)
55
+ token = clip.tokenize(query).to(device)
56
+ img_features = model.encode_image(preprocessed)
57
+ txt_features = model.encode_text(token)
58
+ img_features /= img_features.norm(dim=-1, keepdim=True)
59
+ txt_features /= txt_features.norm(dim=-1, keepdim=True)
60
+ probs = 100.0 * img_features @ txt_features.T
61
+ return probs[:, 0].softmax(dim=0)
62
+
63
+
64
  def filter_masks(
65
+ image: np.ndarray,
66
+ masks: List[Dict[str, Any]],
67
  predicted_iou_threshold: float,
68
  stability_score_threshold: float,
69
  query: str,
70
  clip_threshold: float,
71
+ ) -> List[Dict[str, Any]]:
72
+ cropped_masks: List[PIL.Image.Image] = []
73
+ filtered_masks: List[Dict[str, Any]] = []
74
+
75
  for mask in masks:
76
  if (
77
  mask["predicted_iou"] < predicted_iou_threshold
78
  or mask["stability_score"] < stability_score_threshold
79
  ):
80
  continue
81
+
82
  filtered_masks.append(mask)
83
 
84
+ x, y, w, h = mask["bbox"]
85
+ crop = image[y : y + h, x : x + w]
86
+ crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
87
+ crop = PIL.Image.fromarray(np.uint8(crop * 255)).convert("RGB")
88
+ crop.resize((CLIP_WIDTH, CLIP_HEIGHT))
89
+ cropped_masks.append(crop)
90
+
91
+ if query and filtered_masks:
92
+ scores = get_scores(cropped_masks, query)
93
+ filtered_masks = [
94
+ filtered_masks[i]
95
+ for i, score in enumerate(scores)
96
+ if score > clip_threshold
97
+ ]
98
+
99
+ return filtered_masks
100
 
101
 
102
  def draw_masks(
 
106
  color = [randint(127, 255) for _ in range(3)]
107
 
108
  # draw mask overlay
109
+ colored_mask = np.expand_dims(mask["segmentation"], 0).repeat(3, axis=0)
110
  colored_mask = np.moveaxis(colored_mask, 0, -1)
111
  masked = np.ma.MaskedArray(image, mask=colored_mask, fill_value=color)
112
  image_overlay = masked.filled()
 
114
 
115
  # draw contour
116
  contours, _ = cv2.findContours(
117
+ np.uint8(mask["segmentation"]), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
118
  )
119
  cv2.drawContours(image, contours, -1, (255, 0, 0), 2)
120
  return image
 
132
  image = adjust_image_size(cv2.imread(image_path))
133
  masks = mask_generator.generate(image)
134
  masks = filter_masks(
135
+ image,
136
+ masks,
137
+ predicted_iou_threshold,
138
+ stability_score_threshold,
139
+ query,
140
+ clip_threshold,
141
  )
142
  image = draw_masks(image, masks)
143
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
 
161
  [
162
  0.9,
163
  0.8,
164
+ 0.15,
165
  os.path.join(os.path.dirname(__file__), "examples/dog.jpg"),
166
+ "A dog only",
167
  ],
168
  [
169
  0.9,
170
  0.8,
171
+ 0.1,
172
  os.path.join(os.path.dirname(__file__), "examples/city.jpg"),
173
+ "A bridge on the water",
174
  ],
175
  [
176
  0.9,
 
184
  0.8,
185
  0.05,
186
  os.path.join(os.path.dirname(__file__), "examples/horse.jpg"),
187
+ "horse",
188
  ],
189
  ],
190
  )