lvwerra's picture
lvwerra HF staff
Update app.py
e146ae1
raw
history blame
4.37 kB
import http.client as http_client
import json
import logging
import os
import re
import string
import gradio as gr
import requests
def mark_tokens_bold(string, tokens):
for token in tokens:
pattern = re.escape(token) #r"\b" + re.escape(token) + r"\b"
string = re.sub(pattern, "<span style='color: red;'><b>" + token + "</b></span>", string)
return string
def process_results(results, highlight_terms):
if len(results) == 0:
return """<br><p style='font-family: Arial; color:Silver; text-align: center;'>
No results retrieved.</p><br><hr>"""
results_html = ""
for result in results:
text_html = result["text"]
text_html = mark_tokens_bold(text_html, highlight_terms)
meta_html = (
"""
<p class='underline-on-hover' style='font-size:12px; font-family: Arial; color:#585858; text-align: left;'>
<a href='{}' target='_blank'>{}</a></p>""".format(
result["meta"]["url"], result["meta"]["url"]
)
if "meta" in result and result["meta"] is not None and "url" in result["meta"]
else ""
)
docid_html = str(result["docid"])
licenses = " | ".join(result["repo_license"])
repo_name = result["repo_name"]
repo_path = result["repo_path"]
results_html += """{}
<p style='font-size:16px; font-family: Arial; text-align: left;'>Repository name: <span style='color: #20233fff;'>{}</span></p>
<p style='font-size:16px; font-family: Arial; text-align: left;'>Repository path: <span style='color: #20233fff;'>{}</span></p>
<p style='font-size:16px; font-family: Arial; text-align: left;'>Repository licenses: <span style='color: #20233fff;'>{}</span></p>
<pre style='height: 600px; overflow: scroll;'><code>{}</code></pre>
<br>
""".format(
meta_html, repo_name, repo_path, licenses, text_html
)
return results_html + "<hr>"
def scisearch(query, language, num_results=10):
query = " ".join(query.split())
if query == "" or query is None:
return ""
post_data = {"query": query, "k": num_results}
output = requests.post(
os.environ.get("address"),
headers={"Content-type": "application/json"},
data=json.dumps(post_data),
timeout=60,
)
payload = json.loads(output.text)
results = payload["results"]
highlight_terms = payload["highlight_terms"]
return process_results(results, highlight_terms)
description = """# <p style="text-align: center;"> 🌸 πŸ”Ž ROOTS search tool πŸ” 🌸 </p>
The ROOTS corpus was developed during the [BigScience workshop](https://bigscience.huggingface.co/) for the purpose
of training the Multilingual Large Language Model [BLOOM](https://huggingface.co/bigscience/bloom). This tool allows
you to search through the ROOTS corpus. We serve a BM25 index for each language or group of languages included in
ROOTS. You can read more about the details of the tool design
[here](https://huggingface.co/spaces/bigscience-data/scisearch/blob/main/roots_search_tool_specs.pdf). For more
information and instructions on how to access the full corpus check [this form](https://forms.gle/qyYswbEL5kA23Wu99)."""
if __name__ == "__main__":
demo = gr.Blocks(
css=".underline-on-hover:hover { text-decoration: underline; } .flagging { font-size:12px; color:Silver; }"
)
with demo:
with gr.Row():
gr.Markdown(value=description)
with gr.Row():
query = gr.Textbox(lines=1, max_lines=1, placeholder="Type your query here...", label="Query")
with gr.Row():
k = gr.Slider(1, 100, value=10, step=1, label="Max Results")
with gr.Row():
submit_btn = gr.Button("Submit")
with gr.Row():
results = gr.HTML(label="Results")
def submit(query, k, lang="en"):
query = query.strip()
if query is None or query == "":
return "", ""
return {
results: scisearch(query, lang, k),
}
query.submit(fn=submit, inputs=[query, k], outputs=[results])
submit_btn.click(submit, inputs=[query, k], outputs=[results])
demo.launch(enable_queue=True, debug=True)