scisearch / app.py
ola13's picture
refactor towards falgging
a949c83
raw
history blame
No virus
18.2 kB
import json
import os
import traceback
from typing import List, Tuple
import gradio as gr
import requests
from huggingface_hub import HfApi
hf_api = HfApi()
roots_datasets = {
dset.id.split("/")[-1]: dset
for dset in hf_api.list_datasets(
author="bigscience-data", use_auth_token=os.environ.get("bigscience_data_token")
)
}
def get_docid_html(docid):
data_org, dataset, docid = docid.split("/")
metadata = roots_datasets[dataset]
if metadata.private:
docid_html = """
<a title="This dataset is private. See the introductory text for more information"
style="color:#AA4A44; font-weight: bold; text-decoration:none"
onmouseover="style='color:#AA4A44; font-weight: bold; text-decoration:underline'"
onmouseout="style='color:#AA4A44; font-weight: bold; text-decoration:none'"
href="https://huggingface.co/datasets/bigscience-data/{dataset}"
target="_blank">
πŸ”’{dataset}
</a>
<span style="color:#7978FF; ">/{docid}</span>""".format(
dataset=dataset, docid=docid
)
else:
docid_html = """
<a title="This dataset is licensed {metadata}"
style="color:#7978FF; font-weight: bold; text-decoration:none"
onmouseover="style='color:#7978FF; font-weight: bold; text-decoration:underline'"
onmouseout="style='color:#7978FF; font-weight: bold; text-decoration:none'"
href="https://huggingface.co/datasets/bigscience-data/{dataset}"
target="_blank">
{dataset}
</a>
<span style="color:#7978FF; ">/{docid}</span>""".format(
metadata=metadata.tags[0].split(":")[-1], dataset=dataset, docid=docid
)
return docid_html
PII_TAGS = {"KEY", "EMAIL", "USER", "IP_ADDRESS", "ID", "IPv4", "IPv6"}
PII_PREFIX = "PI:"
def process_pii(text):
for tag in PII_TAGS:
text = text.replace(
PII_PREFIX + tag,
"""<b><mark style="background: Fuchsia; color: Lime;">REDACTED {}</mark></b>""".format(
tag
),
)
return text
def flag(query, language, num_results, issue_description):
try:
post_data = {
"query": query,
"k": num_results,
"flag": True,
"description": issue_description,
}
if language != "detect_language":
post_data["lang"] = language
output = requests.post(
os.environ.get("address"),
headers={"Content-type": "application/json"},
data=json.dumps(post_data),
timeout=120,
)
results = json.loads(output.text)
except:
print("Error flagging")
return ""
def format_result(result, highlight_terms, exact_search, datasets_filter=None):
# print("result", result)
text, url, docid = result
if datasets_filter is not None:
datasets_filter = set(datasets_filter)
dataset = docid.split("/")[1]
if not dataset in datasets_filter:
return ""
if exact_search:
query_start = text.find(highlight_terms)
query_end = query_start + len(highlight_terms)
tokens_html = text[0:query_start]
tokens_html += "<b>{}</b>".format(text[query_start:query_end])
tokens_html += text[query_end:]
else:
tokens = text.split()
tokens_html = []
for token in tokens:
if token in highlight_terms:
tokens_html.append("<b>{}</b>".format(token))
else:
tokens_html.append(token)
tokens_html = " ".join(tokens_html)
tokens_html = process_pii(tokens_html)
url_html = (
"""
<span style='font-size:12px; font-family: Arial; color:Silver; text-align: left;'>
<a style='text-decoration:none; color:Silver;'
onmouseover="style='text-decoration:underline; color:Silver;'"
onmouseout="style='text-decoration:none; color:Silver;'"
href='{url}'
target="_blank">
{url}
</a>
</span><br>
""".format(
url=url
)
if url is not None
else ""
)
docid_html = get_docid_html(docid)
language = "FIXME"
result_html = """{}
<span style='font-size:14px; font-family: Arial; color:#7978FF; text-align: left;'>Document ID: {}</span><br>
<!-- <span style='font-size:12px; font-family: Arial; color:MediumAquaMarine'>Language: {}</span><br> -->
<span style='font-family: Arial;'>{}</span><br>
<br>
""".format(
url_html, docid_html, language, tokens_html
)
return "<p>" + result_html + "</p>"
def format_result_page(
results, highlight_terms, num_results, exact_search, datasets_filter=None
):
results_html = []
for result in results:
result_html = format_result(
result, highlight_terms, exact_search, datasets_filter
)
if result_html != "":
results_html.append(result_html)
return results_html
def extract_results_from_payload(query, language, payload, exact_search):
results = payload["results"]
processed_results = list()
datasets = set()
highlight_terms = None
num_results = None
if exact_search:
highlight_terms = query
num_results = payload["num_results"]
else:
highlight_terms = payload["highlight_terms"]
for result in results:
text = result["text"]
url = (
result["meta"]["url"]
if "meta" in result
and result["meta"] is not None
and "url" in result["meta"]
else None
)
docid = result["docid"]
_, dataset, _ = docid.split("/")
datasets.add(dataset)
processed_results.append((text, url, docid))
return processed_results, highlight_terms, num_results, list(datasets)
def process_error(error_type):
if error_type == "unsupported_lang":
detected_lang = payload["err"]["meta"]["detected_lang"]
return f"""
<p style='font-size:18px; font-family: Arial; color:MediumVioletRed; text-align: center;'>
Detected language <b>{detected_lang}</b> is not supported.<br>
Please choose a language from the dropdown or type another query.
</p><br><hr><br>"""
def extract_error_from_payload(payload):
if "err" in payload:
return payload["err"]["type"]
return None
def request_payload(query, language, exact_search, num_results=10, received_results=0):
post_data = {"query": query, "k": num_results, "received_results": received_results}
if language != "detect_language":
post_data["lang"] = language
address = "http://34.105.160.81:8080" if exact_search else os.environ.get("address")
output = requests.post(
address,
headers={"Content-type": "application/json"},
data=json.dumps(post_data),
timeout=60,
)
payload = json.loads(output.text)
return payload
title = (
"""<p style="text-align: center; font-size:28px"> 🌸 πŸ”Ž ROOTS search tool πŸ” 🌸 </p>"""
)
description = """
The ROOTS corpus was developed during the [BigScience workshop](https://bigscience.huggingface.co/) for the purpose
of training the Multilingual Large Language Model [BLOOM](https://huggingface.co/bigscience/bloom). This tool allows
you to search through the ROOTS corpus. We serve a BM25 index for each language or group of languages included in
ROOTS. You can read more about the details of the tool design
[here](https://huggingface.co/spaces/bigscience-data/scisearch/blob/main/roots_search_tool_specs.pdf). For more
information and instructions on how to access the full corpus check [this form](https://forms.gle/qyYswbEL5kA23Wu99)."""
if __name__ == "__main__":
demo = gr.Blocks(
css=".underline-on-hover:hover { text-decoration: underline; } .flagging { font-size:12px; color:Silver; }"
)
with demo:
processed_results_state = gr.State([])
highlight_terms_state = gr.State([])
num_results_state = gr.State(0)
exact_search_state = gr.State(False)
lang_state = gr.State("")
max_page_size_state = gr.State(100)
received_results_state = gr.State(0)
with gr.Row():
gr.Markdown(value=title)
with gr.Row():
gr.Markdown(value=description)
with gr.Row():
query = gr.Textbox(
lines=1,
max_lines=1,
placeholder="Put your query in double quotes for exact search.",
label="Query",
)
with gr.Row():
lang = gr.Dropdown(
choices=[
"ar",
"ca",
"code",
"en",
"es",
"eu",
"fr",
"id",
"indic",
"nigercongo",
"pt",
"vi",
"zh",
"detect_language",
],
value="en",
label="Language",
)
k = gr.Slider(1, 100, value=10, step=1, label="Max Results")
with gr.Row():
submit_btn = gr.Button("Submit")
with gr.Row(visible=False) as datasets_filter:
available_datasets = gr.Dropdown(
type="value",
choices=[],
value=[],
label="Datasets Filter",
multiselect=True,
)
with gr.Row() as results_row:
header_html = gr.HTML(label="Header")
results_html = gr.HTML(label="Results")
with gr.Row(visible=False) as pagination:
next_page_btn = gr.Button("Next Page")
with gr.Column(visible=False) as flagging_form:
flag_txt = gr.Textbox(
lines=1,
placeholder="Type here...",
label="""If you choose to flag your search, we will save the query, language and the number of results
you requested. Please consider adding relevant additional context below:""",
)
flag_btn = gr.Button("Flag Results")
flag_btn.click(flag, inputs=[query, lang, k, flag_txt], outputs=[flag_txt])
def run_query(query, lang, k, dropdown_input, max_page_size, received_results):
query = query.strip()
exact_search = False
if query.startswith('"') and query.endswith('"') and len(query) >= 2:
exact_search = True
query = query[1:-1]
k = max_page_size
else:
query = " ".join(query.split())
if query == "" or query is None:
return None
print("submitting", query, lang, k)
payload = request_payload(query, lang, exact_search, k, received_results)
err = extract_error_from_payload(payload)
if err is not None:
return process_error(err)
(
processed_results,
highlight_terms,
num_results,
ds,
) = extract_results_from_payload(
query,
lang,
payload,
exact_search,
)
header_html = ""
if lang == "detect_language" and not exact_search:
header_html += """<div style='font-family: Arial; color:MediumAquaMarine; text-align: center; line-height: 3em'>
Detected language: <b style='color:MediumAquaMarine'>{}</b></div>""".format(
"FIX ME!"
)
if len(processed_results) == 0:
header_html += """<div style='font-family: Arial; color:Silver; text-align: left; line-height: 3em'>
No results found.</div>"""
elif num_results is not None:
header_html += """<div style='font-family: Arial; color:MediumAquaMarine; text-align: center; line-height: 3em'>
Total number of matches: <b style='color:MediumAquaMarine'>{}</b></div>""".format(
num_results
)
# print("processed_results", processed_results)
results_html_new = format_result_page(
processed_results, highlight_terms, num_results, exact_search
)
return (
processed_results,
highlight_terms,
num_results,
exact_search,
results_html_new,
ds,
)
def submit(query, lang, k, dropdown_input, max_page_size):
(
processed_results,
highlight_terms,
num_results,
exact_search,
results_html_new,
datasets,
) = run_query(query, lang, k, dropdown_input, max_page_size, 0)
has_more_results = exact_search and (num_results > max_page_size)
return {
processed_results_state: processed_results,
highlight_terms_state: highlight_terms,
num_results_state: num_results,
exact_search_state: exact_search,
results_html: results_html_new,
flagging_form: gr.update(visible=True),
datasets_filter: gr.update(visible=True),
available_datasets: gr.Dropdown.update(
choices=datasets, value=datasets
),
pagination: gr.update(visible=has_more_results),
received_results_state: len(processed_results),
}
def next_page(
query,
lang,
k,
dropdown_input,
max_page_size,
received_results,
processed_results,
):
(
processed_results,
highlight_terms,
num_results,
exact_search,
results_html_new,
datasets,
) = run_query(
query, lang, k, dropdown_input, max_page_size, received_results
)
num_processed_results = len(processed_results)
has_more_results = exact_search and (num_results > max_page_size)
print("num_processed_results", num_processed_results)
print("has_more_results", has_more_results)
print("received_results", received_results)
return {
processed_results_state: processed_results,
highlight_terms_state: highlight_terms,
num_results_state: num_results,
exact_search_state: exact_search,
results_html: results_html_new,
flagging_form: gr.update(visible=True),
datasets_filter: gr.update(visible=True),
available_datasets: gr.Dropdown.update(
choices=datasets, value=datasets
),
pagination: gr.update(
visible=num_processed_results >= max_page_size and has_more_results
),
received_results_state: received_results + num_processed_results,
}
def filter_datasets(
lang,
processed_results,
highlight_terms,
num_results,
exact_search,
datasets_filter,
):
results_html_new = format_result_page(
processed_results,
highlight_terms,
num_results,
exact_search,
datasets_filter,
)
return {results_html: results_html_new}
query.submit(
fn=submit,
inputs=[query, lang, k, available_datasets, max_page_size_state],
outputs=[
processed_results_state,
highlight_terms_state,
num_results_state,
exact_search_state,
results_html,
flagging_form,
datasets_filter,
available_datasets,
pagination,
received_results_state,
],
)
submit_btn.click(
submit,
inputs=[query, lang, k, available_datasets, max_page_size_state],
outputs=[
processed_results_state,
highlight_terms_state,
num_results_state,
exact_search_state,
results_html,
flagging_form,
datasets_filter,
available_datasets,
pagination,
received_results_state,
],
)
next_page_btn.click(
next_page,
inputs=[
query,
lang,
k,
available_datasets,
max_page_size_state,
received_results_state,
processed_results_state,
],
outputs=[
processed_results_state,
highlight_terms_state,
num_results_state,
exact_search_state,
results_html,
flagging_form,
datasets_filter,
available_datasets,
pagination,
received_results_state,
],
)
available_datasets.change(
filter_datasets,
inputs=[
lang,
processed_results_state,
highlight_terms_state,
num_results_state,
exact_search_state,
available_datasets,
],
outputs=[results_html],
)
demo.launch(enable_queue=True, debug=True)