|
import requests |
|
from sentence_transformers import SentenceTransformer, CrossEncoder, util |
|
import os, re |
|
import torch |
|
from rank_bm25 import BM25Okapi |
|
from sklearn.feature_extraction import _stop_words |
|
import string |
|
import numpy as np |
|
import pandas as pd |
|
from newspaper import Article |
|
import base64 |
|
import docx2txt |
|
from io import StringIO |
|
from PyPDF2 import PdfFileReader |
|
import validators |
|
import nltk |
|
import warnings |
|
import streamlit as st |
|
from PIL import Image |
|
|
|
|
|
nltk.download('punkt') |
|
|
|
from nltk import sent_tokenize |
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
def extract_text_from_url(url: str): |
|
|
|
'''Extract text from url''' |
|
|
|
article = Article(url) |
|
article.download() |
|
article.parse() |
|
|
|
|
|
text = article.text |
|
|
|
|
|
title = article.title |
|
|
|
return title, text |
|
|
|
def extract_text_from_file(file): |
|
|
|
'''Extract text from uploaded file''' |
|
|
|
|
|
if file.type == "text/plain": |
|
|
|
stringio = StringIO(file.getvalue().decode("utf-8")) |
|
|
|
|
|
file_text = stringio.read() |
|
|
|
return file_text, None |
|
|
|
|
|
elif file.type == "application/pdf": |
|
pdfReader = PdfFileReader(file) |
|
count = pdfReader.numPages |
|
all_text = "" |
|
pdf_title = pdfReader.getDocumentInfo().title |
|
|
|
for i in range(count): |
|
|
|
try: |
|
page = pdfReader.getPage(i) |
|
all_text += page.extractText() |
|
|
|
except: |
|
continue |
|
|
|
file_text = all_text |
|
|
|
return file_text, pdf_title |
|
|
|
|
|
elif ( |
|
file.type |
|
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document" |
|
): |
|
file_text = docx2txt.process(file) |
|
|
|
return file_text, None |
|
|
|
def preprocess_plain_text(text,window_size=3): |
|
|
|
text = text.encode("ascii", "ignore").decode() |
|
text = re.sub(r"https*\S+", " ", text) |
|
text = re.sub(r"@\S+", " ", text) |
|
text = re.sub(r"#\S+", " ", text) |
|
text = re.sub(r"\s{2,}", " ", text) |
|
|
|
|
|
|
|
lines = [line.strip() for line in text.splitlines()] |
|
|
|
|
|
chunks = [phrase.strip() for line in lines for phrase in line.split(" ")] |
|
|
|
|
|
text = '\n'.join(chunk for chunk in chunks if chunk) |
|
|
|
|
|
paragraphs = [] |
|
for paragraph in text.replace('\n',' ').split("\n\n"): |
|
if len(paragraph.strip()) > 0: |
|
paragraphs.append(sent_tokenize(paragraph.strip())) |
|
|
|
|
|
|
|
|
|
window_size = window_size |
|
passages = [] |
|
for paragraph in paragraphs: |
|
for start_idx in range(0, len(paragraph), window_size): |
|
end_idx = min(start_idx+window_size, len(paragraph)) |
|
passages.append(" ".join(paragraph[start_idx:end_idx])) |
|
|
|
st.write(f"Sentences: {sum([len(p) for p in paragraphs])}") |
|
st.write(f"Passages: {len(passages)}") |
|
|
|
return passages |
|
|
|
@st.experimental_memo(suppress_st_warning=True) |
|
def bi_encode(bi_enc,passages): |
|
|
|
global bi_encoder |
|
|
|
bi_encoder = SentenceTransformer(bi_enc) |
|
|
|
|
|
|
|
|
|
|
|
with st.spinner('Encoding passages into a vector space...'): |
|
|
|
corpus_embeddings = bi_encoder.encode(passages, convert_to_tensor=True, show_progress_bar=True) |
|
|
|
st.success(f"Embeddings computed. Shape: {corpus_embeddings.shape}") |
|
|
|
return bi_encoder, corpus_embeddings |
|
|
|
@st.experimental_singleton(allow_output_mutation=True) |
|
def cross_encode(): |
|
|
|
global cross_encoder |
|
|
|
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2') |
|
return cross_encoder |
|
|
|
@st.experimental_memo(allow_output_mutation=True) |
|
def bm25_tokenizer(text): |
|
|
|
|
|
|
|
|
|
tokenized_doc = [] |
|
for token in text.lower().split(): |
|
token = token.strip(string.punctuation) |
|
|
|
if len(token) > 0 and token not in _stop_words.ENGLISH_STOP_WORDS: |
|
tokenized_doc.append(token) |
|
return tokenized_doc |
|
|
|
@st.experimental_singleton(allow_output_mutation=True) |
|
def bm25_api(passages): |
|
|
|
tokenized_corpus = [] |
|
|
|
for passage in passages: |
|
tokenized_corpus.append(bm25_tokenizer(passage)) |
|
|
|
bm25 = BM25Okapi(tokenized_corpus) |
|
|
|
return bm25 |
|
|
|
bi_enc_options = ["multi-qa-mpnet-base-dot-v1","all-mpnet-base-v2","multi-qa-MiniLM-L6-cos-v1"] |
|
|
|
def display_df_as_table(model,top_k,score='score'): |
|
|
|
df = pd.DataFrame([(hit[score],passages[hit['corpus_id']]) for hit in model[0:top_k]],columns=['Score','Text']) |
|
df['Score'] = round(df['Score'],2) |
|
|
|
return df |
|
|
|
|
|
|
|
st.title("Semantic Search with Retrieve & Rerank 📝") |
|
|
|
""" |
|
[![](https://img.shields.io/twitter/follow/nickmuchi?label=@nickmuchi&style=social)](https://twitter.com/nickmuchi) |
|
""" |
|
|
|
window_size = st.sidebar.slider("Paragraph Window Size",min_value=1,max_value=10,value=3,key= |
|
'slider') |
|
|
|
bi_encoder_type = st.sidebar.selectbox("Bi-Encoder", options=bi_enc_options, key='sbox') |
|
|
|
top_k = st.sidebar.slider("Number of Top Hits Generated",min_value=1,max_value=5,value=2) |
|
|
|
|
|
|
|
def search_func(query, top_k=top_k): |
|
|
|
global bi_encoder, cross_encoder |
|
|
|
st.subheader(f"Search Query: {query}") |
|
|
|
if url_text: |
|
|
|
st.write(f"Document Header: {title}") |
|
|
|
elif pdf_title: |
|
|
|
st.write(f"Document Header: {pdf_title}") |
|
|
|
|
|
bm25_scores = bm25.get_scores(bm25_tokenizer(query)) |
|
top_n = np.argpartition(bm25_scores, -5)[-5:] |
|
bm25_hits = [{'corpus_id': idx, 'score': bm25_scores[idx]} for idx in top_n] |
|
bm25_hits = sorted(bm25_hits, key=lambda x: x['score'], reverse=True) |
|
|
|
st.subheader(f"Top-{top_k} lexical search (BM25) hits") |
|
|
|
bm25_df = display_df_as_table(bm25_hits,top_k) |
|
st.write(bm25_df.to_html(index=False), unsafe_allow_html=True) |
|
|
|
|
|
|
|
question_embedding = bi_encoder.encode(query, convert_to_tensor=True) |
|
question_embedding = question_embedding.cpu() |
|
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k,score_function=util.dot_score) |
|
hits = hits[0] |
|
|
|
|
|
|
|
cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits] |
|
cross_scores = cross_encoder.predict(cross_inp) |
|
|
|
|
|
for idx in range(len(cross_scores)): |
|
hits[idx]['cross-score'] = cross_scores[idx] |
|
|
|
|
|
st.markdown("\n-------------------------\n") |
|
st.subheader(f"Top-{top_k} Bi-Encoder Retrieval hits") |
|
hits = sorted(hits, key=lambda x: x['score'], reverse=True) |
|
|
|
cross_df = display_df_as_table(hits,top_k) |
|
st.write(cross_df.to_html(index=False), unsafe_allow_html=True) |
|
|
|
|
|
st.markdown("\n-------------------------\n") |
|
st.subheader(f"Top-{top_k} Cross-Encoder Re-ranker hits") |
|
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True) |
|
|
|
rerank_df = display_df_as_table(hits,top_k,'cross-score') |
|
st.write(rerank_df.to_html(index=False), unsafe_allow_html=True) |
|
|
|
st.markdown( |
|
""" |
|
- The app supports asymmetric Semantic search which seeks to improve search accuracy of documents/URL by understanding the content of the search query in contrast to traditional search engines which only find documents based on lexical matches. |
|
- The idea behind semantic search is to embed all entries in your corpus, whether they be sentences, paragraphs, or documents, into a vector space. At search time, the query is embedded into the same vector space and the closest embeddings from your corpus are found. These entries should have a high semantic overlap with the query. |
|
- The all-* models where trained on all available training data (more than 1 billion training pairs) and are designed as general purpose models. The all-mpnet-base-v2 model provides the best quality, while all-MiniLM-L6-v2 is 5 times faster and still offers good quality. The models used have been trained on broad datasets, however, if your document/corpus is specialised, such as for science or economics, the results returned might be unsatisfactory.""") |
|
|
|
st.markdown("""There models available to choose from:""") |
|
|
|
st.markdown( |
|
""" |
|
Model Source: |
|
- Bi-Encoders - [multi-qa-mpnet-base-dot-v1](https://huggingface.co/sentence-transformers/multi-qa-mpnet-base-dot-v1), [all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2), [multi-qa-MiniLM-L6-cos-v1](https://huggingface.co/sentence-transformers/multi-qa-MiniLM-L6-cos-v1) and [all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) |
|
- Cross-Encoder - [cross-encoder/ms-marco-MiniLM-L-12-v2](https://huggingface.co/cross-encoder/ms-marco-MiniLM-L-12-v2)""") |
|
|
|
st.markdown( |
|
""" |
|
Code and App Inspiration Source: [Sentence Transformers](https://www.sbert.net/examples/applications/retrieve_rerank/README.html)""") |
|
|
|
st.markdown( |
|
""" |
|
Quick summary of the purposes of a Bi and Cross-encoder below, the image and info were adapted from [www.sbert.net](https://www.sbert.net/examples/applications/semantic-search/README.html):""") |
|
|
|
st.markdown( |
|
""" |
|
- Bi-Encoder (Retrieval): The Bi-encoder is responsible for independently embedding the sentences and search queries into a vector space. The result is then passed to the cross-encoder for checking the relevance/similarity between the query and sentences. |
|
- Cross-Encoder (Re-Ranker): A re-ranker based on a Cross-Encoder can substantially improve the final results for the user. The query and a possible document is passed simultaneously to transformer network, which then outputs a single score between 0 and 1 indicating how relevant the document is for the given query. The cross-encoder further boost the performance, especially when you search over a corpus for which the bi-encoder was not trained for.""") |
|
|
|
st.image(Image.open('encoder.png'), caption='Retrieval and Re-Rank') |
|
|
|
st.markdown(""" |
|
In order to use the app: |
|
- Select the preferred Sentence Transformer model (Bi-Encoder). |
|
- Select the number of sentences per paragraph to partition your corpus (Window-Size), if you choose a small value the context from the other sentences might get lost and for larger values the results might take longer to generate. |
|
- Select the number of top hits to be generated. |
|
- Paste the URL with your corpus or upload your preferred document in txt, pdf or Word format. |
|
- Semantic Search away!! """ |
|
) |
|
|
|
st.markdown("---") |
|
|
|
def clear_text(): |
|
st.session_state["text_url"] = "" |
|
st.session_state["text_input"]= "" |
|
|
|
def clear_search_text(): |
|
st.session_state["text_input"]= "" |
|
|
|
url_text = st.text_input("Please Enter a url here",value="https://www.rba.gov.au/monetary-policy/rba-board-minutes/2022/2022-05-03.html",key='text_url',on_change=clear_search_text) |
|
|
|
st.markdown( |
|
"<h3 style='text-align: center; color: red;'>OR</h3>", |
|
unsafe_allow_html=True, |
|
) |
|
|
|
upload_doc = st.file_uploader("Upload a .txt, .pdf, .docx file",key="upload") |
|
|
|
search_query = st.text_input("Please Enter your search query here",value="What are the expectations for inflation for Australia?",key="text_input") |
|
|
|
if validators.url(url_text): |
|
|
|
title, text = extract_text_from_url(url_text) |
|
passages = preprocess_plain_text(text,window_size=window_size) |
|
|
|
elif upload_doc: |
|
|
|
text, pdf_title = extract_text_from_file(upload_doc) |
|
passages = preprocess_plain_text(text,window_size=window_size) |
|
|
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
search = st.button("Search",key='search_but', help='Click to Search!!') |
|
|
|
with col2: |
|
clear = st.button("Clear Text Input", on_click=clear_text,key='clear',help='Click to clear the URL input and search query') |
|
|
|
if search: |
|
if bi_encoder_type: |
|
|
|
with st.spinner( |
|
text=f"Loading {bi_encoder_type} bi-encoder and embedding document into vector space. This might take a few seconds depending on the length of your document..." |
|
): |
|
bi_encoder, corpus_embeddings = bi_encode(bi_encoder_type,passages) |
|
cross_encoder = cross_encode() |
|
bm25 = bm25_api(passages) |
|
|
|
with st.spinner( |
|
text="Embedding completed, searching for relevant text for given query and hits..."): |
|
|
|
search_func(search_query,top_k) |
|
|
|
st.markdown(""" |
|
""") |
|
|
|
st.markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=nickmuchi-semantic-search-with-retrieve-and-rerank)") |