yucornetto commited on
Commit
a5b4e6b
1 Parent(s): 6bd742d

Upload 106 files

Browse files
Files changed (3) hide show
  1. app.py +2 -3
  2. demo_all_text_embedding_cache.pth +3 -0
  3. fcclip/fcclip.py +60 -21
app.py CHANGED
@@ -27,8 +27,6 @@ from detectron2.data import MetadataCatalog
27
  from detectron2.projects.deeplab import add_deeplab_config
28
 
29
 
30
- coco_metadata = MetadataCatalog.get("coco_2017_val_panoptic")
31
-
32
  # import FCCLIP project
33
  from fcclip import add_maskformer2_config, add_fcclip_config
34
  from demo.predictor import DefaultPredictor, OpenVocabVisualizer
@@ -46,6 +44,7 @@ add_maskformer2_config(cfg)
46
  add_fcclip_config(cfg)
47
  cfg.merge_from_file("configs/coco/panoptic-segmentation/fcclip/fcclip_convnext_large_eval_ade20k.yaml")
48
  os.system("gdown 1-91PIns86vyNaL3CzMmDD39zKGnPMtvj")
 
49
  cfg.MODEL.WEIGHTS = './fcclip_cocopan.pth'
50
  cfg.MODEL.MASK_FORMER.TEST.SEMANTIC_ON = False
51
  cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON = False
@@ -160,7 +159,7 @@ def inference(image_path, vocab, label_list):
160
 
161
  im = cv2.imread(image_path)
162
  outputs = predictor(im)
163
- v = OpenVocabVisualizer(im[:, :, ::-1], demo_metadata, scale=1.2, instance_mode=ColorMode.IMAGE_BW)
164
  panoptic_result = v.draw_panoptic_seg(outputs["panoptic_seg"][0].to("cpu"), outputs["panoptic_seg"][1]).get_image()
165
  return Image.fromarray(np.uint8(panoptic_result)).convert('RGB')
166
 
 
27
  from detectron2.projects.deeplab import add_deeplab_config
28
 
29
 
 
 
30
  # import FCCLIP project
31
  from fcclip import add_maskformer2_config, add_fcclip_config
32
  from demo.predictor import DefaultPredictor, OpenVocabVisualizer
 
44
  add_fcclip_config(cfg)
45
  cfg.merge_from_file("configs/coco/panoptic-segmentation/fcclip/fcclip_convnext_large_eval_ade20k.yaml")
46
  os.system("gdown 1-91PIns86vyNaL3CzMmDD39zKGnPMtvj")
47
+ os.system("gdown 1-91PIns86vyNaL3CzMmDD39zKGnPMtvj")
48
  cfg.MODEL.WEIGHTS = './fcclip_cocopan.pth'
49
  cfg.MODEL.MASK_FORMER.TEST.SEMANTIC_ON = False
50
  cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON = False
 
159
 
160
  im = cv2.imread(image_path)
161
  outputs = predictor(im)
162
+ v = OpenVocabVisualizer(im[:, :, ::-1], demo_metadata, scale=1.0, instance_mode=ColorMode.IMAGE)
163
  panoptic_result = v.draw_panoptic_seg(outputs["panoptic_seg"][0].to("cpu"), outputs["panoptic_seg"][1]).get_image()
164
  return Image.fromarray(np.uint8(panoptic_result)).convert('RGB')
165
 
demo_all_text_embedding_cache.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0ee4c83884a03f41e1078a5b0916f6a26606258c0031e4e22e74c93c6672e9c9
3
+ size 7848107
fcclip/fcclip.py CHANGED
@@ -18,6 +18,7 @@ from .modeling.matcher import HungarianMatcher
18
 
19
 
20
  from .modeling.transformer_decoder.fcclip_transformer_decoder import MaskPooling, get_classification_logits
 
