Spaces:
Runtime error
Runtime error
secilozksen
commited on
Commit
·
02ecb0f
1
Parent(s):
98b83d0
Upload 4 files
Browse filesdemo with new dataset commit
- basecamp-dpr-context-embeddings.pkl +3 -0
- basecamp.csv +0 -0
- demo_dpr.py +25 -41
- st-context-embeddings.pkl +3 -0
basecamp-dpr-context-embeddings.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:18843457511ccc9cd7e998dafac0339d60dcc9984a69fcf884f9e96d2fd11d15
|
3 |
+
size 68535357
|
basecamp.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
demo_dpr.py
CHANGED
@@ -16,7 +16,7 @@ import tokenizers
|
|
16 |
st.set_page_config(layout="wide")
|
17 |
|
18 |
DATAFRAME_FILE_ORIGINAL = 'policyQA_original.csv'
|
19 |
-
DATAFRAME_FILE_BSBS = '
|
20 |
|
21 |
selectbox_selections = {
|
22 |
'Retrieve - Rerank (with fine-tuned cross-encoder)': 1,
|
@@ -68,22 +68,21 @@ def load_paragraphs(path):
|
|
68 |
|
69 |
@st.cache(show_spinner=False)
|
70 |
def load_dataframes():
|
71 |
-
|
72 |
data_bsbs = pd.read_csv(DATAFRAME_FILE_BSBS, index_col=0, sep='|')
|
73 |
-
|
|
|
74 |
data_bsbs = data_bsbs.sample(frac=1).reset_index(drop=True)
|
75 |
-
return
|
76 |
|
77 |
def dot_product(question_output, context_output):
|
78 |
-
mat1 = torch.
|
79 |
-
mat2 = torch.
|
80 |
-
result = torch.
|
81 |
-
result = torch.squeeze(result, dim=1)
|
82 |
-
result = torch.squeeze(result, dim=1)
|
83 |
return result
|
84 |
|
85 |
def retrieve_rerank_DPR(question):
|
86 |
-
hits =
|
87 |
return rerank_with_DPR(hits, question)
|
88 |
|
89 |
def DPR_reranking(question, selected_contexes, selected_embeddings):
|
@@ -124,7 +123,7 @@ def custom_dpr_pipeline(question):
|
|
124 |
results_list = []
|
125 |
for i,context_embedding in enumerate(dpr_context_embeddings):
|
126 |
score = dot_product(question_embedding, context_embedding)
|
127 |
-
results_list.append(score.detach().cpu()
|
128 |
|
129 |
hits = sorted(range(len(results_list)), key=lambda i: results_list[i], reverse=True)
|
130 |
top_5_contexes = []
|
@@ -134,10 +133,10 @@ def custom_dpr_pipeline(question):
|
|
134 |
top_5_scores.append(results_list[j])
|
135 |
return top_5_contexes, top_5_scores
|
136 |
|
137 |
-
def retrieve(question
|
138 |
# Semantic Search (Retrieve)
|
139 |
question_embedding = bi_encoder.encode(question, convert_to_tensor=True)
|
140 |
-
hits = util.semantic_search(question_embedding,
|
141 |
if len(hits) == 0:
|
142 |
return []
|
143 |
hits = hits[0]
|
@@ -156,41 +155,22 @@ def retrieve_with_dpr_embeddings(question):
|
|
156 |
if len(hits) == 0:
|
157 |
return []
|
158 |
hits = hits[0]
|
159 |
-
return hits
|
160 |
|
161 |
-
def rerank_with_DPR(hits,
|
162 |
# Rerank - score all retrieved passages with cross-encoder
|
163 |
selected_contexes = [dpr_contexes[hit['corpus_id']] for hit in hits]
|
164 |
selected_embeddings = [dpr_context_embeddings[hit['corpus_id']] for hit in hits]
|
165 |
-
top_5_scores, top_5_contexes = DPR_reranking(
|
166 |
return top_5_contexes, top_5_scores
|
167 |
|
168 |
-
def DPR_reranking(question, selected_contexes, selected_embeddings):
|
169 |
-
scores = []
|
170 |
-
tokenized_question = question_tokenizer(question, padding=True, truncation=True, return_tensors="pt",
|
171 |
-
add_special_tokens=True)
|
172 |
-
question_output = dpr_trained.model.question_model(**tokenized_question)
|
173 |
-
question_output = question_output['pooler_output']
|
174 |
-
for context_embedding in selected_embeddings:
|
175 |
-
score = dot_product(question_output, context_embedding)
|
176 |
-
scores.append(score.detach().cpu().numpy()[0])
|
177 |
-
|
178 |
-
scores_index = sorted(range(len(scores)), key=lambda x: scores[x], reverse=True)
|
179 |
-
contexes_list = []
|
180 |
-
scores_final = []
|
181 |
-
for i, idx in enumerate(scores_index[:5]):
|
182 |
-
scores_final.append(scores[idx])
|
183 |
-
contexes_list.append(selected_contexes[idx])
|
184 |
-
return scores_final, contexes_list
|
185 |
-
|
186 |
-
|
187 |
def retrieve_rerank_with_trained_cross_encoder(question):
|
188 |
-
hits = retrieve(question
|
189 |
cross_inp = [(question, contexes[hit['corpus_id']]) for hit in hits]
|
190 |
cross_scores = trained_cross_encoder.predict(cross_inp)
|
191 |
# Sort results by the cross-encoder scores
|
192 |
for idx in range(len(cross_scores)):
|
193 |
-
hits[idx]['cross-score'] = cross_scores[idx][
|
194 |
|
195 |
# Output of top-5 hits from re-ranker
|
196 |
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
|
@@ -229,7 +209,7 @@ def img_to_bytes(img_path):
|
|
229 |
return encoded
|
230 |
|
231 |
def qa_main_widgetsv2():
|
232 |
-
st.title("
|
233 |
st.markdown("""---""")
|
234 |
option = st.selectbox("Select a search method:", list(selectbox_selections.keys()))
|
235 |
header_html = "<center> <img src='data:image/png;base64,{}' class='img-fluid' width='60%', height='40%'> </center>".format(
|
@@ -289,9 +269,13 @@ def load_models(dpr_model_path, auth_token, cross_encoder_model_path):
|
|
289 |
question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained('facebook/dpr-question_encoder-single-nq-base')
|
290 |
return dpr_trained, bi_encoder, cross_encoder, trained_cross_encoder, question_tokenizer
|
291 |
|
292 |
-
context_embeddings, contexes = load_paragraphs('context-embeddings.pkl')
|
293 |
-
dpr_context_embeddings, dpr_contexes = load_paragraphs('
|
294 |
-
|
295 |
dpr_trained, bi_encoder, cross_encoder, trained_cross_encoder, question_tokenizer = copy.deepcopy(load_models(st.secrets["DPR_MODEL_PATH"], st.secrets["AUTH_TOKEN"], st.secrets["CROSS_ENCODER_MODEL_PATH"]))
|
296 |
|
297 |
qa_main_widgetsv2()
|
|
|
|
|
|
|
|
|
|
16 |
st.set_page_config(layout="wide")
|
17 |
|
18 |
DATAFRAME_FILE_ORIGINAL = 'policyQA_original.csv'
|
19 |
+
DATAFRAME_FILE_BSBS = 'basecamp.csv'
|
20 |
|
21 |
selectbox_selections = {
|
22 |
'Retrieve - Rerank (with fine-tuned cross-encoder)': 1,
|
|
|
68 |
|
69 |
@st.cache(show_spinner=False)
|
70 |
def load_dataframes():
|
71 |
+
# data_original = pd.read_csv(DATAFRAME_FILE_ORIGINAL, index_col=0, sep='|')
|
72 |
data_bsbs = pd.read_csv(DATAFRAME_FILE_BSBS, index_col=0, sep='|')
|
73 |
+
data_bsbs.drop('context_id', axis=1, inplace=True)
|
74 |
+
# data_original = data_original.sample(frac=1).reset_index(drop=True)
|
75 |
data_bsbs = data_bsbs.sample(frac=1).reset_index(drop=True)
|
76 |
+
return data_bsbs
|
77 |
|
78 |
def dot_product(question_output, context_output):
|
79 |
+
mat1 = torch.squeeze(question_output, 0)
|
80 |
+
mat2 = torch.squeeze(context_output, 0)
|
81 |
+
result = torch.dot(mat1, mat2)
|
|
|
|
|
82 |
return result
|
83 |
|
84 |
def retrieve_rerank_DPR(question):
|
85 |
+
hits = retrieve(question)
|
86 |
return rerank_with_DPR(hits, question)
|
87 |
|
88 |
def DPR_reranking(question, selected_contexes, selected_embeddings):
|
|
|
123 |
results_list = []
|
124 |
for i,context_embedding in enumerate(dpr_context_embeddings):
|
125 |
score = dot_product(question_embedding, context_embedding)
|
126 |
+
results_list.append(score.detach().cpu())
|
127 |
|
128 |
hits = sorted(range(len(results_list)), key=lambda i: results_list[i], reverse=True)
|
129 |
top_5_contexes = []
|
|
|
133 |
top_5_scores.append(results_list[j])
|
134 |
return top_5_contexes, top_5_scores
|
135 |
|
136 |
+
def retrieve(question):
|
137 |
# Semantic Search (Retrieve)
|
138 |
question_embedding = bi_encoder.encode(question, convert_to_tensor=True)
|
139 |
+
hits = util.semantic_search(question_embedding, context_embeddings, top_k=100)
|
140 |
if len(hits) == 0:
|
141 |
return []
|
142 |
hits = hits[0]
|
|
|
155 |
if len(hits) == 0:
|
156 |
return []
|
157 |
hits = hits[0]
|
158 |
+
return hits, question_embedding
|
159 |
|
160 |
+
def rerank_with_DPR(hits, question_embedding):
|
161 |
# Rerank - score all retrieved passages with cross-encoder
|
162 |
selected_contexes = [dpr_contexes[hit['corpus_id']] for hit in hits]
|
163 |
selected_embeddings = [dpr_context_embeddings[hit['corpus_id']] for hit in hits]
|
164 |
+
top_5_scores, top_5_contexes = DPR_reranking(question_embedding, selected_contexes, selected_embeddings)
|
165 |
return top_5_contexes, top_5_scores
|
166 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
def retrieve_rerank_with_trained_cross_encoder(question):
|
168 |
+
hits = retrieve(question)
|
169 |
cross_inp = [(question, contexes[hit['corpus_id']]) for hit in hits]
|
170 |
cross_scores = trained_cross_encoder.predict(cross_inp)
|
171 |
# Sort results by the cross-encoder scores
|
172 |
for idx in range(len(cross_scores)):
|
173 |
+
hits[idx]['cross-score'] = cross_scores[idx][1]
|
174 |
|
175 |
# Output of top-5 hits from re-ranker
|
176 |
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
|
|
|
209 |
return encoded
|
210 |
|
211 |
def qa_main_widgetsv2():
|
212 |
+
st.title("Question Answering Demo")
|
213 |
st.markdown("""---""")
|
214 |
option = st.selectbox("Select a search method:", list(selectbox_selections.keys()))
|
215 |
header_html = "<center> <img src='data:image/png;base64,{}' class='img-fluid' width='60%', height='40%'> </center>".format(
|
|
|
269 |
question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained('facebook/dpr-question_encoder-single-nq-base')
|
270 |
return dpr_trained, bi_encoder, cross_encoder, trained_cross_encoder, question_tokenizer
|
271 |
|
272 |
+
context_embeddings, contexes = load_paragraphs('st-context-embeddings.pkl')
|
273 |
+
dpr_context_embeddings, dpr_contexes = load_paragraphs('basecamp-dpr-context-embeddings.pkl')
|
274 |
+
dataframe_bsbs = load_dataframes()
|
275 |
dpr_trained, bi_encoder, cross_encoder, trained_cross_encoder, question_tokenizer = copy.deepcopy(load_models(st.secrets["DPR_MODEL_PATH"], st.secrets["AUTH_TOKEN"], st.secrets["CROSS_ENCODER_MODEL_PATH"]))
|
276 |
|
277 |
qa_main_widgetsv2()
|
278 |
+
|
279 |
+
#if __name__ == '__main__':
|
280 |
+
# search_pipeline('Life insurance is paid by insurance companies that pay for what?', 1)
|
281 |
+
|
st-context-embeddings.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bd65fe793062375df1efd50218e9a7c35253fe06a24e5527de7855671a4f958c
|
3 |
+
size 468299
|