secilozksen's picture
Upload 14 files
bbe9860
raw history blame
No virus
13.3 kB
import copy
import streamlit as st
import pandas as pd
from sentence_transformers import SentenceTransformer, util
from sentence_transformers.cross_encoder import CrossEncoder
from st_aggrid import GridOptionsBuilder, AgGrid
import pickle
import torch
from transformers import DPRQuestionEncoderTokenizer, AutoModel
from pathlib import Path
import base64
import regex
import tokenizers
st.set_page_config(layout="wide")
DATAFRAME_FILE_ORIGINAL = 'policyQA_original.csv'
DATAFRAME_FILE_BSBS = 'policyQA_bsbs_sentence.csv'
selectbox_selections = {
'Retrieve - Rerank (with fine-tuned cross-encoder)': 1,
'Dense Passage Retrieval':2,
'Retrieve - Reranking with DPR':3,
'Retrieve - Rerank':4
}
imagebox_selections = {
'Retrieve - Rerank (with fine-tuned cross-encoder)': 'Retrieve-rerank-trained-cross-encoder.png',
'Dense Passage Retrieval': 'DPR_pipeline.png',
'Retrieve - Reranking with DPR': 'Retrieve-rerank-DPR.png',
'Retrieve - Rerank': 'retrieve-rerank.png'
}
def retrieve_rerank(question):
# Semantic Search (Retrieve)
question_embedding = bi_encoder.encode(question, convert_to_tensor=True)
hits = util.semantic_search(question_embedding, context_embeddings, top_k=100)
if len(hits) == 0:
return []
hits = hits[0]
# Rerank - score all retrieved passages with cross-encoder
cross_inp = [[question, contexes[hit['corpus_id']]] for hit in hits]
cross_scores = cross_encoder.predict(cross_inp)
# Sort results by the cross-encoder scores
for idx in range(len(cross_scores)):
hits[idx]['cross-score'] = cross_scores[idx]
# Output of top-5 hits from re-ranker
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
top_5_contexes = []
top_5_scores = []
for hit in hits[0:20]:
top_5_contexes.append(contexes[hit['corpus_id']])
top_5_scores.append(hit['cross-score'])
return top_5_contexes, top_5_scores
@st.cache(show_spinner=False, allow_output_mutation=True)
def load_paragraphs(path):
with open(path, "rb") as fIn:
cache_data = pickle.load(fIn)
corpus_sentences = cache_data['contexes']
corpus_embeddings = cache_data['embeddings']
return corpus_embeddings, corpus_sentences
@st.cache(show_spinner=False)
def load_dataframes():
data_original = pd.read_csv(DATAFRAME_FILE_ORIGINAL, index_col=0, sep='|')
data_bsbs = pd.read_csv(DATAFRAME_FILE_BSBS, index_col=0, sep='|')
data_original = data_original.sample(frac=1).reset_index(drop=True)
data_bsbs = data_bsbs.sample(frac=1).reset_index(drop=True)
return data_original, data_bsbs
def dot_product(question_output, context_output):
mat1 = torch.unsqueeze(question_output, dim=1)
mat2 = torch.unsqueeze(context_output, dim=2)
result = torch.bmm(mat1, mat2)
result = torch.squeeze(result, dim=1)
result = torch.squeeze(result, dim=1)
return result
def retrieve_rerank_DPR(question):
hits = retrieve_with_dpr_embeddings(question)
return rerank_with_DPR(hits, question)
def DPR_reranking(question, selected_contexes, selected_embeddings):
scores = []
tokenized_question = question_tokenizer(question, padding=True, truncation=True, return_tensors="pt",
add_special_tokens=True)
question_output = dpr_trained.model.question_model(**tokenized_question)
question_output = question_output['pooler_output']
for context_embedding in selected_embeddings:
score = dot_product(question_output, context_embedding)
scores.append(score.detach().cpu())
scores_index = sorted(range(len(scores)), key=lambda x: scores[x], reverse=True)
contexes_list = []
scores_final = []
for i, idx in enumerate(scores_index[:5]):
scores_final.append(scores[idx])
contexes_list.append(selected_contexes[idx])
return scores_final, contexes_list
def search_pipeline(question, search_method):
if search_method == 1: #Retrieve - rerank with fine-tuned cross encoder
return retrieve_rerank_with_trained_cross_encoder(question)
if search_method == 2:
return custom_dpr_pipeline(question) # DPR only
if search_method == 3:
return retrieve_rerank_DPR(question)
if search_method == 4:
return retrieve_rerank(question)
def custom_dpr_pipeline(question):
#paragraphs
tokenized_question = question_tokenizer(question, padding=True, truncation=True, return_tensors="pt",
add_special_tokens=True)
question_embedding = dpr_trained.model.question_model(**tokenized_question)
question_embedding = question_embedding['pooler_output']
results_list = []
for i,context_embedding in enumerate(dpr_context_embeddings):
score = dot_product(question_embedding, context_embedding)
results_list.append(score.detach().cpu().numpy()[0])
hits = sorted(range(len(results_list)), key=lambda i: results_list[i], reverse=True)
top_5_contexes = []
top_5_scores = []
for j in hits[0:5]:
top_5_contexes.append(dpr_contexes[j])
top_5_scores.append(results_list[j])
return top_5_contexes, top_5_scores
def retrieve(question, corpus_embeddings):
# Semantic Search (Retrieve)
question_embedding = bi_encoder.encode(question, convert_to_tensor=True)
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=100)
if len(hits) == 0:
return []
hits = hits[0]
return hits
def retrieve_with_dpr_embeddings(question):
# Semantic Search (Retrieve)
question_tokens = question_tokenizer(question, padding=True, truncation=True, return_tensors="pt",
add_special_tokens=True)
question_embedding = dpr_trained.model.question_model(**question_tokens)['pooler_output']
question_embedding = torch.squeeze(question_embedding, dim=0)
corpus_embeddings = torch.stack(dpr_context_embeddings)
corpus_embeddings = torch.squeeze(corpus_embeddings, dim=1)
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=100)
if len(hits) == 0:
return []
hits = hits[0]
return hits
def rerank_with_DPR(hits, question):
# Rerank - score all retrieved passages with cross-encoder
selected_contexes = [dpr_contexes[hit['corpus_id']] for hit in hits]
selected_embeddings = [dpr_context_embeddings[hit['corpus_id']] for hit in hits]
top_5_scores, top_5_contexes = DPR_reranking(question, selected_contexes, selected_embeddings)
return top_5_contexes, top_5_scores
def DPR_reranking(question, selected_contexes, selected_embeddings):
scores = []
tokenized_question = question_tokenizer(question, padding=True, truncation=True, return_tensors="pt",
add_special_tokens=True)
question_output = dpr_trained.model.question_model(**tokenized_question)
question_output = question_output['pooler_output']
for context_embedding in selected_embeddings:
score = dot_product(question_output, context_embedding)
scores.append(score.detach().cpu().numpy()[0])
scores_index = sorted(range(len(scores)), key=lambda x: scores[x], reverse=True)
contexes_list = []
scores_final = []
for i, idx in enumerate(scores_index[:5]):
scores_final.append(scores[idx])
contexes_list.append(selected_contexes[idx])
return scores_final, contexes_list
def retrieve_rerank_with_trained_cross_encoder(question):
hits = retrieve(question, context_embeddings)
cross_inp = [(question, contexes[hit['corpus_id']]) for hit in hits]
cross_scores = trained_cross_encoder.predict(cross_inp)
# Sort results by the cross-encoder scores
for idx in range(len(cross_scores)):
hits[idx]['cross-score'] = cross_scores[idx][0]
# Output of top-5 hits from re-ranker
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
top_5_contexes = []
top_5_scores = []
for hit in hits[0:5]:
top_5_contexes.append(contexes[hit['corpus_id']])
top_5_scores.append(hit['cross-score'])
return top_5_contexes, top_5_scores
def interactive_table(dataframe):
gb = GridOptionsBuilder.from_dataframe(dataframe)
gb.configure_pagination(paginationAutoPageSize=True)
gb.configure_side_bar()
gb.configure_selection('single', rowMultiSelectWithClick=True,
groupSelectsChildren="Group checkbox select children") # Enable multi-row selection
gridOptions = gb.build()
grid_response = AgGrid(
dataframe,
gridOptions=gridOptions,
data_return_mode='AS_INPUT',
update_mode='SELECTION_CHANGED',
enable_enterprise_modules=False,
fit_columns_on_grid_load=False,
theme='streamlit', # Add theme color to the table
height=350,
width='100%',
reload_data=False
)
return grid_response
def img_to_bytes(img_path):
img_bytes = Path(img_path).read_bytes()
encoded = base64.b64encode(img_bytes).decode()
return encoded
def qa_main_widgetsv2():
st.title("Question Answering Demo")
st.markdown("""---""")
option = st.selectbox("Select a search method:", list(selectbox_selections.keys()))
header_html = "<center> <img src='data:image/png;base64,{}' class='img-fluid' width='60%', height='40%'> </center>".format(
img_to_bytes(imagebox_selections[option])
)
st.markdown(
header_html, unsafe_allow_html=True,
)
st.markdown("""---""")
col1, col2, col3 = st.columns([2, 1, 1])
with col1:
form = st.form(key='first_form')
question = form.text_area("What is your question?:", height=200)
submit = form.form_submit_button('Submit')
if "form_submit" not in st.session_state:
st.session_state.form_submit = False
if submit:
st.session_state.form_submit = True
if st.session_state.form_submit and question != '':
with st.spinner(text='Related context search in progress..'):
top_5_contexes, top_5_scores = search_pipeline(question.strip(), selectbox_selections[option])
if len(top_5_contexes) == 0:
st.error("Related context not found!")
st.session_state.form_submit = False
else:
for i, context in enumerate(top_5_contexes):
st.markdown(f"## Related Context - {i + 1} (score: {top_5_scores[i]:.2f})")
st.markdown(context)
st.markdown("""---""")
with col2:
st.markdown("## Original Questions")
grid_response = interactive_table(dataframe_original)
data1 = grid_response['selected_rows']
if "grid_click_1" not in st.session_state:
st.session_state.grid_click_1 = False
if len(data1) > 0:
st.session_state.grid_click_1 = True
if st.session_state.grid_click_1:
selection = data1[0]
# st.markdown("## Context & Answer:")
st.markdown("### Context:")
st.write(selection['context'])
st.markdown("### Question:")
st.write(selection['question'])
st.markdown("### Answer:")
st.write(selection['answer'])
st.session_state.grid_click_1 = False
with col3:
st.markdown("## Our Questions")
grid_response = interactive_table(dataframe_bsbs)
data2 = grid_response['selected_rows']
if "grid_click_2" not in st.session_state:
st.session_state.grid_click_2 = False
if len(data2) > 0:
st.session_state.grid_click_2 = True
if st.session_state.grid_click_2:
selection = data2[0]
# st.markdown("## Context & Answer:")
st.markdown("### Context:")
st.write(selection['context'])
st.markdown("### Question:")
st.write(selection['question'])
st.markdown("### Answer:")
st.write(selection['answer'])
st.session_state.grid_click_2 = False
@st.cache(show_spinner=False, allow_output_mutation = True)
def load_models(dpr_model_path, auth_token, cross_encoder_model_path):
dpr_trained = AutoModel.from_pretrained(dpr_model_path, use_auth_token=auth_token,
trust_remote_code=True)
bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
bi_encoder.max_seq_length = 500
trained_cross_encoder = CrossEncoder(cross_encoder_model_path)
question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained('facebook/dpr-question_encoder-single-nq-base')
return dpr_trained, bi_encoder, cross_encoder, trained_cross_encoder, question_tokenizer
context_embeddings, contexes = load_paragraphs('context-embeddings.pkl')
dpr_context_embeddings, dpr_contexes = load_paragraphs('custom-dpr-context-embeddings.pkl')
dataframe_original, dataframe_bsbs = load_dataframes()
dpr_trained, bi_encoder, cross_encoder, trained_cross_encoder, question_tokenizer = copy.deepcopy(load_models(st.secrets["DPR_MODEL_PATH"], st.secrets["AUTH_TOKEN"], st.secrets["CROSS_ENCODER_MODEL_PATH"]))
qa_main_widgetsv2()