import streamlit as st import numpy as np import pickle from collections import OrderedDict from sentence_transformers import SentenceTransformer, CrossEncoder, util import torch from nltk.tokenize import sent_tokenize import nltk nltk.download('punkt') if not torch.cuda.is_available(): print("Warning: No GPU found. Please add GPU to your notebook") import pandas as pd st.title('Sociology Paragraph Search') st.write('This page is a work-in-progress that allows you to search through articles recently published in a few sociology journals and retrieve the most relevant paragraphs. ') st.markdown('''Notes: * To get the best results, search like you are using Google. My best luck comes from phrases, such as "social movements and public opinion", "inequality in latin america", "race color skin tone measurement", "audit study experiment gender", "crenshaw intersectionality" or "logistic regression or linear probability model". * The dataset currently includes only article published since 2016 in Social Forces, Social Problems, Sociology of Race and Ethnicity, Gender and Society, Socius, JHSB, and the American Sociological Review (approximately 100K paragraphs from 2K articles). * The most relevant paragarph to your search is returned first, along with up to four other related paragraphs from that article. * The most relevant sentence within each paragraph, as determined by math, is bolded. * Behind the scenes, the semantic search uses [text embeddings](https://www.sbert.net) with a [retrieve & re-rank](https://colab.research.google.com/github/UKPLab/sentence-transformers/blob/master/examples/applications/retrieve_rerank/retrieve_rerank_simple_wikipedia.ipynb) process to find the best matches. * Let [me](mailto:neal.caren@unc.edu) know what you think. ''') def sent_trans_load(): #We use the Bi-Encoder to encode all passages, so that we can use it with sematic search bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1') bi_encoder.max_seq_length = 256 #Truncate long passages to 256 tokens, max 512 return bi_encoder def sent_cross_load(): #We use the Bi-Encoder to encode all passages, so that we can use it with sematic search cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') return cross_encoder @st.cache def load_data(): dfs = [pd.read_json(f'data/passages_{i}.jsonl', lines=True) for i in range(0,5)] df = pd.concat(dfs) df.reset_index(inplace=True, drop=True) return df with st.spinner(text="Loading data..."): df = load_data() passages = df['text'].values @st.cache def load_embeddings(): efs = [np.load(f'data/embeddings_{i}.pt.npy') for i in range(0,5)] corpus_embeddings = np.concatenate(efs) return corpus_embeddings with st.spinner(text="Loading embeddings..."): corpus_embeddings = load_embeddings() def search(query, top_k=40): ##### Sematic Search ##### # Encode the query using the bi-encoder and find potentially relevant passages question_embedding = bi_encoder.encode(query, convert_to_tensor=True) hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k) hits = hits[0] # Get the hits for the first query ##### Re-Ranking ##### # Now, score all retrieved passages with the cross_encoder cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits] cross_scores = cross_encoder.predict(cross_inp) # Sort results by the cross-encoder scores for idx in range(len(cross_scores)): hits[idx]['cross-score'] = cross_scores[idx] # Output of top-5 hits from re-ranker print("\n-------------------------\n") print("Search Results") hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True) hd = OrderedDict() for hit in hits[0:20]: row_id = hit['corpus_id'] cite = df.loc[row_id]['cite'] #graph = passages[row_id] graph = df.loc[row_id]['text'] # Find best sentence ab_sentences= [s for s in sent_tokenize(graph)] cross_inp = [[query, s] for s in ab_sentences] cross_scores = cross_encoder.predict(cross_inp) thesis = pd.Series(cross_scores, ab_sentences).sort_values().index[-1] graph = graph.replace(thesis, f'**{thesis}**') if cite in hd: hd[cite].append(graph) else: hd[cite] = [graph] for cite, graphs in hd.items(): cite = cite.replace(", ", '. "').replace(', Social ', '", Social ') st.write(cite) for graph in graphs[:5]: st.write(f'* {graph}') st.write('') # print("\t{:.3f}\t{}".format(hit['cross-score'], passages[hit['corpus_id']].replace("\n", " "))) search_query = st.text_input('Enter your search phrase:') if search_query!='': with st.spinner(text="Searching and sorting results (may take up to 30 seconds)"): bi_encoder = sent_trans_load() cross_encoder = sent_cross_load() search(search_query)