kaisugi commited on
Commit
4a783fe
1 Parent(s): 3f9f23d
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -17,7 +17,7 @@ def average_pool(last_hidden_states: Tensor,
17
  return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
18
 
19
 
20
- @st.cache(allow_output_mutation=True)
21
  def load_model_and_tokenizer():
22
  tokenizer = AutoTokenizer.from_pretrained('intfloat/multilingual-e5-large')
23
  model = AutoModel.from_pretrained('intfloat/multilingual-e5-large')
@@ -26,14 +26,14 @@ def load_model_and_tokenizer():
26
  return model, tokenizer
27
 
28
 
29
- @st.cache(allow_output_mutation=True)
30
  def load_title_data():
31
  title_df = pd.read_csv('anlp2024.tsv', names=["pid", "title"], sep="\t")
32
 
33
  return title_df
34
 
35
 
36
- @st.cache(allow_output_mutation=True)
37
  def load_title_embeddings():
38
  npz_comp = np.load("anlp2024.npz")
39
  title_embeddings = npz_comp["arr_0"]
@@ -41,13 +41,13 @@ def load_title_embeddings():
41
  return title_embeddings
42
 
43
 
44
- @st.cache
45
  def get_retrieval_results(index, input_text, top_k, tokenizer, title_df):
46
  batch_dict = tokenizer(f"query: {input_text}", max_length=512, padding=True, truncation=True, return_tensors='pt')
47
  with torch.no_grad():
48
  outputs = model(**batch_dict)
49
- embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
50
- embeddings = F.normalize(embeddings, p=2, dim=1)
51
 
52
  _, ids = index.search(x=np.array([query_embeddings]), k=top_k)
53
  retrieved_titles = []
 
17
  return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
18
 
19
 
20
+ @st.cache_resource
21
  def load_model_and_tokenizer():
22
  tokenizer = AutoTokenizer.from_pretrained('intfloat/multilingual-e5-large')
23
  model = AutoModel.from_pretrained('intfloat/multilingual-e5-large')
 
26
  return model, tokenizer
27
 
28
 
29
+ @st.cache_resource
30
  def load_title_data():
31
  title_df = pd.read_csv('anlp2024.tsv', names=["pid", "title"], sep="\t")
32
 
33
  return title_df
34
 
35
 
36
+ @st.cache_resource
37
  def load_title_embeddings():
38
  npz_comp = np.load("anlp2024.npz")
39
  title_embeddings = npz_comp["arr_0"]
 
41
  return title_embeddings
42
 
43
 
44
+ @st.cache_data
45
  def get_retrieval_results(index, input_text, top_k, tokenizer, title_df):
46
  batch_dict = tokenizer(f"query: {input_text}", max_length=512, padding=True, truncation=True, return_tensors='pt')
47
  with torch.no_grad():
48
  outputs = model(**batch_dict)
49
+ query_embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
50
+ query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
51
 
52
  _, ids = index.search(x=np.array([query_embeddings]), k=top_k)
53
  retrieved_titles = []