policy_test / appStore /keyword_search.py
peter2000's picture
Update appStore/keyword_search.py
55d03cf
raw history blame
No virus
8.72 kB
# set path
import glob, os, sys; sys.path.append('../scripts')
#import helper
import scripts.process as pre
import scripts.clean as clean
#import needed libraries
import seaborn as sns
from pandas import DataFrame
from sentence_transformers import SentenceTransformer, CrossEncoder, util
# from keybert import KeyBERT
from transformers import pipeline
import matplotlib.pyplot as plt
import numpy as np
import streamlit as st
import pandas as pd
from rank_bm25 import BM25Okapi
from sklearn.feature_extraction import _stop_words
import string
from tqdm.autonotebook import tqdm
import numpy as np
import tempfile
import sqlite3
def app():
with st.container():
st.markdown("<h1 style='text-align: center; color: black;'> Keyword Search</h1>", unsafe_allow_html=True)
st.write(' ')
st.write(' ')
with st.expander("ℹ️ - About this app", expanded=True):
st.write(
"""
The *Keyword Search* app is an easy-to-use interface built in Streamlit for doing keyword search in policy document - developed by GIZ Data and the Sustainable Development Solution Network.
"""
)
st.markdown("")
st.markdown("")
st.markdown("## 📌 Step One: Upload document ")
with st.container():
file = st.file_uploader('Upload PDF File', type=['pdf', 'docx', 'txt'])
if file is not None:
with tempfile.NamedTemporaryFile(mode="wb") as temp:
bytes_data = file.getvalue()
temp.write(bytes_data)
st.write("Filename: ", file.name)
# load document
docs = pre.load_document(temp.name, file)
# preprocess document
haystackDoc, dataframeDoc, textData, paraList = clean.preprocessing(docs)
# testing
# st.write(len(all_text))
# for i in par_list:
# st.write(i)
keyword = st.text_input("Please enter here what you want to search, we will look for similar context in the document.",
value="floods",)
@st.cache(allow_output_mutation=True)
def load_sentenceTransformer(name):
return SentenceTransformer(name)
bi_encoder = load_sentenceTransformer('msmarco-distilbert-cos-v5') # multi-qa-MiniLM-L6-cos-v1
bi_encoder.max_seq_length = 64 #Truncate long passages to 256 tokens
top_k = 32
#@st.cache(allow_output_mutation=True)
#def load_crossEncoder(name):
# return CrossEncoder(name)
# cross_encoder = load_crossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
document_embeddings = bi_encoder.encode(paraList, convert_to_tensor=True, show_progress_bar=False)
def bm25_tokenizer(text):
tokenized_doc = []
for token in text.lower().split():
token = token.strip(string.punctuation)
if len(token) > 0 and token not in _stop_words.ENGLISH_STOP_WORDS:
tokenized_doc.append(token)
return tokenized_doc
def bm25TokenizeDoc(paraList):
tokenized_corpus = []
for passage in tqdm(paraList):
if len(passage.split()) >256:
temp = " ".join(passage.split()[:256])
tokenized_corpus.append(bm25_tokenizer(temp))
temp = " ".join(passage.split()[256:])
tokenized_corpus.append(bm25_tokenizer(temp))
else:
tokenized_corpus.append(bm25_tokenizer(passage))
return tokenized_corpus
tokenized_corpus = bm25TokenizeDoc(paraList)
document_bm25 = BM25Okapi(tokenized_corpus)
def search(keyword):
##### BM25 search (lexical search) #####
bm25_scores = document_bm25.get_scores(bm25_tokenizer(keyword))
top_n = np.argpartition(bm25_scores, -10)[-10:]
bm25_hits = [{'corpus_id': idx, 'score': bm25_scores[idx]} for idx in top_n]
bm25_hits = sorted(bm25_hits, key=lambda x: x['score'], reverse=True)
##### Sematic Search #####
# Encode the query using the bi-encoder and find potentially relevant passages
#query = "Does document contain {} issues ?".format(keyword)
question_embedding = bi_encoder.encode(keyword, convert_to_tensor=True)
hits = util.semantic_search(question_embedding, document_embeddings, top_k=top_k)
hits = hits[0] # Get the hits for the first query
##### Re-Ranking #####
# Now, score all retrieved passages with the cross_encoder
#cross_inp = [[query, paraList[hit['corpus_id']]] for hit in hits]
#cross_scores = cross_encoder.predict(cross_inp)
# Sort results by the cross-encoder scores
#for idx in range(len(cross_scores)):
# hits[idx]['cross-score'] = cross_scores[idx]
return bm25_hits, hits
if st.button("Find them."):
bm25_hits, hits = search(keyword)
st.markdown("""
We will provide with 2 kind of results. The 'lexical search' and the semantic search.
""")
# In the semantic search part we provide two kind of results one with only Retriever (Bi-Encoder) and other the ReRanker (Cross Encoder)
st.markdown("Top few lexical search (BM25) hits")
for hit in bm25_hits[0:5]:
if hit['score'] > 0.00:
st.write("\t Score: {:.3f}: \t{}".format(hit['score'], paraList[hit['corpus_id']].replace("\n", " ")))
# st.table(bm25_hits[0:3])
st.markdown("\n-------------------------\n")
st.markdown("Top few Bi-Encoder Retrieval hits")
hits = sorted(hits, key=lambda x: x['score'], reverse=True)
for hit in hits[0:5]:
# if hit['score'] > 0.45:
st.write("\t Score: {:.3f}: \t{}".format(hit['score'], paraList[hit['corpus_id']].replace("\n", " ")))
#st.table(hits[0:3]
#st.markdown("-------------------------")
#hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
#st.markdown("Top few Cross-Encoder Re-ranker hits")
#for hit in hits[0:3]:
# st.write("\t Score: {:.3f}: \t{}".format(hit['cross-score'], paraList[hit['corpus_id']].replace("\n", " ")))
#st.table(hits[0:3]
#for hit in bm25_hits[0:3]:
# print("\t{:.3f}\t{}".format(hit['score'], paraList[hit['corpus_id']].replace("\n", " ")))
# Output of top-5 hits from bi-encoder
#print("\n-------------------------\n")
#print("Top-3 Bi-Encoder Retrieval hits")
#hits = sorted(hits, key=lambda x: x['score'], reverse=True)
#for hit in hits[0:3]:
# print("\t{:.3f}\t{}".format(hit['score'], paraList[hit['corpus_id']].replace("\n", " ")))
# Output of top-5 hits from re-ranker
# print("\n-------------------------\n")
#print("Top-3 Cross-Encoder Re-ranker hits")
# hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
# for hit in hits[0:3]:
# print("\t{:.3f}\t{}".format(hit['cross-score'], paraList[hit['corpus_id']].replace("\n", " ")))