ola13 commited on
Commit
014aa64
β€’
1 Parent(s): 2cc82f7
Files changed (1) hide show
  1. app.py +165 -177
app.py CHANGED
@@ -1,10 +1,11 @@
1
  import json
2
  import os
 
 
 
3
  import gradio as gr
4
  import requests
5
  from huggingface_hub import HfApi
6
- import traceback
7
-
8
 
9
  hf_api = HfApi()
10
  roots_datasets = {
@@ -54,35 +55,40 @@ def process_pii(text):
54
  return text
55
 
56
 
57
- def format_meta(result):
58
- meta_html = (
59
- """
60
- <p class='underline-on-hover' style='font-size:12px; font-family: Arial; color:#585858; text-align: left;'>
61
- <a href='{}' target='_blank'>{}</a></p>""".format(
62
- result["meta"]["url"], result["meta"]["url"]
 
 
 
 
 
 
 
 
 
 
63
  )
64
- if "meta" in result and result["meta"] is not None and "url" in result["meta"]
65
- else ""
66
- )
67
- docid_html = get_docid_html(result["docid"])
68
- return """{}
69
- <p style='font-size:14px; font-family: Arial; color:#7978FF; text-align: left;'>Document ID: {}</p>
70
- <p style='font-size:12px; font-family: Arial; color:MediumAquaMarine'>Language: {}</p>
71
- """.format(
72
- meta_html,
73
- docid_html,
74
- result["lang"] if lang in result else None,
75
- )
76
- return meta_html
77
 
 
 
 
 
78
 
79
- def process_results(results, highlight_terms):
80
- if len(results) == 0:
81
- return """<br><p style='font-family: Arial; color:Silver; text-align: center;'>
82
- No results retrieved.</p><br><hr>"""
83
- results_html = ""
84
- for result in results:
85
- tokens = result["text"].split()
 
 
 
 
86
  tokens_html = []
87
  for token in tokens:
88
  if token in highlight_terms:
@@ -90,172 +96,131 @@ def process_results(results, highlight_terms):
90
  else:
91
  tokens_html.append(token)
92
  tokens_html = " ".join(tokens_html)
93
- tokens_html = process_pii(tokens_html)
94
- meta_html = format_meta(result)
95
- meta_html += """
96
- <p style='font-family: Arial;'>{}</p>
97
- <br>
98
- """.format(
99
- tokens_html
100
- )
101
- results_html += meta_html
102
- return results_html + "<hr>"
103
-
104
 
105
- def process_exact_match_payload(payload, query):
106
- datasets = set()
107
- results = payload["results"]
108
- results_html = (
109
- "<p style='font-family: Arial;'>Total nubmer of results: {}</p>".format(
110
- payload["num_results"]
111
  )
 
 
 
 
 
 
 
 
 
 
 
 
112
  )
113
- for result in results:
114
- _, dataset, _ = result["docid"].split("/")
115
- datasets.add(dataset)
116
- text = result["text"]
117
- meta_html = format_meta(result)
118
-
119
- query_start = text.find(query)
120
- query_end = query_start + len(query)
121
- tokens_html = text[0:query_start]
122
- tokens_html += "<b>{}</b>".format(text[query_start:query_end])
123
- tokens_html += text[query_end:]
124
- result_html = (
125
- meta_html
126
- + """
127
- <p style='font-family: Arial;'>{}</p>
128
- <br>
129
- """.format(
130
- tokens_html
131
- )
132
- )
133
- results_html += result_html
134
- return results_html + "<hr>", list(datasets)
135
-
136
 
137
- def process_bm25_match_payload(payload, language):
138
- if "err" in payload:
139
- if payload["err"]["type"] == "unsupported_lang":
140
- detected_lang = payload["err"]["meta"]["detected_lang"]
141
- return f"""
142
- <p style='font-size:18px; font-family: Arial; color:MediumVioletRed; text-align: center;'>
143
- Detected language <b>{detected_lang}</b> is not supported.<br>
144
- Please choose a language from the dropdown or type another query.
145
- </p><br><hr><br>"""
146
 
147
- results = payload["results"]
148
- highlight_terms = payload["highlight_terms"]
 
 
149
 
150
- if language == "detect_language":
151
- return (
152
- (
153
- (
154
- f"""<p style='font-family: Arial; color:MediumAquaMarine; text-align: center; line-height: 3em'>
155
- Detected language: <b>{results[0]["lang"]}</b></p><br><hr><br>"""
156
- if len(results) > 0 and language == "detect_language"
157
- else ""
158
- )
159
- + process_results(results, highlight_terms)
160
- ),
161
- [],
162
- )
163
 
164
- if language == "all":
165
- datasets = set()
166
- get_docid_html(result["docid"])
167
- results_html = ""
168
- for lang, results_for_lang in results.items():
169
- if len(results_for_lang) == 0:
170
  results_html += f"""<p style='font-family: Arial; color:Silver; text-align: left; line-height: 3em'>
171
- No results for language: <b>{lang}</b><hr></p>"""
172
- continue
173
-
174
- collapsible_results = f"""
 
 
 
 
 
 
 
 
175
  <details>
176
  <summary style='font-family: Arial; color:MediumAquaMarine; text-align: left; line-height: 3em'>
177
  Results for language: <b>{lang}</b><hr>
178
  </summary>
179
- {process_results(results_for_lang, highlight_terms)}
180
  </details>"""
181
- results_html += collapsible_results
182
- for r in results_for_lang:
183
- _, dataset, _ = r["docid"].split("/")
184
- datasets.add(dataset)
185
- return results_html, list(datasets)
186
 
187
- datasets = set()
188
- for r in results:
189
- _, dataset, _ = r["docid"].split("/")
190
- datasets.add(dataset)
191
- return process_results(results, highlight_terms), list(datasets)
192
 
193
 
194
- def scisearch(query, language, num_results=10):
195
- datasets = []
196
- try:
197
- query = query.strip()
198
- exact_search = False
199
- if query.startswith('"') and query.endswith('"') and len(query) >= 2:
200
- exact_search = True
201
- query = query[1:-1]
202
- else:
203
- query = " ".join(query.split())
204
- if query == "" or query is None:
205
- return ""
206
- post_data = {"query": query, "k": num_results}
207
- if language != "detect_language":
208
- post_data["lang"] = language
209
- address = (
210
- "http://34.105.160.81:8080" if exact_search else os.environ.get("address")
211
- )
212
- output = requests.post(
213
- address,
214
- headers={"Content-type": "application/json"},
215
- data=json.dumps(post_data),
216
- timeout=60,
217
- )
218
- payload = json.loads(output.text)
219
- return (
220
- process_bm25_match_payload(payload, language)
221
- if not exact_search
222
- else process_exact_match_payload(payload, query)
223
- )
224
- except Exception as e:
225
- results_html = f"""
226
- <p style='font-size:18px; font-family: Arial; color:MediumVioletRed; text-align: center;'>
227
- Raised {type(e).__name__}</p>
228
- <p style='font-size:14px; font-family: Arial; '>
229
- Check if a relevant discussion already exists in the Community tab. If not, please open a discussion.
230
- </p>
231
- """
232
- print(e)
233
- print(traceback.format_exc())
234
- return results_html, datasets
235
 
 
 
 
236
 
237
- def flag(query, language, num_results, issue_description):
238
- try:
239
- post_data = {
240
- "query": query,
241
- "k": num_results,
242
- "flag": True,
243
- "description": issue_description,
244
- }
245
- if language != "detect_language":
246
- post_data["lang"] = language
 
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
- output = requests.post(
249
- os.environ.get("address"),
250
- headers={"Content-type": "application/json"},
251
- data=json.dumps(post_data),
252
- timeout=120,
253
- )
254
 
255
- results = json.loads(output.text)
256
- except:
257
- print("Error flagging")
258
- return ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
 
260
 
261
  description = """# <p style="text-align: center;"> 🌸 πŸ”Ž ROOTS search tool πŸ” 🌸 </p>
@@ -338,15 +303,38 @@ if __name__ == "__main__":
338
  def submit(query, lang, k, dropdown_input):
339
  print("submitting", query, lang, k)
340
  query = query.strip()
341
- if query is None or query == "":
 
 
 
 
 
 
342
  return "", ""
343
- results_html, datasets = scisearch(query, lang, k)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
  print(datasets)
345
  return {
346
  results: results_html,
347
  flagging_form: gr.update(visible=True),
348
  datasets_filter: gr.update(visible=True),
349
- available_datasets: gr.Dropdown.update(choices=datasets, value=datasets),
 
 
350
  }
351
 
352
  def filter_datasets():
 
1
  import json
2
  import os
3
+ import traceback
4
+ from typing import List, Tuple
5
+
6
  import gradio as gr
7
  import requests
8
  from huggingface_hub import HfApi
 
 
9
 
10
  hf_api = HfApi()
11
  roots_datasets = {
 
55
  return text
56
 
57
 
58
+ def flag(query, language, num_results, issue_description):
59
+ try:
60
+ post_data = {
61
+ "query": query,
62
+ "k": num_results,
63
+ "flag": True,
64
+ "description": issue_description,
65
+ }
66
+ if language != "detect_language":
67
+ post_data["lang"] = language
68
+
69
+ output = requests.post(
70
+ os.environ.get("address"),
71
+ headers={"Content-type": "application/json"},
72
+ data=json.dumps(post_data),
73
+ timeout=120,
74
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ results = json.loads(output.text)
77
+ except:
78
+ print("Error flagging")
79
+ return ""
80
 
81
+
82
+ def format_result(result, highlight_terms, exact_search):
83
+ text, url, docid = result
84
+ if exact_search:
85
+ query_start = text.find(highlight_terms)
86
+ query_end = query_start + len(highlight_terms)
87
+ tokens_html = text[0:query_start]
88
+ tokens_html += "<b>{}</b>".format(text[query_start:query_end])
89
+ tokens_html += text[query_end:]
90
+ else:
91
+ tokens = text.split()
92
  tokens_html = []
93
  for token in tokens:
94
  if token in highlight_terms:
 
96
  else:
97
  tokens_html.append(token)
98
  tokens_html = " ".join(tokens_html)
99
+ tokens_html = process_pii(tokens_html)
 
 
 
 
 
 
 
 
 
 
100
 
101
+ meta_html = (
102
+ """<p class='underline-on-hover' style='font-size:12px; font-family: Arial; color:#585858; text-align: left;'>
103
+ <a href='{}' target='_blank'>{}</a></p>""".format(
104
+ url, url
 
 
105
  )
106
+ if url is not None
107
+ else ""
108
+ )
109
+ docid_html = get_docid_html(docid)
110
+ language = "FIXME"
111
+ return """{}
112
+ <p style='font-size:14px; font-family: Arial; color:#7978FF; text-align: left;'>Document ID: {}</p>
113
+ <p style='font-size:12px; font-family: Arial; color:MediumAquaMarine'>Language: {}</p>
114
+ <p style='font-family: Arial;'>{}</p>
115
+ <br>
116
+ """.format(
117
+ meta_html, docid_html, language, tokens_html
118
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
 
 
 
 
 
 
 
 
 
120
 
121
+ def format_result_page(
122
+ language, results, highlight_terms, num_results, exact_search
123
+ ) -> gr.HTML:
124
+ header_html = ""
125
 
126
+ # FIX lang detection by normalizing format on the backend
127
+ if language == "detect_language" and not exact_search:
128
+ header_html = f"""<p style='font-family: Arial; color:MediumAquaMarine; text-align: center; line-height: 3em'>
129
+ Detected language: <b> FIX MEEEE !!! </b></p><br><hr><br>"""
 
 
 
 
 
 
 
 
 
130
 
131
+ results_html = ""
132
+ for lang, results_for_lang in results.items():
133
+ if len(results_for_lang) == 0:
134
+ if exact_search:
 
 
135
  results_html += f"""<p style='font-family: Arial; color:Silver; text-align: left; line-height: 3em'>
136
+ No results found.<hr></p>"""
137
+ else:
138
+ results_html += f"""<p style='font-family: Arial; color:Silver; text-align: left; line-height: 3em'>
139
+ No results for language: <b>{lang}</b><hr></p>"""
140
+ continue
141
+ results_for_lang_html = ""
142
+ for result in results_for_lang:
143
+ results_for_lang_html += format_result(
144
+ result, highlight_terms, exact_search
145
+ )
146
+ if language == "all" and not exact_search:
147
+ results_for_lang_html = f"""
148
  <details>
149
  <summary style='font-family: Arial; color:MediumAquaMarine; text-align: left; line-height: 3em'>
150
  Results for language: <b>{lang}</b><hr>
151
  </summary>
152
+ {results_for_lang_html}
153
  </details>"""
154
+ results_html += results_for_lang_html
 
 
 
 
155
 
156
+ return header_html + results_html
 
 
 
 
157
 
158
 
159
+ def extract_results_from_payload(query, language, payload, exact_search):
160
+ results = payload["results"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
+ processed_results = dict()
163
+ highlight_terms = None
164
+ num_results = None
165
 
166
+ if exact_search:
167
+ highlight_terms = query
168
+ num_results = payload["num_results"]
169
+ results = {language: results}
170
+ else:
171
+ highlight_terms = payload["highlight_terms"]
172
+ # unify format - might be best fixed on server side
173
+ if language != "all":
174
+ results = {language: results}
175
+
176
+ for lang, results_for_lang in results.items():
177
+ processed_results[lang] = list()
178
+ for result in results_for_lang:
179
+ text = result["text"]
180
+ url = (
181
+ result["meta"]["url"]
182
+ if "meta" in result
183
+ and result["meta"] is not None
184
+ and "url" in result["meta"]
185
+ else None
186
+ )
187
+ docid = result["docid"]
188
+ processed_results[lang].append((text, url, docid))
189
 
190
+ return processed_results, highlight_terms, num_results
 
 
 
 
 
191
 
192
+
193
+ def process_error(error_type):
194
+ if error_type == "unsupported_lang":
195
+ detected_lang = payload["err"]["meta"]["detected_lang"]
196
+ return f"""
197
+ <p style='font-size:18px; font-family: Arial; color:MediumVioletRed; text-align: center;'>
198
+ Detected language <b>{detected_lang}</b> is not supported.<br>
199
+ Please choose a language from the dropdown or type another query.
200
+ </p><br><hr><br>"""
201
+
202
+
203
+ def extract_error_from_payload(payload):
204
+ if "err" in payload:
205
+ return payload["err"]["type"]
206
+ return None
207
+
208
+
209
+ def request_payload(
210
+ query, language, exact_search, num_results=10
211
+ ) -> List[Tuple[str, str]]:
212
+ post_data = {"query": query, "k": num_results}
213
+ if language != "detect_language":
214
+ post_data["lang"] = language
215
+ address = "http://34.105.160.81:8080" if exact_search else os.environ.get("address")
216
+ output = requests.post(
217
+ address,
218
+ headers={"Content-type": "application/json"},
219
+ data=json.dumps(post_data),
220
+ timeout=60,
221
+ )
222
+ payload = json.loads(output.text)
223
+ return payload
224
 
225
 
226
  description = """# <p style="text-align: center;"> 🌸 πŸ”Ž ROOTS search tool πŸ” 🌸 </p>
 
303
  def submit(query, lang, k, dropdown_input):
304
  print("submitting", query, lang, k)
305
  query = query.strip()
306
+ exact_search = False
307
+ if query.startswith('"') and query.endswith('"') and len(query) >= 2:
308
+ exact_search = True
309
+ query = query[1:-1]
310
+ else:
311
+ query = " ".join(query.split())
312
+ if query == "" or query is None:
313
  return "", ""
314
+
315
+ results_html = ""
316
+ payload = request_payload(query, lang, exact_search, k)
317
+ err = extract_error_from_payload(payload)
318
+ if err is not None:
319
+ results_html = process_error(err)
320
+ else:
321
+ (
322
+ processed_results,
323
+ highlight_terms,
324
+ num_results,
325
+ ) = extract_results_from_payload(query, lang, payload, exact_search)
326
+ results_html = format_result_page(
327
+ lang, processed_results, highlight_terms, num_results, exact_search
328
+ )
329
+ datasets = []
330
  print(datasets)
331
  return {
332
  results: results_html,
333
  flagging_form: gr.update(visible=True),
334
  datasets_filter: gr.update(visible=True),
335
+ available_datasets: gr.Dropdown.update(
336
+ choices=datasets, value=datasets
337
+ ),
338
  }
339
 
340
  def filter_datasets():