jwyang commited on
Commit
294ea6c
1 Parent(s): 92144aa

average over multiple prompts

Browse files
detectron2/modeling/meta_arch/clip_rcnn.py CHANGED
@@ -751,8 +751,10 @@ class CLIPFastRCNN(nn.Module):
751
  text_features = self.backbone.encode_text(queries)
752
  else:
753
  features = self.backbone(images.tensor)
754
- token_embeddings = pre_tokenize([queries])[:, 0].to(images.tensor.device)
755
  text_features = self.lang_encoder.encode_text(token_embeddings)
 
 
756
 
757
  if self.backbone_type == "resnet":
758
  head = self.backbone.layer4
 
751
  text_features = self.backbone.encode_text(queries)
752
  else:
753
  features = self.backbone(images.tensor)
754
+ token_embeddings = pre_tokenize([queries]).to(images.tensor.device)[0]
755
  text_features = self.lang_encoder.encode_text(token_embeddings)
756
+ text_features = text_features.mean(0, keepdim=True)
757
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
758
 
759
  if self.backbone_type == "resnet":
760
  head = self.backbone.layer4