hsshin98 commited on
Commit
aff8d56
1 Parent(s): f98e690
app.py CHANGED
@@ -41,7 +41,6 @@ def setup_cfg(args):
41
  add_cat_seg_config(cfg)
42
  cfg.merge_from_file(args.config_file)
43
  cfg.merge_from_list(args.opts)
44
- cfg.MODEL.DEVICE = "cpu"
45
  cfg.freeze()
46
  return cfg
47
 
@@ -67,7 +66,10 @@ def get_parser():
67
  "MODEL.SEM_SEG_HEAD.TRAIN_CLASS_JSON", "datasets/voc20.json",
68
  "MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON", "datasets/voc20.json",
69
  "TEST.SLIDING_WINDOW", "True",
70
- "MODEL.SEM_SEG_HEAD.POOLING_SIZES", "[1,1]"],
 
 
 
71
  nargs=argparse.REMAINDER,
72
  )
73
  return parser
@@ -75,7 +77,7 @@ def get_parser():
75
  def save_masks(preds, text):
76
  preds = preds['sem_seg'].argmax(dim=0).cpu().numpy() # C H W
77
  for i, t in enumerate(text):
78
- dir = f"masks/mask_{t}.png"
79
  mask = preds == i
80
  cv2.imwrite(dir, mask * 255)
81
 
@@ -84,7 +86,7 @@ def predict(image, text):
84
  cfg = setup_cfg(args)
85
  demo = VisualizationDemo(cfg, text=text)
86
  predictions, visualized_output = demo.run_on_image(image)
87
- # save_masks(predictions, text.split(','))
88
  canvas = fc(visualized_output.fig)
89
  canvas.draw()
90
  out = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape(canvas.get_width_height()[::-1] + (3,))
@@ -97,7 +99,12 @@ if __name__ == "__main__":
97
 
98
  iface = gr.Interface(
99
  fn=predict,
100
- inputs=[gr.Image(), gr.Textbox(placeholder="Classes to segment")],
101
  outputs="image",
102
- )
 
 
 
 
 
103
  iface.launch()
 
41
  add_cat_seg_config(cfg)
42
  cfg.merge_from_file(args.config_file)
43
  cfg.merge_from_list(args.opts)
 
44
  cfg.freeze()
45
  return cfg
46
 
 
66
  "MODEL.SEM_SEG_HEAD.TRAIN_CLASS_JSON", "datasets/voc20.json",
67
  "MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON", "datasets/voc20.json",
68
  "TEST.SLIDING_WINDOW", "True",
69
+ "MODEL.SEM_SEG_HEAD.POOLING_SIZES", "[1,1]",
70
+ "MODEL.DEVICE", "cpu",
71
+ "MODEL.PROMPT_ENSEMBLE_TYPE", "single"
72
+ ],
73
  nargs=argparse.REMAINDER,
74
  )
75
  return parser
 
77
  def save_masks(preds, text):
78
  preds = preds['sem_seg'].argmax(dim=0).cpu().numpy() # C H W
79
  for i, t in enumerate(text):
80
+ dir = f"mask_{t}.png"
81
  mask = preds == i
82
  cv2.imwrite(dir, mask * 255)
83
 
 
86
  cfg = setup_cfg(args)
87
  demo = VisualizationDemo(cfg, text=text)
88
  predictions, visualized_output = demo.run_on_image(image)
89
+ #save_masks(predictions, text.split(','))
90
  canvas = fc(visualized_output.fig)
91
  canvas.draw()
92
  out = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape(canvas.get_width_height()[::-1] + (3,))
 
99
 
100
  iface = gr.Interface(
101
  fn=predict,
102
+ inputs=[gr.Image(), gr.Textbox(placeholder='cat, person, background')],
103
  outputs="image",
104
+ description="""## CAT-Seg Demo
105
+ Welcome to the CAT-Seg Demo! Here, we present the CAT-Seg with ViT-L model for open-vocabulary semantic segmentation.
106
+
107
+ Please note that this is an optimized version of the full model, and as such, its performance may be limited compared to the full model.
108
+
109
+ To get started, simply upload an image and a comma-separated list of categories, and let the model work its magic!""")
110
  iface.launch()
