NealCaren commited on
Commit
ee21672
1 Parent(s): a1bce3f

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -0
app.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ import numpy as np
4
+
5
+ import pickle
6
+ from collections import OrderedDict
7
+
8
+ from sentence_transformers import SentenceTransformer, CrossEncoder, util
9
+ import torch
10
+
11
+ from nltk.tokenize import sent_tokenize
12
+
13
+ import nltk
14
+ nltk.download('punkt')
15
+
16
+
17
+ if not torch.cuda.is_available():
18
+ print("Warning: No GPU found. Please add GPU to your notebook")
19
+
20
+
21
+ import pandas as pd
22
+ st.title('Sociology Paragraph Search')
23
+
24
+ st.write('This page is a work-in-progress that allows you to search through articles recently published in a few sociology journals and retrieve the most relevant paragraphs. ')
25
+
26
+ st.markdown('''Notes:
27
+ * To get the best results, search like you are using Google. My best luck comes from phrases, such as "social movements and public opinion", "inequality in latin america", "race color skin tone measurement", "audit study experiment gender", "crenshaw intersectionality" or "logistic regression or linear probability model".
28
+ * The dataset currently includes only article published since 2016 in Social Forces, Social Problems, Sociology of Race and Ethnicity, Gender and Society, Socius, JHSB, and the American Sociological Review (approximately 100K paragraphs from 2K articles).
29
+ * The most relevant paragarph to your search is returned first, along with up to four other related paragraphs from that article.
30
+ * The most relevant sentence within each paragraph, as determined by math, is bolded.
31
+ * Behind the scenes, the semantic search uses [text embeddings](https://www.sbert.net) with a [retrieve & re-rank](https://colab.research.google.com/github/UKPLab/sentence-transformers/blob/master/examples/applications/retrieve_rerank/retrieve_rerank_simple_wikipedia.ipynb) process to find the best matches.
32
+ * Let [me](mailto:neal.caren@unc.edu) know what you think.
33
+ ''')
34
+
35
+
36
+ def sent_trans_load():
37
+ #We use the Bi-Encoder to encode all passages, so that we can use it with sematic search
38
+ bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
39
+ bi_encoder.max_seq_length = 256 #Truncate long passages to 256 tokens, max 512
40
+ return bi_encoder
41
+
42
+ def sent_cross_load():
43
+ #We use the Bi-Encoder to encode all passages, so that we can use it with sematic search
44
+ cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
45
+ return cross_encoder
46
+
47
+
48
+ @st.cache
49
+ def load_data():
50
+ dfs = [pd.read_json(f'data/passages_{i}.jsonl', lines=True) for i in range(0,5)]
51
+ df = pd.concat(dfs)
52
+ df.reset_index(inplace=True, drop=True)
53
+ return df
54
+
55
+
56
+ with st.spinner(text="Loading data..."):
57
+ df = load_data()
58
+ passages = df['text'].values
59
+
60
+ @st.cache
61
+ def load_embeddings():
62
+ efs = [np.load(f'data/embeddings_{i}.pt.npy') for i in range(0,5)]
63
+ corpus_embeddings = np.concatenate(efs)
64
+ return corpus_embeddings
65
+
66
+ with st.spinner(text="Loading embeddings..."):
67
+ corpus_embeddings = load_embeddings()
68
+
69
+
70
+
71
+
72
+
73
+ def search(query, top_k=40):
74
+
75
+ ##### Sematic Search #####
76
+ # Encode the query using the bi-encoder and find potentially relevant passages
77
+ question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
78
+
79
+
80
+ hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
81
+ hits = hits[0] # Get the hits for the first query
82
+ ##### Re-Ranking #####
83
+ # Now, score all retrieved passages with the cross_encoder
84
+ cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits]
85
+ cross_scores = cross_encoder.predict(cross_inp)
86
+
87
+ # Sort results by the cross-encoder scores
88
+ for idx in range(len(cross_scores)):
89
+ hits[idx]['cross-score'] = cross_scores[idx]
90
+
91
+ # Output of top-5 hits from re-ranker
92
+ print("\n-------------------------\n")
93
+ print("Search Results")
94
+ hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
95
+
96
+ hd = OrderedDict()
97
+ for hit in hits[0:20]:
98
+
99
+ row_id = hit['corpus_id']
100
+ cite = df.loc[row_id]['cite']
101
+ #graph = passages[row_id]
102
+ graph = df.loc[row_id]['text']
103
+
104
+ # Find best sentence
105
+ ab_sentences= [s for s in sent_tokenize(graph)]
106
+ cross_inp = [[query, s] for s in ab_sentences]
107
+ cross_scores = cross_encoder.predict(cross_inp)
108
+ thesis = pd.Series(cross_scores, ab_sentences).sort_values().index[-1]
109
+ graph = graph.replace(thesis, f'**{thesis}**')
110
+
111
+ if cite in hd:
112
+
113
+ hd[cite].append(graph)
114
+ else:
115
+ hd[cite] = [graph]
116
+
117
+ for cite, graphs in hd.items():
118
+ cite = cite.replace(", ", '. "').replace(', Social ', '", Social ')
119
+ st.write(cite)
120
+ for graph in graphs[:5]:
121
+ st.write(f'* {graph}')
122
+ st.write('')
123
+ # print("\t{:.3f}\t{}".format(hit['cross-score'], passages[hit['corpus_id']].replace("\n", " ")))
124
+
125
+
126
+
127
+ search_query = st.text_input('Enter your search phrase:')
128
+ if search_query!='':
129
+ with st.spinner(text="Searching and sorting results (may take up to 30 seconds)"):
130
+ bi_encoder = sent_trans_load()
131
+ cross_encoder = sent_cross_load()
132
+ search(search_query)