grostaco commited on
Commit
c5c3fa2
1 Parent(s): ef3b87e

feat: add cosine similarity measure

Browse files
Files changed (2) hide show
  1. app.py +7 -5
  2. lib/utils/model.py +0 -1
app.py CHANGED
@@ -3,6 +3,7 @@ from lib.utils.model import get_model, get_similarities
3
  from PIL import Image
4
 
5
  st.title('IRRA Text-To-Image-Retrival')
 
6
 
7
  st.header('Inputs')
8
  caption = st.text_input('Description Input')
@@ -12,7 +13,7 @@ if images is not None:
12
  st.image(images) # type: ignore
13
 
14
  st.header('Options')
15
- st.subheader('Ranks')
16
 
17
  ranks = st.slider('slider_ranks', min_value=1, max_value=10, label_visibility='collapsed',value=5)
18
 
@@ -26,15 +27,16 @@ if button:
26
  st.text(f'IRRA model loaded with {sum(p.numel() for p in model.parameters()) / 1e6:.0f}M parameters')
27
 
28
  with st.spinner('Computing and ranking similarities'):
29
- similarities = get_similarities(caption, images, model)
30
 
31
- indices = similarities.argsort(descending=True).squeeze(0).cpu().tolist()[:ranks]
32
 
33
  for i, idx in enumerate(indices):
34
- c1, c2 = st.columns(2)
35
  with c1:
36
  st.text(f'Rank {i + 1}')
37
  with c2:
38
  st.image(images[idx])
39
-
 
40
 
 
3
  from PIL import Image
4
 
5
  st.title('IRRA Text-To-Image-Retrival')
6
+ st.markdown('A text-to-image retrieval model implemented from [arXiv: Cross-Modal Implicit Relation Reasoning and Aligning for Text-to-Image Person Retrieval](https://arxiv.org/abs/2303.12501)')
7
 
8
  st.header('Inputs')
9
  caption = st.text_input('Description Input')
 
13
  st.image(images) # type: ignore
14
 
15
  st.header('Options')
16
+ st.subheader('Ranks', help='How many predictions the model is allowed to make')
17
 
18
  ranks = st.slider('slider_ranks', min_value=1, max_value=10, label_visibility='collapsed',value=5)
19
 
 
27
  st.text(f'IRRA model loaded with {sum(p.numel() for p in model.parameters()) / 1e6:.0f}M parameters')
28
 
29
  with st.spinner('Computing and ranking similarities'):
30
+ similarities = get_similarities(caption, images, model).squeeze(0)
31
 
32
+ indices = similarities.argsort(descending=True).cpu().tolist()[:ranks]
33
 
34
  for i, idx in enumerate(indices):
35
+ c1, c2, c3 = st.columns(3)
36
  with c1:
37
  st.text(f'Rank {i + 1}')
38
  with c2:
39
  st.image(images[idx])
40
+ with c3:
41
+ st.text(f'Cosine sim {similarities[idx].cpu():.2f}')
42
 
lib/utils/model.py CHANGED
@@ -24,7 +24,6 @@ def get_similarities(text: str, images: list[str], model: IRRA) -> torch.Tensor:
24
  txt = tokenize(text, tokenizer)
25
  imgs = prepare_images(images)
26
 
27
- print(imgs.shape)
28
  image_feats = model.encode_image(imgs)
29
  text_feats = model.encode_text(txt.unsqueeze(0))
30
 
 
24
  txt = tokenize(text, tokenizer)
25
  imgs = prepare_images(images)
26
 
 
27
  image_feats = model.encode_image(imgs)
28
  text_feats = model.encode_text(txt.unsqueeze(0))
29