cat_seg/modeling/transformer/cat_seg_predictor.py CHANGED
@@ -50,13 +50,13 @@ class CATSegPredictor(nn.Module):
50
 
51
  import json
52
  # use class_texts in train_forward, and test_class_texts in test_forward
53
- with open(train_class_json, 'r') as f_in:
54
- self.class_texts = json.load(f_in)
55
- with open(test_class_json, 'r') as f_in:
56
- self.test_class_texts = json.load(f_in)
57
- assert self.class_texts != None
58
- if self.test_class_texts == None:
59
- self.test_class_texts = self.class_texts
60
  device = "cuda" if torch.cuda.is_available() else "cpu"
61
  self.device = device
62
  self.tokenizer = None
@@ -84,12 +84,12 @@ class CATSegPredictor(nn.Module):
84
  prompt_templates = ['A photo of a {} in the scene',]
85
  else:
86
  raise NotImplementedError
 
 
 
87
 
88
  self.clip_model = clip_model.float()
89
  self.clip_preprocess = clip_preprocess
90
-
91
- self.text_features = self.class_embeddings(self.class_texts, prompt_templates, clip_model).permute(1, 0, 2).float()
92
- self.text_features_test = self.class_embeddings(self.test_class_texts, prompt_templates, clip_model).permute(1, 0, 2).float()
93
 
94
  transformer = Aggregator(
95
  text_guidance_dim=text_guidance_dim,
 
50
 
51
  import json
52
  # use class_texts in train_forward, and test_class_texts in test_forward
53
+ #with open(train_class_json, 'r') as f_in:
54
+ # self.class_texts = json.load(f_in)
55
+ #with open(test_class_json, 'r') as f_in:
56
+ # self.test_class_texts = json.load(f_in)
57
+ #assert self.class_texts != None
58
+ #if self.test_class_texts == None:
59
+ # self.test_class_texts = self.class_texts
60
  device = "cuda" if torch.cuda.is_available() else "cpu"
61
  self.device = device
62
  self.tokenizer = None
 
84
  prompt_templates = ['A photo of a {} in the scene',]
85
  else:
86
  raise NotImplementedError
87
+
88
+ #self.text_features = self.class_embeddings(self.class_texts, prompt_templates, clip_model).permute(1, 0, 2).float()
89
+ #self.text_features_test = self.class_embeddings(self.test_class_texts, prompt_templates, clip_model).permute(1, 0, 2).float()
90
 
91
  self.clip_model = clip_model.float()
92
  self.clip_preprocess = clip_preprocess
 
 
 
93
 
94
  transformer = Aggregator(
95
  text_guidance_dim=text_guidance_dim,
demo/predictor.py CHANGED
@@ -43,8 +43,9 @@ class VisualizationDemo(object):
43
  pred = self.predictor.model.sem_seg_head.predictor
44
  pred.test_class_texts = text.split(',')
45
  pred.text_features_test = pred.class_embeddings(pred.test_class_texts,
46
- imagenet_templates.IMAGENET_TEMPLATES,
47
- pred.clip_model).permute(1, 0, 2).float()
 
48
  self.metadata = ns()
49
  self.metadata.stuff_classes = pred.test_class_texts
50
 
 
43
  pred = self.predictor.model.sem_seg_head.predictor
44
  pred.test_class_texts = text.split(',')
45
  pred.text_features_test = pred.class_embeddings(pred.test_class_texts,
46
+ #imagenet_templates.IMAGENET_TEMPLATES,
47
+ ['A photo of a {} in the scene',],
48
+ pred.clip_model).permute(1, 0, 2).float().repeat(1, 80, 1)
49
  self.metadata = ns()
50
  self.metadata.stuff_classes = pred.test_class_texts
51