Spaces:
Runtime error
Runtime error
secilozksen
commited on
Commit
•
30dce9f
1
Parent(s):
18665b8
files updated
Browse files- basecamp-dpr-contriever-embeddings.pkl +3 -0
- basecamp.csv +0 -0
- demo_dpr.py +16 -57
- st-context-embeddings.pkl +2 -2
basecamp-dpr-contriever-embeddings.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:413837017c7b17e8e44556d9ab0cc9d42c9b24d3d28b29a39f3e7e143bd9f482
|
3 |
+
size 856086
|
basecamp.csv
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
demo_dpr.py
CHANGED
@@ -7,7 +7,7 @@ from sentence_transformers.cross_encoder import CrossEncoder
|
|
7 |
from st_aggrid import GridOptionsBuilder, AgGrid
|
8 |
import pickle
|
9 |
import torch
|
10 |
-
from transformers import
|
11 |
from pathlib import Path
|
12 |
import base64
|
13 |
import io
|
@@ -20,13 +20,11 @@ DATAFRAME_FILE_BSBS = 'basecamp.csv'
|
|
20 |
selectbox_selections = {
|
21 |
'Retrieve - Rerank (with fine-tuned cross-encoder)': 1,
|
22 |
'Dense Passage Retrieval':2,
|
23 |
-
'Retrieve - Reranking with DPR':3,
|
24 |
'Retrieve - Rerank':4
|
25 |
}
|
26 |
imagebox_selections = {
|
27 |
'Retrieve - Rerank (with fine-tuned cross-encoder)': 'Retrieve-rerank-trained-cross-encoder.png',
|
28 |
'Dense Passage Retrieval': 'DPR_pipeline.png',
|
29 |
-
'Retrieve - Reranking with DPR': 'Retrieve-rerank-DPR.png',
|
30 |
'Retrieve - Rerank': 'retrieve-rerank.png'
|
31 |
}
|
32 |
|
@@ -63,7 +61,7 @@ class CPU_Unpickler(pickle.Unpickler):
|
|
63 |
@st.cache(show_spinner=False, allow_output_mutation=True)
|
64 |
def load_paragraphs(path):
|
65 |
with open(path, "rb") as fIn:
|
66 |
-
cache_data =
|
67 |
corpus_sentences = cache_data['contexes']
|
68 |
corpus_embeddings = cache_data['embeddings']
|
69 |
|
@@ -84,45 +82,25 @@ def dot_product(question_output, context_output):
|
|
84 |
result = torch.dot(mat1, mat2)
|
85 |
return result
|
86 |
|
87 |
-
def retrieve_rerank_DPR(question):
|
88 |
-
hits = retrieve(question)
|
89 |
-
return rerank_with_DPR(hits, question)
|
90 |
-
|
91 |
-
def DPR_reranking(question, selected_contexes, selected_embeddings):
|
92 |
-
scores = []
|
93 |
-
tokenized_question = question_tokenizer(question, padding=True, truncation=True, return_tensors="pt",
|
94 |
-
add_special_tokens=True)
|
95 |
-
question_output = dpr_trained.model.question_model(**tokenized_question)
|
96 |
-
question_output = question_output['pooler_output']
|
97 |
-
for context_embedding in selected_embeddings:
|
98 |
-
score = dot_product(question_output, context_embedding)
|
99 |
-
scores.append(score.detach().cpu())
|
100 |
-
|
101 |
-
scores_index = sorted(range(len(scores)), key=lambda x: scores[x], reverse=True)
|
102 |
-
contexes_list = []
|
103 |
-
scores_final = []
|
104 |
-
for i, idx in enumerate(scores_index[:5]):
|
105 |
-
scores_final.append(scores[idx])
|
106 |
-
contexes_list.append(selected_contexes[idx])
|
107 |
-
return scores_final, contexes_list
|
108 |
-
|
109 |
def search_pipeline(question, search_method):
|
110 |
if search_method == 1: #Retrieve - rerank with fine-tuned cross encoder
|
111 |
return retrieve_rerank_with_trained_cross_encoder(question)
|
112 |
if search_method == 2:
|
113 |
return custom_dpr_pipeline(question) # DPR only
|
114 |
-
if search_method == 3:
|
115 |
-
return retrieve_rerank_DPR(question)
|
116 |
if search_method == 4:
|
117 |
return retrieve_rerank(question)
|
118 |
|
|
|
|
|
|
|
|
|
119 |
|
120 |
def custom_dpr_pipeline(question):
|
121 |
#paragraphs
|
122 |
-
tokenized_question = question_tokenizer(question, padding=True, truncation=True, return_tensors="pt"
|
123 |
-
add_special_tokens=True)
|
124 |
question_embedding = dpr_trained.model.question_model(**tokenized_question)
|
125 |
-
question_embedding = question_embedding['
|
|
|
126 |
results_list = []
|
127 |
for i,context_embedding in enumerate(dpr_context_embeddings):
|
128 |
score = dot_product(question_embedding, context_embedding)
|
@@ -145,35 +123,13 @@ def retrieve(question):
|
|
145 |
hits = hits[0]
|
146 |
return hits
|
147 |
|
148 |
-
def retrieve_with_dpr_embeddings(question):
|
149 |
-
# Semantic Search (Retrieve)
|
150 |
-
question_tokens = question_tokenizer(question, padding=True, truncation=True, return_tensors="pt",
|
151 |
-
add_special_tokens=True)
|
152 |
-
|
153 |
-
question_embedding = dpr_trained.model.question_model(**question_tokens)['pooler_output']
|
154 |
-
question_embedding = torch.squeeze(question_embedding, dim=0)
|
155 |
-
corpus_embeddings = torch.stack(dpr_context_embeddings)
|
156 |
-
corpus_embeddings = torch.squeeze(corpus_embeddings, dim=1)
|
157 |
-
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=100, score_function=util.dot_score)
|
158 |
-
if len(hits) == 0:
|
159 |
-
return []
|
160 |
-
hits = hits[0]
|
161 |
-
return hits, question_embedding
|
162 |
-
|
163 |
-
def rerank_with_DPR(hits, question_embedding):
|
164 |
-
# Rerank - score all retrieved passages with cross-encoder
|
165 |
-
selected_contexes = [dpr_contexes[hit['corpus_id']] for hit in hits]
|
166 |
-
selected_embeddings = [dpr_context_embeddings[hit['corpus_id']] for hit in hits]
|
167 |
-
top_5_scores, top_5_contexes = DPR_reranking(question_embedding, selected_contexes, selected_embeddings)
|
168 |
-
return top_5_contexes, top_5_scores
|
169 |
-
|
170 |
def retrieve_rerank_with_trained_cross_encoder(question):
|
171 |
hits = retrieve(question)
|
172 |
cross_inp = [(question, contexes[hit['corpus_id']]) for hit in hits]
|
173 |
cross_scores = trained_cross_encoder.predict(cross_inp)
|
174 |
# Sort results by the cross-encoder scores
|
175 |
for idx in range(len(cross_scores)):
|
176 |
-
hits[idx]['cross-score'] = cross_scores[idx]
|
177 |
|
178 |
# Output of top-5 hits from re-ranker
|
179 |
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
|
@@ -263,19 +219,22 @@ def qa_main_widgetsv2():
|
|
263 |
|
264 |
@st.cache(show_spinner=False, allow_output_mutation = True)
|
265 |
def load_models(dpr_model_path, auth_token, cross_encoder_model_path):
|
|
|
266 |
dpr_trained = AutoModel.from_pretrained(dpr_model_path, use_auth_token=auth_token,
|
267 |
trust_remote_code=True)
|
268 |
bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
|
269 |
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
270 |
bi_encoder.max_seq_length = 500
|
271 |
trained_cross_encoder = CrossEncoder(cross_encoder_model_path)
|
272 |
-
question_tokenizer =
|
273 |
return dpr_trained, bi_encoder, cross_encoder, trained_cross_encoder, question_tokenizer
|
274 |
|
275 |
context_embeddings, contexes = load_paragraphs('st-context-embeddings.pkl')
|
276 |
-
dpr_context_embeddings, dpr_contexes = load_paragraphs('basecamp-dpr-
|
277 |
dataframe_bsbs = load_dataframes()
|
278 |
dpr_trained, bi_encoder, cross_encoder, trained_cross_encoder, question_tokenizer = copy.deepcopy(load_models(st.secrets["DPR_MODEL_PATH"], st.secrets["AUTH_TOKEN"], st.secrets["CROSS_ENCODER_MODEL_PATH"]))
|
279 |
-
|
280 |
qa_main_widgetsv2()
|
281 |
|
|
|
|
|
|
|
|
7 |
from st_aggrid import GridOptionsBuilder, AgGrid
|
8 |
import pickle
|
9 |
import torch
|
10 |
+
from transformers import AutoTokenizer, AutoModel
|
11 |
from pathlib import Path
|
12 |
import base64
|
13 |
import io
|
|
|
20 |
selectbox_selections = {
|
21 |
'Retrieve - Rerank (with fine-tuned cross-encoder)': 1,
|
22 |
'Dense Passage Retrieval':2,
|
|
|
23 |
'Retrieve - Rerank':4
|
24 |
}
|
25 |
imagebox_selections = {
|
26 |
'Retrieve - Rerank (with fine-tuned cross-encoder)': 'Retrieve-rerank-trained-cross-encoder.png',
|
27 |
'Dense Passage Retrieval': 'DPR_pipeline.png',
|
|
|
28 |
'Retrieve - Rerank': 'retrieve-rerank.png'
|
29 |
}
|
30 |
|
|
|
61 |
@st.cache(show_spinner=False, allow_output_mutation=True)
|
62 |
def load_paragraphs(path):
|
63 |
with open(path, "rb") as fIn:
|
64 |
+
cache_data = pickle.load(fIn)
|
65 |
corpus_sentences = cache_data['contexes']
|
66 |
corpus_embeddings = cache_data['embeddings']
|
67 |
|
|
|
82 |
result = torch.dot(mat1, mat2)
|
83 |
return result
|
84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
def search_pipeline(question, search_method):
|
86 |
if search_method == 1: #Retrieve - rerank with fine-tuned cross encoder
|
87 |
return retrieve_rerank_with_trained_cross_encoder(question)
|
88 |
if search_method == 2:
|
89 |
return custom_dpr_pipeline(question) # DPR only
|
|
|
|
|
90 |
if search_method == 4:
|
91 |
return retrieve_rerank(question)
|
92 |
|
93 |
+
def mean_pooling(token_embeddings, mask):
|
94 |
+
token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.)
|
95 |
+
sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]
|
96 |
+
return sentence_embeddings
|
97 |
|
98 |
def custom_dpr_pipeline(question):
|
99 |
#paragraphs
|
100 |
+
tokenized_question = question_tokenizer(question, padding=True, truncation=True, return_tensors="pt")
|
|
|
101 |
question_embedding = dpr_trained.model.question_model(**tokenized_question)
|
102 |
+
question_embedding = mean_pooling(question_embedding[0], tokenized_question['attention_mask'])
|
103 |
+
# question_embedding = question_embedding['pooler_output']
|
104 |
results_list = []
|
105 |
for i,context_embedding in enumerate(dpr_context_embeddings):
|
106 |
score = dot_product(question_embedding, context_embedding)
|
|
|
123 |
hits = hits[0]
|
124 |
return hits
|
125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
def retrieve_rerank_with_trained_cross_encoder(question):
|
127 |
hits = retrieve(question)
|
128 |
cross_inp = [(question, contexes[hit['corpus_id']]) for hit in hits]
|
129 |
cross_scores = trained_cross_encoder.predict(cross_inp)
|
130 |
# Sort results by the cross-encoder scores
|
131 |
for idx in range(len(cross_scores)):
|
132 |
+
hits[idx]['cross-score'] = cross_scores[idx]
|
133 |
|
134 |
# Output of top-5 hits from re-ranker
|
135 |
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
|
|
|
219 |
|
220 |
@st.cache(show_spinner=False, allow_output_mutation = True)
|
221 |
def load_models(dpr_model_path, auth_token, cross_encoder_model_path):
|
222 |
+
|
223 |
dpr_trained = AutoModel.from_pretrained(dpr_model_path, use_auth_token=auth_token,
|
224 |
trust_remote_code=True)
|
225 |
bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
|
226 |
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
227 |
bi_encoder.max_seq_length = 500
|
228 |
trained_cross_encoder = CrossEncoder(cross_encoder_model_path)
|
229 |
+
question_tokenizer = AutoTokenizer.from_pretrained('facebook/contriever-msmarco')
|
230 |
return dpr_trained, bi_encoder, cross_encoder, trained_cross_encoder, question_tokenizer
|
231 |
|
232 |
context_embeddings, contexes = load_paragraphs('st-context-embeddings.pkl')
|
233 |
+
dpr_context_embeddings, dpr_contexes = load_paragraphs('basecamp-dpr-contriever-embeddings.pkl')
|
234 |
dataframe_bsbs = load_dataframes()
|
235 |
dpr_trained, bi_encoder, cross_encoder, trained_cross_encoder, question_tokenizer = copy.deepcopy(load_models(st.secrets["DPR_MODEL_PATH"], st.secrets["AUTH_TOKEN"], st.secrets["CROSS_ENCODER_MODEL_PATH"]))
|
|
|
236 |
qa_main_widgetsv2()
|
237 |
|
238 |
+
#if __name__ == '__main__':
|
239 |
+
# top_5_contexes, top_5_scores = search_pipeline('What are the benefits of 37Signals Visa Card?', 1)
|
240 |
+
|
st-context-embeddings.pkl
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:79e231244e12074d5e22f46cf3da70f4f1dd43cc6e82f36959d2c6817f2e2bf2
|
3 |
+
size 441107
|