paragraphs / app.py
NealCaren's picture
Update app.py
642f7f2
raw
history blame
4.99 kB
import streamlit as st
import numpy as np
import pickle
from collections import OrderedDict
from sentence_transformers import SentenceTransformer, CrossEncoder, util
import torch
from nltk.tokenize import sent_tokenize
import nltk
nltk.download('punkt')
if not torch.cuda.is_available():
print("Warning: No GPU found. Please add GPU to your notebook")
import pandas as pd
st.title('Sociology Paragraph Search')
st.write('This page is a work-in-progress that allows you to search through articles recently published in a few sociology journals and retrieve the most relevant paragraphs. ')
st.markdown('''Notes:
* To get the best results, search like you are using Google. My best luck comes from phrases, such as "social movements and public opinion", "inequality in latin america", "race color skin tone measurement", "audit study experiment gender", "crenshaw intersectionality" or "logistic regression or linear probability model".
* The dataset currently includes only article published since 2016 in Social Forces, Social Problems, Sociology of Race and Ethnicity, Gender and Society, Socius, JHSB, and the American Sociological Review (approximately 100K paragraphs from 2K articles).
* The most relevant paragarph to your search is returned first, along with up to four other related paragraphs from that article.
* The most relevant sentence within each paragraph, as determined by math, is bolded.
* Behind the scenes, the semantic search uses [text embeddings](https://www.sbert.net) with a [retrieve & re-rank](https://colab.research.google.com/github/UKPLab/sentence-transformers/blob/master/examples/applications/retrieve_rerank/retrieve_rerank_simple_wikipedia.ipynb) process to find the best matches.
* Let [me](mailto:neal.caren@unc.edu) know what you think.
''')
def sent_trans_load():
#We use the Bi-Encoder to encode all passages, so that we can use it with sematic search
bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
bi_encoder.max_seq_length = 256 #Truncate long passages to 256 tokens, max 512
return bi_encoder
def sent_cross_load():
#We use the Bi-Encoder to encode all passages, so that we can use it with sematic search
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
return cross_encoder
@st.cache
def load_data():
dfs = [pd.read_json(f'passages_{i}.jsonl', lines=True) for i in range(0,5)]
df = pd.concat(dfs)
df.reset_index(inplace=True, drop=True)
return df
with st.spinner(text="Loading data..."):
df = load_data()
passages = df['text'].values
@st.cache
def load_embeddings():
efs = [np.load(f'embeddings_{i}.pt.npy') for i in range(0,5)]
corpus_embeddings = np.concatenate(efs)
return corpus_embeddings
with st.spinner(text="Loading embeddings..."):
corpus_embeddings = load_embeddings()
def search(query, top_k=40):
##### Sematic Search #####
# Encode the query using the bi-encoder and find potentially relevant passages
question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
hits = hits[0] # Get the hits for the first query
##### Re-Ranking #####
# Now, score all retrieved passages with the cross_encoder
cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits]
cross_scores = cross_encoder.predict(cross_inp)
# Sort results by the cross-encoder scores
for idx in range(len(cross_scores)):
hits[idx]['cross-score'] = cross_scores[idx]
# Output of top-5 hits from re-ranker
print("\n-------------------------\n")
print("Search Results")
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
hd = OrderedDict()
for hit in hits[0:20]:
row_id = hit['corpus_id']
cite = df.loc[row_id]['cite']
#graph = passages[row_id]
graph = df.loc[row_id]['text']
# Find best sentence
ab_sentences= [s for s in sent_tokenize(graph)]
cross_inp = [[query, s] for s in ab_sentences]
cross_scores = cross_encoder.predict(cross_inp)
thesis = pd.Series(cross_scores, ab_sentences).sort_values().index[-1]
graph = graph.replace(thesis, f'**{thesis}**')
if cite in hd:
hd[cite].append(graph)
else:
hd[cite] = [graph]
for cite, graphs in hd.items():
cite = cite.replace(", ", '. "').replace(', Social ', '", Social ')
st.write(cite)
for graph in graphs[:5]:
st.write(f'* {graph}')
st.write('')
# print("\t{:.3f}\t{}".format(hit['cross-score'], passages[hit['corpus_id']].replace("\n", " ")))
search_query = st.text_input('Enter your search phrase:')
if search_query!='':
with st.spinner(text="Searching and sorting results (may take up to 30 seconds)"):
bi_encoder = sent_trans_load()
cross_encoder = sent_cross_load()
search(search_query)