domenicrosati commited on
Commit
4c36cd4
Β·
1 Parent(s): 8890bde

add strict relevancy and scite badges and reranking

Browse files
Files changed (3) hide show
  1. README.md +0 -2
  2. app.py +94 -27
  3. requirements.txt +3 -0
README.md CHANGED
@@ -9,5 +9,3 @@ app_file: app.py
9
  pinned: false
10
  license: cc-by-2.0
11
  ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
9
  pinned: false
10
  license: cc-by-2.0
11
  ---
 
 
app.py CHANGED
@@ -2,15 +2,38 @@ import streamlit as st
2
  from transformers import pipeline
3
  import requests
4
  from bs4 import BeautifulSoup
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  SCITE_API_KEY = st.secrets["SCITE_API_KEY"]
7
 
 
8
  def remove_html(x):
9
  soup = BeautifulSoup(x, 'html.parser')
10
  text = soup.get_text()
11
  return text
12
 
13
- def search(term, limit=25):
 
 
 
14
  search = f"https://api.scite.ai/search?mode=citations&term={term}&limit={limit}&offset=0&user_slug=domenic-rosati-keW5&compute_aggregations=false"
15
  req = requests.get(
16
  search,
@@ -19,8 +42,9 @@ def search(term, limit=25):
19
  }
20
  )
21
  return (
22
- remove_html('\n'.join(['\n'.join([cite['snippet'] for cite in doc['citations']]) for doc in req.json()['hits']])),
23
- [(doc['doi'], doc['citations'], doc['title']) for doc in req.json()['hits']]
 
24
  )
25
 
26
 
@@ -39,25 +63,37 @@ def find_source(text, docs):
39
  'source_title': doc[2],
40
  'source_link': f"https://scite.ai/reports/{doc[0]}"
41
  }
42
- return {
43
- 'citation_statement': '',
44
- 'text': text,
45
- 'from': '',
46
- 'supporting': '',
47
- 'source_title': '',
48
- 'source_link': ''
49
- }
50
 
51
  @st.experimental_singleton
52
  def init_models():
53
- question_answerer = pipeline("question-answering", model='sultan/BioM-ELECTRA-Large-SQuAD2-BioASQ8B')
54
- return question_answerer
 
 
 
 
 
 
 
 
 
55
 
56
- qa_model = init_models()
 
 
 
 
 
 
 
 
57
 
58
 
59
- def card(title, context, score, link):
60
- return st.markdown(f"""
 
61
  <div class="container-fluid">
62
  <div class="row align-items-start">
63
  <div class="col-md-12 col-sm-12">
@@ -72,6 +108,22 @@ def card(title, context, score, link):
72
  </div>
73
  </div>
74
  """, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  st.title("Scientific Question Answering with Citations")
77
 
@@ -85,8 +137,14 @@ st.markdown("""
85
  """, unsafe_allow_html=True)
86
 
87
  def run_query(query):
88
- context, orig_docs = search(query)
89
- if not context.strip():
 
 
 
 
 
 
90
  return st.markdown("""
91
  <div class="container-fluid">
92
  <div class="row align-items-start">
@@ -97,35 +155,44 @@ def run_query(query):
97
  </div>
98
  """, unsafe_allow_html=True)
99
 
 
 
 
 
 
 
100
  results = []
101
  model_results = qa_model(question=query, context=context, top_k=10)
102
  for result in model_results:
103
  support = find_source(result['answer'], orig_docs)
 
 
104
  results.append({
105
  "answer": support['text'],
106
  "title": support['source_title'],
107
  "link": support['source_link'],
108
  "context": support['citation_statement'],
109
- "score": result['score']
 
110
  })
111
 
112
-
113
-
114
  sorted_result = sorted(results, key=lambda x: x['score'], reverse=True)
115
  sorted_result = list({
116
  result['context']: result for result in sorted_result
117
  }.values())
118
- sorted_result = sorted(sorted_result, key=lambda x: x['score'], reverse=True)
119
-
120
 
121
  for r in sorted_result:
122
  answer = r["answer"]
123
- ctx = remove_html(r["context"]).replace(answer, f"<mark>{answer}</mark>").replace('<cite', '<a').replace('</cite', '</a').replace('data-doi="', 'href="https://scite.ai/reports/')
124
- title = r["title"].replace("_", " ")
 
125
  score = round(r["score"], 4)
126
- card(title, ctx, score, r['link'])
127
 
128
  query = st.text_input("Ask scientific literature a question", "")
129
 
130
  if query != "":
131
- run_query(query)
 
 
2
  from transformers import pipeline
3
  import requests
4
  from bs4 import BeautifulSoup
5
+ from nltk.corpus import stopwords
6
+ import nltk
7
+ import string
8
+ from streamlit.components.v1 import html
9
+ from sentence_transformers.cross_encoder import CrossEncoder as CE
10
+ import numpy as np
11
+ from typing import List, Tuple
12
+ import torch
13
+
14
+ class CrossEncoder:
15
+ def __init__(self, model_path: str, **kwargs):
16
+ self.model = CE(model_path, **kwargs)
17
+
18
+ def predict(self, sentences: List[Tuple[str,str]], batch_size: int = 32, show_progress_bar: bool = True) -> List[float]:
19
+ return self.model.predict(
20
+ sentences=sentences,
21
+ batch_size=batch_size,
22
+ show_progress_bar=show_progress_bar)
23
+
24
 
25
  SCITE_API_KEY = st.secrets["SCITE_API_KEY"]
26
 
27
+
28
  def remove_html(x):
29
  soup = BeautifulSoup(x, 'html.parser')
30
  text = soup.get_text()
31
  return text
32
 
33
+
34
+ def search(term, limit=10, clean=True, strict=True):
35
+ term = clean_query(term, clean=clean, strict=strict)
36
+ # heuristic, 2 searches strict and not? and then merge?
37
  search = f"https://api.scite.ai/search?mode=citations&term={term}&limit={limit}&offset=0&user_slug=domenic-rosati-keW5&compute_aggregations=false"
38
  req = requests.get(
39
  search,
 
42
  }
43
  )
