JeffLiang commited on
Commit
8c62972
1 Parent(s): ba09e2c

change sam_vit_h to sam_vit_l to save memory

Browse files
app.py CHANGED
@@ -45,7 +45,7 @@ def inference(class_names, proposal_gen, granularity, input_img):
45
  if proposal_gen == 'MaskFormer':
46
  demo = VisualizationDemo(cfg)
47
  elif proposal_gen == 'Segment_Anything':
48
- demo = SAMVisualizationDemo(cfg, granularity, './sam_vit_h_4b8939.pth', './ovseg_clip_l_9a1909.pth')
49
  class_names = class_names.split(',')
50
  img = read_image(input_img, format="BGR")
51
  _, visualized_output = demo.run_on_image(img, class_names)
 
45
  if proposal_gen == 'MaskFormer':
46
  demo = VisualizationDemo(cfg)
47
  elif proposal_gen == 'Segment_Anything':
48
+ demo = SAMVisualizationDemo(cfg, granularity, './sam_vit_l_0b3195.pth', './ovseg_clip_l_9a1909.pth')
49
  class_names = class_names.split(',')
50
  img = read_image(input_img, format="BGR")
51
  _, visualized_output = demo.run_on_image(img, class_names)
open_vocab_seg/utils/predictor.py CHANGED
@@ -150,7 +150,7 @@ class SAMVisualizationDemo(object):
150
 
151
  self.parallel = parallel
152
  self.granularity = granularity
153
- sam = sam_model_registry["vit_h"](checkpoint=sam_path).cuda()
154
  self.predictor = SamAutomaticMaskGenerator(sam, points_per_batch=16)
155
  self.clip_model, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained=ovsegclip_path)
156
  self.clip_model.cuda()
@@ -189,12 +189,17 @@ class SAMVisualizationDemo(object):
189
  txts = [f'a photo of {cls_name}' for cls_name in class_names]
190
  text = open_clip.tokenize(txts)
191
 
 
 
192
  with torch.no_grad(), torch.cuda.amp.autocast():
193
- image_features = self.clip_model.encode_image(imgs.cuda().half())
194
  text_features = self.clip_model.encode_text(text.cuda())
195
- image_features /= image_features.norm(dim=-1, keepdim=True)
196
  text_features /= text_features.norm(dim=-1, keepdim=True)
197
-
 
 
 
 
 
198
  class_preds = (100.0 * image_features @ text_features.T).softmax(dim=-1)
199
  select_cls = torch.zeros_like(class_preds)
200
 
 
150
 
151
  self.parallel = parallel
152
  self.granularity = granularity
153
+ sam = sam_model_registry["vit_l"](checkpoint=sam_path).cuda()
154
  self.predictor = SamAutomaticMaskGenerator(sam, points_per_batch=16)
155
  self.clip_model, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained=ovsegclip_path)
156
  self.clip_model.cuda()
 
189
  txts = [f'a photo of {cls_name}' for cls_name in class_names]
190
  text = open_clip.tokenize(txts)
191
 
192
+ img_batches = torch.split(imgs, 32, dim=0)
193
+
194
  with torch.no_grad(), torch.cuda.amp.autocast():
 
195
  text_features = self.clip_model.encode_text(text.cuda())
 
196
  text_features /= text_features.norm(dim=-1, keepdim=True)
197
+ image_features = []
198
+ for img_batch in img_batches:
199
+ image_feat = self.clip_model.encode_image(img_batch.cuda().half())
200
+ image_feat /= image_feat.norm(dim=-1, keepdim=True)
201
+ image_features.append(image_feat.detach())
202
+ image_features = torch.cat(image_features, dim=0)
203
  class_preds = (100.0 * image_features @ text_features.T).softmax(dim=-1)
204
  select_cls = torch.zeros_like(class_preds)
205
 
sam_vit_l_0b3195.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3adcc4315b642a4d2101128f611684e8734c41232a17c648ed1693702a49a622
3
+ size 1249524607