sm_hunt / app.py
NealCaren's picture
Upload app.py
90399d1
import streamlit as st
import numpy as np
import re
import pickle
from collections import OrderedDict
import io
from sentence_transformers import SentenceTransformer, CrossEncoder, util
import torch
from nltk.tokenize import sent_tokenize
import nltk
import gdown
import requests
from PIL import Image
# Trying to figure out some CSS stuff
st.markdown(
"""
<style>
.streamlit-expanderHeader {
font-size: medium;
}
</style>
""",
unsafe_allow_html=True,
)
nltk.download('punkt')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import pandas as pd
@st.cache_data
def load_embeddings():
url = "https://drive.google.com/uc?export=download&id=14y-RQ18IQ3tP7p9iMTeDKSsvFAKz1bLv"
output = "embeddings.npy"
gdown.download(url, output, quiet=False)
corpus_embeddings = np.load(output)
return corpus_embeddings
@st.cache_data
def load_data():
url = "https://drive.google.com/uc?export=download&id=1--6zc38C-FfIb-C4BMG87Bvx947Z1UNO"
output = "passages.jsonl"
gdown.download(url, output, quiet=False)
df = pd.read_json(output, lines=True)
df.reset_index(inplace=True, drop=True)
return df
st.title('Related Social Movement Articles')
st.write('This project is a work-in-progress that searches the abstracts of recently-published articles related to social movements and retrieves the most relevant articles.')
with st.spinner(text="Loading data..."):
df = load_data()
passages = df['Abstract'].values
no_of_graphs=len(df)
no_of_articles = len(df['cite'].value_counts())
notes = f'''Notes:
* I have found three types of searches work best:
* Phrases or specific topics, such as "inequality in latin america", "race color skin tone measurement", "audit study experiment gender", or "logistic regression or linear probability model".
* Citations to well-known works, either using author year ("bourdieu 1984") or author idea ("Crenshaw intersectionality")
* Questions, like "What is a topic model?" or "How did Weber define bureaucracy?"
* The search expands beyond exact matching, so "asia social movements" may return paragraphs on Asian-Americans politics and South Korean labor unions.
* The first search can take up to 10 seconds as the files load. After that, it's quicker to respond.
* The most relevant paragraph to your search is returned first, along with up to four other related paragraphs from that article.
* The most relevant sentence within each paragraph, as determined by math, is displayed. Click on it to see the full paragraph.
* The results are not exhaustive, and seem to drift off even when you suspect there are more relevant articles :man-shrugging:.
* The dataset currently includes {no_of_graphs:,} paragraphs from {no_of_articles:,} published in the last five years in *Mobilization*, *Social Forces*, *Social Problems*, *Sociology of Race and Ethnicity*, *Gender and Society*, *Socius*, *JHSB*, *Annual Review of Sociology*, and the *American Sociological Review*.
* Behind the scenes, the semantic search uses [text embeddings](https://www.sbert.net) with a [retrieve & re-rank](https://colab.research.google.com/github/UKPLab/sentence-transformers/blob/master/examples/applications/retrieve_rerank/retrieve_rerank_simple_wikipedia.ipynb) process to find the best matches.
* Let [me](mailto:neal.caren@unc.edu) know what you think or it looks broken.
'''
# st.markdown(notes)
def sent_trans_load():
#We use the Bi-Encoder to encode all passages, so that we can use it with sematic search
bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
bi_encoder.max_seq_length = 256 #Truncate long passages to 256 tokens, max 512
return bi_encoder
def sent_cross_load():
#We use the Bi-Encoder to encode all passages, so that we can use it with sematic search
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
return cross_encoder
with st.spinner(text="Loading embeddings..."):
corpus_embeddings = load_embeddings()
def search(query, top_k=50):
##### Sematic Search #####
# Encode the query using the bi-encoder and find potentially relevant passages
question_embedding = bi_encoder.encode(query, convert_to_tensor=True).to(device)
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
hits = hits[0] # Get the hits for the first query
##### Re-Ranking #####
# Now, score all retrieved passages with the cross_encoder
cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits]
cross_scores = cross_encoder.predict(cross_inp)
# Sort results by the cross-encoder scores
for idx in range(len(cross_scores)):
hits[idx]['cross-score'] = cross_scores[idx]
# Output of top-5 hits from re-ranker
print("\n-------------------------\n")
print("Search Results")
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
hd = OrderedDict()
for hit in hits[0:30]:
row_id = hit['corpus_id']
cite = df.loc[row_id]['cite']
#graph = passages[row_id]
graph = df.loc[row_id]['Abstract']
# Find best sentence
ab_sentences= [s for s in sent_tokenize(graph)]
cross_inp = [[query, s] for s in ab_sentences]
cross_scores = cross_encoder.predict(cross_inp)
thesis = pd.Series(cross_scores, ab_sentences).sort_values().index[-1]
graph = graph.replace(thesis, f'**{thesis}**')
if cite in hd:
hd[cite].append(graph)
else:
hd[cite] = [graph]
for cite, graphs in hd.items():
st.markdown(cite)
for graph in graphs[:5]:
# refind the Thesis
thesis = re.findall('\*\*(.*?)\*\*', graph)[0]
#thesis = graph.split('.')[0]
with st.expander(thesis):
st.write(f'> {graph}')
st.write('')
# print("\t{:.3f}\t{}".format(hit['cross-score'], passages[hit['corpus_id']].replace("\n", " ")))
search_query = st.text_area('Enter abstract or search phrase:')
if search_query!='':
with st.spinner(text="Searching and sorting results."):
placeholder = st.empty()
with placeholder.container():
st.image('https://www.dropbox.com/s/yndn6lkesjga9a6/emerac.png?raw=1')
bi_encoder = sent_trans_load()
cross_encoder = sent_cross_load()
search(search_query)
placeholder.empty()