SajjadAyoubi commited on
Commit
6b7167b
1 Parent(s): 99c7487

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -11
app.py CHANGED
@@ -1,24 +1,21 @@
1
  import streamlit as st
2
- import pandas as pd, numpy as np
3
  from html import escape
4
  import os
5
  import torch
6
  from transformers import RobertaModel, AutoTokenizer
7
 
8
 
9
- @st.cache(show_spinner=False,
10
- hash_funcs={text_encoder: lambda _: None,
11
- tokenizer: lambda _: None,
12
- dict: lambda _: None})
13
  def load():
14
  text_encoder = RobertaModel.from_pretrained('SajjadAyoubi/clip-fa-text')
15
  tokenizer = AutoTokenizer.from_pretrained('SajjadAyoubi/clip-fa-text')
16
- df = pd.read_csv('data.csv')
17
- image_embeddings = np.load('embeddings.npy')
18
- return text_encoder, tokenizer, df, image_embeddings
19
 
20
 
21
- text_encoder, tokenizer, df, image_embeddings = load()
22
 
23
 
24
  def get_html(url_list, height=224):
@@ -34,12 +31,12 @@ def get_html(url_list, height=224):
34
  return html
35
 
36
 
37
- st.cache(show_spinner=False)
38
  def image_search(query, top_k=8):
39
  with torch.no_grad():
40
  text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output
41
  values, indices = torch.cosine_similarity(text_embedding, image_embeddings).sort(descending=True)
42
- return [(df.iloc[i]['path'], df.iloc[i]['link']) for i in indices[:top_k]]
43
 
44
 
45
  description = '''
 
1
  import streamlit as st
2
+ import pandas as pd
3
  from html import escape
4
  import os
5
  import torch
6
  from transformers import RobertaModel, AutoTokenizer
7
 
8
 
9
+ @st.cache(show_spinner=False)
 
 
 
10
  def load():
11
  text_encoder = RobertaModel.from_pretrained('SajjadAyoubi/clip-fa-text')
12
  tokenizer = AutoTokenizer.from_pretrained('SajjadAyoubi/clip-fa-text')
13
+ link_df = pd.read_csv('links.csv')
14
+ image_embeddings = torch.load('embeddings.pt')
15
+ return text_encoder, tokenizer, link_df, image_embeddings
16
 
17
 
18
+ text_encoder, tokenizer, link_df, image_embeddings = load()
19
 
20
 
21
  def get_html(url_list, height=224):
 
31
  return html
32
 
33
 
34
+ @st.cache(show_spinner=False)
35
  def image_search(query, top_k=8):
36
  with torch.no_grad():
37
  text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output
38
  values, indices = torch.cosine_similarity(text_embedding, image_embeddings).sort(descending=True)
39
+ return [(link_df.iloc[i]['path'], link_df.iloc[i]['link']) for i in indices[:top_k]]
40
 
41
 
42
  description = '''