Spaces:
Sleeping
Sleeping
fix
Browse files
app.py
CHANGED
@@ -17,7 +17,7 @@ def average_pool(last_hidden_states: Tensor,
|
|
17 |
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
18 |
|
19 |
|
20 |
-
@st.
|
21 |
def load_model_and_tokenizer():
|
22 |
tokenizer = AutoTokenizer.from_pretrained('intfloat/multilingual-e5-large')
|
23 |
model = AutoModel.from_pretrained('intfloat/multilingual-e5-large')
|
@@ -26,14 +26,14 @@ def load_model_and_tokenizer():
|
|
26 |
return model, tokenizer
|
27 |
|
28 |
|
29 |
-
@st.
|
30 |
def load_title_data():
|
31 |
title_df = pd.read_csv('anlp2024.tsv', names=["pid", "title"], sep="\t")
|
32 |
|
33 |
return title_df
|
34 |
|
35 |
|
36 |
-
@st.
|
37 |
def load_title_embeddings():
|
38 |
npz_comp = np.load("anlp2024.npz")
|
39 |
title_embeddings = npz_comp["arr_0"]
|
@@ -41,13 +41,13 @@ def load_title_embeddings():
|
|
41 |
return title_embeddings
|
42 |
|
43 |
|
44 |
-
@st.
|
45 |
def get_retrieval_results(index, input_text, top_k, tokenizer, title_df):
|
46 |
batch_dict = tokenizer(f"query: {input_text}", max_length=512, padding=True, truncation=True, return_tensors='pt')
|
47 |
with torch.no_grad():
|
48 |
outputs = model(**batch_dict)
|
49 |
-
|
50 |
-
|
51 |
|
52 |
_, ids = index.search(x=np.array([query_embeddings]), k=top_k)
|
53 |
retrieved_titles = []
|
|
|
17 |
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
18 |
|
19 |
|
20 |
+
@st.cache_resource
|
21 |
def load_model_and_tokenizer():
|
22 |
tokenizer = AutoTokenizer.from_pretrained('intfloat/multilingual-e5-large')
|
23 |
model = AutoModel.from_pretrained('intfloat/multilingual-e5-large')
|
|
|
26 |
return model, tokenizer
|
27 |
|
28 |
|
29 |
+
@st.cache_resource
|
30 |
def load_title_data():
|
31 |
title_df = pd.read_csv('anlp2024.tsv', names=["pid", "title"], sep="\t")
|
32 |
|
33 |
return title_df
|
34 |
|
35 |
|
36 |
+
@st.cache_resource
|
37 |
def load_title_embeddings():
|
38 |
npz_comp = np.load("anlp2024.npz")
|
39 |
title_embeddings = npz_comp["arr_0"]
|
|
|
41 |
return title_embeddings
|
42 |
|
43 |
|
44 |
+
@st.cache_data
|
45 |
def get_retrieval_results(index, input_text, top_k, tokenizer, title_df):
|
46 |
batch_dict = tokenizer(f"query: {input_text}", max_length=512, padding=True, truncation=True, return_tensors='pt')
|
47 |
with torch.no_grad():
|
48 |
outputs = model(**batch_dict)
|
49 |
+
query_embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
|
50 |
+
query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
|
51 |
|
52 |
_, ids = index.search(x=np.array([query_embeddings]), k=top_k)
|
53 |
retrieved_titles = []
|