grostaco commited on
Commit
7e39033
1 Parent(s): 3617c74

feat: normalize cosine similarity

Browse files
Files changed (1) hide show
  1. lib/utils/model.py +4 -0
lib/utils/model.py CHANGED
@@ -1,6 +1,7 @@
1
  import streamlit as st
2
  import yaml
3
  import torch
 
4
 
5
  from lib.IRRA.tokenizer import tokenize, SimpleTokenizer
6
  from lib.IRRA.image import prepare_images
@@ -27,4 +28,7 @@ def get_similarities(text: str, images: list[str], model: IRRA) -> torch.Tensor:
27
  image_feats = model.encode_image(imgs)
28
  text_feats = model.encode_text(txt.unsqueeze(0))
29
 
 
 
 
30
  return text_feats @ image_feats.t()
 
1
  import streamlit as st
2
  import yaml
3
  import torch
4
+ import torch.nn.functional as F
5
 
6
  from lib.IRRA.tokenizer import tokenize, SimpleTokenizer
7
  from lib.IRRA.image import prepare_images
 
28
  image_feats = model.encode_image(imgs)
29
  text_feats = model.encode_text(txt.unsqueeze(0))
30
 
31
+ image_feats = F.normalize(image_feats, p=2, dim=1)
32
+ text_feats = F.normalize(text_feats, p=2, dim=1)
33
+
34
  return text_feats @ image_feats.t()