21
  VILD_PROMPT = [
22
  "a photo of a {}.",
23
  "This is a photo of a {}",
@@ -35,6 +36,20 @@ VILD_PROMPT = [
35
  "There is a large {} in the scene.",
36
  ]
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  @META_ARCH_REGISTRY.register()
40
  class FCCLIP(nn.Module):
@@ -129,14 +144,15 @@ class FCCLIP(nn.Module):
129
  _, self.train_num_templates, self.train_class_names = self.prepare_class_names_from_metadata(train_metadata, train_metadata)
130
  self.category_overlapping_mask, self.test_num_templates, self.test_class_names = self.prepare_class_names_from_metadata(test_metadata, train_metadata)
131
 
 
 
 
 
 
 
 
 
132
  def prepare_class_names_from_metadata(self, metadata, train_metadata):
133
- def split_labels(x):
134
- res = []
135
- for x_ in x:
136
- x_ = x_.replace(', ', ',')
137
- x_ = x_.split(',') # there can be multiple synonyms for single class
138
- res.append(x_)
139
- return res
140
  # get text classifier
141
  try:
142
  class_names = split_labels(metadata.stuff_classes) # it includes both thing and stuff
@@ -152,13 +168,6 @@ class FCCLIP(nn.Module):
152
  category_overlapping_list.append(is_overlapping)
153
  category_overlapping_mask = torch.tensor(
154
  category_overlapping_list, dtype=torch.long)
155
-
156
- def fill_all_templates_ensemble(x_=''):
157
- res = []
158
- for x in x_:
159
- for template in VILD_PROMPT:
160
- res.append(template.format(x))
161
- return res, len(res) // len(VILD_PROMPT)
162
 
163
  num_templates = []
164
  templated_class_names = []
@@ -195,17 +204,47 @@ class FCCLIP(nn.Module):
195
  return self.train_text_classifier, self.train_num_templates
196
  else:
197
  if self.test_text_classifier is None:
 
 
 
 
 
 
 
 
 
198
  text_classifier = []
 
 
 
 
 
 
 
 
 
199
  # this is needed to avoid oom, which may happen when num of class is large
200
  bs = 128
201
- for idx in range(0, len(self.test_class_names), bs):
202
- text_classifier.append(self.backbone.get_text_classifier(self.test_class_names[idx:idx+bs], self.device).detach())
203
- text_classifier = torch.cat(text_classifier, dim=0)
 
 
 
 
 
 
 
 
 
 
 
204
 
205
- # average across templates and normalization.
206
- text_classifier /= text_classifier.norm(dim=-1, keepdim=True)
207
- text_classifier = text_classifier.reshape(text_classifier.shape[0]//len(VILD_PROMPT), len(VILD_PROMPT), text_classifier.shape[-1]).mean(1)
208
- text_classifier /= text_classifier.norm(dim=-1, keepdim=True)
 
209
  self.test_text_classifier = text_classifier
210
  return self.test_text_classifier, self.test_num_templates
211
 
 
18
 
19
 
20
  from .modeling.transformer_decoder.fcclip_transformer_decoder import MaskPooling, get_classification_logits
21
+ import os
22
  VILD_PROMPT = [
23
  "a photo of a {}.",
24
  "This is a photo of a {}",
 
36
  "There is a large {} in the scene.",
37
  ]
38
 
39
+ def split_labels(x):
40
+ res = []
41
+ for x_ in x:
42
+ x_ = x_.replace(', ', ',')
43
+ x_ = x_.split(',') # there can be multiple synonyms for single class
44
+ res.append(x_)
45
+ return res
46
+
47
+ def fill_all_templates_ensemble(x_=''):
48
+ res = []
49
+ for x in x_:
50
+ for template in VILD_PROMPT:
51
+ res.append(template.format(x))
52
+ return res, len(res) // len(VILD_PROMPT)
53
 
54
  @META_ARCH_REGISTRY.register()
55
  class FCCLIP(nn.Module):
 
144
  _, self.train_num_templates, self.train_class_names = self.prepare_class_names_from_metadata(train_metadata, train_metadata)
145
  self.category_overlapping_mask, self.test_num_templates, self.test_class_names = self.prepare_class_names_from_metadata(test_metadata, train_metadata)
146
 
147
+ self.demo_all_text_embedding_cache = {}
148
+ # This consists of COCO, ADE20K, LVIS
149
+ if os.path.exists("demo_all_text_embedding_cache.pth"):
150
+ # key: str of class name, value: tensor in shape of C
151
+ self.demo_all_text_embedding_cache = torch.load("demo_all_text_embedding_cache.pth", map_location=self.device)
152
+ self.demo_all_text_embedding_cache = {k:v.to(self.device) for k,v in self.demo_all_text_embedding_cache.items()}
153
+
154
+
155
  def prepare_class_names_from_metadata(self, metadata, train_metadata):
 
 
 
 
 
 
 
156
  # get text classifier
157
  try:
158
  class_names = split_labels(metadata.stuff_classes) # it includes both thing and stuff
 
168
  category_overlapping_list.append(is_overlapping)
169
  category_overlapping_mask = torch.tensor(
170
  category_overlapping_list, dtype=torch.long)
 
 
 
 
 
 
 
171
 
172
  num_templates = []
173
  templated_class_names = []
 
204
  return self.train_text_classifier, self.train_num_templates
205
  else:
206
  if self.test_text_classifier is None:
207
+ try:
208
+ nontemplated_class_names = split_labels(self.test_metadata.stuff_classes) # it includes both thing and stuff
209
+ except:
210
+ # this could be for insseg, where only thing_classes are available
211
+ nontemplated_class_names = split_labels(self.test_metadata.thing_classes)
212
+
213
+ text2classifier = {}
214
+ test_class_names = []
215
+ uncached_class_name = []
216
  text_classifier = []
217
+ # exclude those already in cache
218
+ for class_names in nontemplated_class_names:
219
+ for class_name in class_names:
220
+ if class_name in self.demo_all_text_embedding_cache:
221
+ text2classifier[class_name] = self.demo_all_text_embedding_cache[class_name].to(self.device)
222
+ else:
223
+ test_class_names += fill_all_templates_ensemble([class_name])[0]
224
+ uncached_class_name.append(class_name)
225
+ print("Uncached texts:", len(uncached_class_name), uncached_class_name, test_class_names)
226
  # this is needed to avoid oom, which may happen when num of class is large
227
  bs = 128
228
+ for idx in range(0, len(test_class_names), bs):
229
+ text_classifier.append(self.backbone.get_text_classifier(test_class_names[idx:idx+bs], self.device).detach())
230
+
231
+ if len(text_classifier) > 0:
232
+ text_classifier = torch.cat(text_classifier, dim=0)
233
+ # average across templates and normalization.
234
+ text_classifier /= text_classifier.norm(dim=-1, keepdim=True)
235
+ text_classifier = text_classifier.reshape(text_classifier.shape[0]//len(VILD_PROMPT), len(VILD_PROMPT), text_classifier.shape[-1]).mean(1)
236
+ text_classifier /= text_classifier.norm(dim=-1, keepdim=True)
237
+ assert text_classifier.shape[0] == len(uncached_class_name)
238
+ for idx in range(len(uncached_class_name)):
239
+ self.demo_all_text_embedding_cache[uncached_class_name[idx]] = text_classifier[idx]
240
+ text2classifier[uncached_class_name[idx]] = text_classifier[idx]
241
+ #torch.save({k:v for k, v in self.demo_all_text_embedding_cache.items()}, "demo_all_text_embedding_cache.pth")
242
 
243
+ text_classifier = []
244
+ for class_names in nontemplated_class_names:
245
+ for text in class_names:
246
+ text_classifier.append(text2classifier[text].to(self.device))
247
+ text_classifier = torch.stack(text_classifier, dim=0).to(self.device)
248
  self.test_text_classifier = text_classifier
249
  return self.test_text_classifier, self.test_num_templates
250