44
  return (
45
+ [remove_html('\n'.join([cite['snippet'] for cite in doc['citations']])) for doc in req.json()['hits']],
46
+ [(doc['doi'], doc['citations'], doc['title'])
47
+ for doc in req.json()['hits']]
48
  )
49
 
50
 
 
63
  'source_title': doc[2],
64
  'source_link': f"https://scite.ai/reports/{doc[0]}"
65
  }
66
+ return None
67
+
 
 
 
 
 
 
68
 
69
  @st.experimental_singleton
70
  def init_models():
71
+ nltk.download('stopwords')
72
+ stop = set(stopwords.words('english') + list(string.punctuation))
73
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
74
+ question_answerer = pipeline(
75
+ "question-answering", model='sultan/BioM-ELECTRA-Large-SQuAD2-BioASQ8B',
76
+ device=device
77
+ )
78
+ reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', device=device)
79
+ return question_answerer, reranker, stop, device
80
+
81
+ qa_model, reranker, stop, device = init_models()
82
 
83
+ def clean_query(query, strict=True, clean=True):
84
+ operator = ' '
85
+ if strict:
86
+ operator = ' AND '
87
+ query = operator.join(
88
+ [i for i in query.lower().split(' ') if clean and i not in stop])
89
+ if clean:
90
+ query = query.translate(str.maketrans('', '', string.punctuation))
91
+ return query
92
 
93
 
94
+
95
+ def card(title, context, score, link, supporting):
96
+ st.markdown(f"""
97
  <div class="container-fluid">
98
  <div class="row align-items-start">
99
  <div class="col-md-12 col-sm-12">
 
108
  </div>
109
  </div>
110
  """, unsafe_allow_html=True)
111
+ html(f"""
112
+ <div
113
+ class="scite-badge"
114
+ data-doi="{supporting}"
115
+ data-layout="horizontal"
116
+ data-show-zero="false"
117
+ data-show-labels="false"
118
+ data-tally-show="true"
119
+ />
120
+ <script
121
+ async
122
+ type="application/javascript"
123
+ src="https://cdn.scite.ai/badge/scite-badge-latest.min.js">
124
+ </script>
125
+ """, width=None, height=42, scrolling=False)
126
+
127
 
128
  st.title("Scientific Question Answering with Citations")
129
 
 
137
  """, unsafe_allow_html=True)
138
 
139
  def run_query(query):
140
+ if device == 'cpu':
141
+ limit = 50
142
+ context_limit = 10
143
+ else:
144
+ limit = 100
145
+ context_limit = 25
146
+ contexts, orig_docs = search(query, limit=limit)
147
+ if len(contexts) == 0 or not ''.join(contexts).strip():
148
  return st.markdown("""
149
  <div class="container-fluid">
150
  <div class="row align-items-start">
 
155
  </div>
156
  """, unsafe_allow_html=True)
157
 
158
+ sentence_pairs = [[query, context] for context in contexts]
159
+ scores = reranker.predict(sentence_pairs, batch_size=limit, show_progress_bar=False)
160
+ hits = {contexts[idx]: scores[idx] for idx in range(len(scores))}
161
+ sorted_contexts = [k for k,v in sorted(hits.items(), key=lambda x: x[0], reverse=True)]
162
+
163
+ context = '\n'.join(sorted_contexts[:context_limit])
164
  results = []
165
  model_results = qa_model(question=query, context=context, top_k=10)
166
  for result in model_results:
167
  support = find_source(result['answer'], orig_docs)
168
+ if not support:
169
+ continue
170
  results.append({
171
  "answer": support['text'],
172
  "title": support['source_title'],
173
  "link": support['source_link'],
174
  "context": support['citation_statement'],
175
+ "score": result['score'],
176
+ "doi": support["supporting"]
177
  })
178
 
 
 
179
  sorted_result = sorted(results, key=lambda x: x['score'], reverse=True)
180
  sorted_result = list({
181
  result['context']: result for result in sorted_result
182
  }.values())
183
+ sorted_result = sorted(
184
+ sorted_result, key=lambda x: x['score'], reverse=True)
185
 
186
  for r in sorted_result:
187
  answer = r["answer"]
188
+ ctx = remove_html(r["context"]).replace(answer, f"<mark>{answer}</mark>").replace(
189
+ '<cite', '<a').replace('</cite', '</a').replace('data-doi="', 'href="https://scite.ai/reports/')
190
+ title = r.get("title", '').replace("_", " ")
191
  score = round(r["score"], 4)
192
+ card(title, ctx, score, r['link'], r['doi'])
193
 
194
  query = st.text_input("Ask scientific literature a question", "")
195
 
196
  if query != "":
197
+ with st.spinner('Loading...'):
198
+ run_query(query)
requirements.txt CHANGED
@@ -3,3 +3,6 @@ requests
3
  beautifulsoup4
4
  streamlit==1.2.0
5
  torch
 
 
 
 
3
  beautifulsoup4
4
  streamlit==1.2.0
5
  torch
6
+ nltk
7
+ sentence_transformers
8
+ numpy