secilozksen's picture
Upload 8 files
38a4d89
raw history blame
No virus
13.3 kB
import copy
import streamlit as st
import json
import pandas as pd
import tokenizers
from sentence_transformers import SentenceTransformer, CrossEncoder, util
from transformers import pipeline
from st_aggrid import GridOptionsBuilder, AgGrid
import pickle
import torch
from transformers import RobertaTokenizer, RobertaForSequenceClassification
import spacy
import regex
from typing import List
from torch.autograd import Variable
st.set_page_config(layout="wide")
DATAFRAME_FILE_ORIGINAL = 'policyQA_original.csv'
DATAFRAME_FILE_BSBS = 'policyQA_bsbs_sentence.csv'
@st.experimental_singleton(suppress_st_warning=True, show_spinner=False)
def cross_encoder_init():
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
return cross_encoder
@st.experimental_singleton(suppress_st_warning=True, show_spinner=False)
def bi_encoder_init():
bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
bi_encoder.max_seq_length = 500 # Truncate long passages to 256 tokens
return bi_encoder
@st.experimental_singleton(suppress_st_warning=True, show_spinner=False)
def nlp_init(auth_token, private_model_name):
return pipeline('question-answering', model=private_model_name, tokenizer=private_model_name,
use_auth_token=auth_token,
revision="main")
@st.experimental_singleton(suppress_st_warning=True, show_spinner=False)
def nlp_pipeline_hf():
model_name = "deepset/roberta-base-squad2"
return pipeline('question-answering', model=model_name, tokenizer=model_name)
@st.experimental_singleton(suppress_st_warning=True, show_spinner=False)
def nlp_pipeline_sentence_based(auth_token, private_model_name):
tokenizer = RobertaTokenizer.from_pretrained(private_model_name, use_auth_token=auth_token)
model = RobertaForSequenceClassification.from_pretrained(private_model_name, use_auth_token=auth_token)
return tokenizer, model
@st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None, tokenizers.AddedToken: lambda _: None,
regex.Pattern: lambda _: None}, show_spinner=False)
def load_models_sentence_based(auth_token, private_model_name, private_model_name_base):
bi_encoder = bi_encoder_init()
cross_encoder = cross_encoder_init()
# OLD MODEL
# nlp = nlp_init(auth_token, private_model_name)
# nlp_hf = nlp_pipeline_hf()
policy_qa_tokenizer, policy_qa_model = nlp_pipeline_sentence_based(auth_token, private_model_name)
asnq_tokenizer, asnq_model = nlp_pipeline_sentence_based(auth_token, private_model_name_base)
return bi_encoder, cross_encoder, policy_qa_tokenizer, policy_qa_model, asnq_tokenizer, asnq_model
@st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None, tokenizers.AddedToken: lambda _: None}, show_spinner=False)
def load_models(auth_token, private_model_name):
bi_encoder = bi_encoder_init()
cross_encoder = cross_encoder_init()
nlp = nlp_init(auth_token, private_model_name)
nlp_hf = nlp_pipeline_hf()
return bi_encoder, cross_encoder, nlp, nlp_hf
def context():
bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1', device='cpu')
with open("/home/secilsen/PycharmProjects/SquadOperations/contexes.json", 'r', encoding='utf-8') as f:
paragraphs = json.load(f)
paragraphs = paragraphs['contexes']
with open('context-embeddings.pkl', "wb") as fIn:
context_embeddings = bi_encoder.encode(paragraphs, convert_to_tensor=True, show_progress_bar=True)
pickle.dump({'contexes': paragraphs, 'embeddings': context_embeddings}, fIn)
@st.cache(show_spinner=False)
def load_paragraphs():
with open('context-embeddings.pkl', "rb") as fIn:
cache_data = pickle.load(fIn)
corpus_sentences = cache_data['contexes']
corpus_embeddings = cache_data['embeddings']
return corpus_embeddings, corpus_sentences
@st.cache(show_spinner=False)
def load_dataframes():
data_original = pd.read_csv(DATAFRAME_FILE_ORIGINAL, index_col=0, sep='|')
data_bsbs = pd.read_csv(DATAFRAME_FILE_BSBS, index_col=0, sep='|')
data_original = data_original.sample(frac=1).reset_index(drop=True)
data_bsbs = data_bsbs.sample(frac=1).reset_index(drop=True)
return data_original, data_bsbs
def search(question, corpus_embeddings, contexes, bi_encoder, cross_encoder):
# Semantic Search (Retrieve)
question_embedding = bi_encoder.encode(question, convert_to_tensor=True)
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=100)
if len(hits) == 0:
return []
hits = hits[0]
# Rerank - score all retrieved passages with cross-encoder
cross_inp = [[question, contexes[hit['corpus_id']]] for hit in hits]
cross_scores = cross_encoder.predict(cross_inp)
# Sort results by the cross-encoder scores
for idx in range(len(cross_scores)):
hits[idx]['cross-score'] = cross_scores[idx]
# Output of top-5 hits from re-ranker
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
top_5_contexes = []
top_5_scores = []
for hit in hits[0:20]:
top_5_contexes.append(contexes[hit['corpus_id']])
top_5_scores.append(hit['cross-score'])
return top_5_contexes, top_5_scores
def paragraph_embeddings():
paragraphs = load_paragraphs()
context_embeddings = bi_encoder.encode(paragraphs, convert_to_tensor=True, show_progress_bar=True)
return context_embeddings, paragraphs
def retrieve_rerank_pipeline(question, context_embeddings, paragraphs, bi_encoder, cross_encoder):
top_5_contexes, top_5_scores = search(question, context_embeddings, paragraphs, bi_encoder, cross_encoder)
return top_5_contexes, top_5_scores
def qa_pipeline(question, context, nlp):
return nlp({'question': question.strip(), 'context': context})
def qa_pipeline_sentence(question, context, model, tokenizer):
sentences_doc = spacy_nlp(context)
candidate_sentences = []
for sentence in sentences_doc.sents:
tokenized = tokenizer(f"<s> {question} </s> {sentence.text} </s>", padding=True, truncation=True, return_tensors='pt')
output = model(**tokenized)
soft_outputs = torch.nn.functional.sigmoid(output[0])
t = Variable(torch.Tensor([0.2])) # threshold
out = (soft_outputs[0] > t) * 1
out = out.flatten().cpu().detach().numpy()
# res = torch.argmax(out, dim=-1)
print(out[1])
if out[1] == 1:
prob = soft_outputs[:, 1].flatten().cpu().detach().numpy()
candidate_sentences.append(dict(sentence=sentence,
prob=prob[0]))
print(candidate_sentences)
candidate_sentences = sorted(candidate_sentences, key=lambda x: x['prob'], reverse=True)
return candidate_sentences
def candidate_sentence_controller(sentences):
if sentences is None or len(sentences) == 0:
return ""
if len(sentences) == 1:
return sentences[0]
return sentences
def interactive_table(dataframe):
gb = GridOptionsBuilder.from_dataframe(dataframe)
gb.configure_pagination(paginationAutoPageSize=True)
gb.configure_side_bar()
gb.configure_selection('single', rowMultiSelectWithClick=True,
groupSelectsChildren="Group checkbox select children") # Enable multi-row selection
gridOptions = gb.build()
grid_response = AgGrid(
dataframe,
gridOptions=gridOptions,
data_return_mode='AS_INPUT',
update_mode='SELECTION_CHANGED',
enable_enterprise_modules=False,
fit_columns_on_grid_load=False,
theme='streamlit', # Add theme color to the table
height=350,
width='100%',
reload_data=False
)
return grid_response
def qa_main_widgetsv2():
st.title("Question Answering Demo")
col1, col2, col3 = st.columns([2, 1, 1])
with col1:
form = st.form(key='first_form')
question = form.text_area("What is your question?:", height=200)
submit = form.form_submit_button('Submit')
if "form_submit" not in st.session_state:
st.session_state.form_submit = False
if submit:
st.session_state.form_submit = True
if st.session_state.form_submit and question != '':
with st.spinner(text='Related context search in progress..'):
top_5_contexes, top_5_scores = retrieve_rerank_pipeline(question.strip(), context_embeddings,
paragraphs, bi_encoder,
cross_encoder)
if len(top_5_contexes) == 0:
st.error("Related context not found!")
st.session_state.form_submit = False
else:
with st.spinner(text='Now answering your question..'):
for i, context in enumerate(top_5_contexes):
# answer_trained = qa_pipeline(question, context, nlp)
# answer_base = qa_pipeline(question, context, nlp_hf)
answer_trained = qa_pipeline_sentence(question, context, policy_qa_model, policy_qa_tokenizer)
answer_base = qa_pipeline_sentence(question, context, asnq_model, asnq_tokenizer)
st.markdown(f"## Related Context - {i + 1} (score: {top_5_scores[i]:.2f})")
st.markdown(context)
st.markdown("## Answer (trained):")
if answer_trained is None:
st.markdown("")
elif isinstance(answer_trained, List):
for i,answer in enumerate(answer_trained):
st.markdown(f"### Answer Option {i+1} with prob. {answer['prob']:.4f}")
st.markdown(answer['sentence'])
else:
st.markdown(answer_trained)
# st.markdown(answer_trained['answer'])
st.markdown("## Answer (roberta-base-asnq):")
if answer_base is None:
st.markdown("")
elif isinstance(answer_base, List):
for i,answer in enumerate(answer_base):
st.markdown(f"### Answer Option {i + 1} with prob. {answer['prob']:.4f}")
st.markdown(answer['sentence'])
else:
st.markdown(answer_base)
st.markdown("""---""")
with col2:
st.markdown("## Original Questions")
grid_response = interactive_table(dataframe_original)
data1 = grid_response['selected_rows']
if "grid_click_1" not in st.session_state:
st.session_state.grid_click_1 = False
if len(data1) > 0:
st.session_state.grid_click_1 = True
if st.session_state.grid_click_1:
selection = data1[0]
# st.markdown("## Context & Answer:")
st.markdown("### Context:")
st.write(selection['context'])
st.markdown("### Question:")
st.write(selection['question'])
st.markdown("### Answer:")
st.write(selection['answer'])
st.session_state.grid_click_1 = False
with col3:
st.markdown("## Our Questions")
grid_response = interactive_table(dataframe_bsbs)
data2 = grid_response['selected_rows']
if "grid_click_2" not in st.session_state:
st.session_state.grid_click_2 = False
if len(data2) > 0:
st.session_state.grid_click_2 = True
if st.session_state.grid_click_2:
selection = data2[0]
# st.markdown("## Context & Answer:")
st.markdown("### Context:")
st.write(selection['context'])
st.markdown("### Question:")
st.write(selection['question'])
st.markdown("### Answer:")
st.write(selection['answer'])
st.session_state.grid_click_2 = False
def load():
context_embeddings, paragraphs = load_paragraphs()
dataframe_original, dataframe_bsbs = load_dataframes()
spacy_nlp = spacy.load('en_core_web_sm')
# bi_encoder, cross_encoder, nlp, nlp_hf = copy.deepcopy(load(st.secrets["AUTH_TOKEN"], st.secrets["MODEL_NAME"]))
bi_encoder, cross_encoder, policy_qa_tokenizer, policy_qa_model, asnq_tokenizer, asnq_model \
= copy.deepcopy(
load_models_sentence_based(st.secrets["AUTH_TOKEN"], st.secrets["MODEL_NAME"], st.secrets["MODEL_NAME_BASE"]))
return context_embeddings, paragraphs, dataframe_original, dataframe_bsbs, bi_encoder, cross_encoder, policy_qa_tokenizer, policy_qa_model, asnq_tokenizer, asnq_model, spacy_nlp
# save_dataframe()
# context_embeddings, paragraphs, dataframe_original, dataframe_bsbs, bi_encoder, cross_encoder, nlp, nlp_hf = load()
context_embeddings, paragraphs, dataframe_original, dataframe_bsbs, bi_encoder, cross_encoder, policy_qa_tokenizer, policy_qa_model, asnq_tokenizer, asnq_model, spacy_nlp = load()
qa_main_widgetsv2()
# if __name__ == '__main__':
# context()