cheng commited on
Commit
863c45d
1 Parent(s): baab402

update PIL image

Browse files
Files changed (3) hide show
  1. clip_component.py +2 -2
  2. detector.py +3 -4
  3. grounding_component.py +2 -2
clip_component.py CHANGED
@@ -17,8 +17,8 @@ def get_token_from_clip(image):
17
  text_features = model.encode_text(text_tokens).float()
18
  text_features /= text_features.norm(dim=-1, keepdim=True)
19
 
20
- image_pil = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
21
- image_input = preprocess(image).unsqueeze(0).to(device) # Add batch dimension
22
 
23
  with torch.no_grad():
24
  image_feature = model.encode_image(image_input)
 
17
  text_features = model.encode_text(text_tokens).float()
18
  text_features /= text_features.norm(dim=-1, keepdim=True)
19
 
20
+ image_pil = Image.fromarray(image.astype('uint8'))
21
+ image_input = preprocess(image_pil).unsqueeze(0).to(device) # Add batch dimension
22
 
23
  with torch.no_grad():
24
  image_feature = model.encode_image(image_input)
detector.py CHANGED
@@ -2,8 +2,7 @@ from clip_component import get_token_from_clip
2
  from grounding_component import run_grounding
3
 
4
  def detect(image):
5
- token = get_token_from_clip(image)
6
- print('token')
7
- print(token)
8
- predict_image = run_grounding(image,token)
9
  return predict_image
 
2
  from grounding_component import run_grounding
3
 
4
  def detect(image):
5
+ describe = get_token_from_clip(image)
6
+ print('describe:',describe)
7
+ predict_image = run_grounding(image,describe)
 
8
  return predict_image
grounding_component.py CHANGED
@@ -57,10 +57,10 @@ def image_transform_grounding_for_vis(init_image):
57
 
58
  model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
59
 
60
- def run_grounding(input_image, token):
61
  pil_img = Image.fromarray(input_image)
62
  init_image = pil_img.convert("RGB")
63
- grounding_caption = "token"
64
  box_threshold = 0.25
65
  text_threshold = 0.25
66
 
 
57
 
58
  model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
59
 
60
+ def run_grounding(input_image, describe):
61
  pil_img = Image.fromarray(input_image)
62
  init_image = pil_img.convert("RGB")
63
+ grounding_caption = describe
64
  box_threshold = 0.25
65
  text_threshold = 0.25
66