AAAAAAyq commited on
Commit
2f10180
1 Parent(s): 5350ba4

Update the examples

Browse files
Files changed (1) hide show
  1. tools.py +27 -27
tools.py CHANGED
@@ -3,7 +3,7 @@ from PIL import Image
3
  import matplotlib.pyplot as plt
4
  import cv2
5
  import torch
6
- import clip
7
 
8
 
9
  def convert_box_xywh_to_xyxy(box):
@@ -290,20 +290,20 @@ def fast_show_mask_gpu(
290
  return mask_cpu
291
 
292
 
293
- # clip
294
- @torch.no_grad()
295
- def retriev(
296
- model, preprocess, elements, search_text: str, device
297
- ) -> int:
298
- preprocessed_images = [preprocess(image).to(device) for image in elements]
299
- tokenized_text = clip.tokenize([search_text]).to(device)
300
- stacked_images = torch.stack(preprocessed_images)
301
- image_features = model.encode_image(stacked_images)
302
- text_features = model.encode_text(tokenized_text)
303
- image_features /= image_features.norm(dim=-1, keepdim=True)
304
- text_features /= text_features.norm(dim=-1, keepdim=True)
305
- probs = 100.0 * image_features @ text_features.T
306
- return probs[:, 0].softmax(dim=0)
307
 
308
 
309
  def crop_image(annotations, image_path):
@@ -381,15 +381,15 @@ def point_prompt(masks, points, pointlabel, target_height, target_width): # num
381
  return onemask, 0
382
 
383
 
384
- def text_prompt(annotations, args):
385
- cropped_boxes, cropped_images, not_crop, filter_id, annotaions = crop_image(
386
- annotations, args.img_path
387
- )
388
- clip_model, preprocess = clip.load("ViT-B/32", device=args.device)
389
- scores = retriev(
390
- clip_model, preprocess, cropped_boxes, args.text_prompt, device=args.device
391
- )
392
- max_idx = scores.argsort()
393
- max_idx = max_idx[-1]
394
- max_idx += sum(np.array(filter_id) <= int(max_idx))
395
- return annotaions[max_idx]["segmentation"], max_idx
 
3
  import matplotlib.pyplot as plt
4
  import cv2
5
  import torch
6
+ # import clip
7
 
8
 
9
  def convert_box_xywh_to_xyxy(box):
 
290
  return mask_cpu
291
 
292
 
293
+ # # clip
294
+ # @torch.no_grad()
295
+ # def retriev(
296
+ # model, preprocess, elements, search_text: str, device
297
+ # ) -> int:
298
+ # preprocessed_images = [preprocess(image).to(device) for image in elements]
299
+ # tokenized_text = clip.tokenize([search_text]).to(device)
300
+ # stacked_images = torch.stack(preprocessed_images)
301
+ # image_features = model.encode_image(stacked_images)
302
+ # text_features = model.encode_text(tokenized_text)
303
+ # image_features /= image_features.norm(dim=-1, keepdim=True)
304
+ # text_features /= text_features.norm(dim=-1, keepdim=True)
305
+ # probs = 100.0 * image_features @ text_features.T
306
+ # return probs[:, 0].softmax(dim=0)
307
 
308
 
309
  def crop_image(annotations, image_path):
 
381
  return onemask, 0
382
 
383
 
384
+ # def text_prompt(annotations, args):
385
+ # cropped_boxes, cropped_images, not_crop, filter_id, annotaions = crop_image(
386
+ # annotations, args.img_path
387
+ # )
388
+ # clip_model, preprocess = clip.load("ViT-B/32", device=args.device)
389
+ # scores = retriev(
390
+ # clip_model, preprocess, cropped_boxes, args.text_prompt, device=args.device
391
+ # )
392
+ # max_idx = scores.argsort()
393
+ # max_idx = max_idx[-1]
394
+ # max_idx += sum(np.array(filter_id) <= int(max_idx))
395
+ # return annotaions[max_idx]["segmentation"], max_idx