import json import os import traceback from typing import List, Tuple import gradio as gr import requests from huggingface_hub import HfApi hf_api = HfApi() roots_datasets = { dset.id.split("/")[-1]: dset for dset in hf_api.list_datasets( author="bigscience-data", use_auth_token=os.environ.get("bigscience_data_token") ) } def get_docid_html(docid): data_org, dataset, docid = docid.split("/") metadata = roots_datasets[dataset] if metadata.private: docid_html = """ 🔒{dataset} /{docid}""".format( dataset=dataset, docid=docid ) else: docid_html = """ {dataset} /{docid}""".format( metadata=metadata.tags[0].split(":")[-1], dataset=dataset, docid=docid ) return docid_html PII_TAGS = {"KEY", "EMAIL", "USER", "IP_ADDRESS", "ID", "IPv4", "IPv6"} PII_PREFIX = "PI:" def process_pii(text): for tag in PII_TAGS: text = text.replace( PII_PREFIX + tag, """REDACTED {}""".format( tag ), ) return text def flag(query, language, num_results, issue_description): try: post_data = { "query": query, "k": num_results, "flag": True, "description": issue_description, } if language != "detect_language": post_data["lang"] = language output = requests.post( os.environ.get("address"), headers={"Content-type": "application/json"}, data=json.dumps(post_data), timeout=120, ) results = json.loads(output.text) except: print("Error flagging") return "" def format_result(result, highlight_terms, exact_search, datasets_filter=None): text, url, docid = result if datasets_filter is not None: datasets_filter = set(datasets_filter) dataset = docid.split("/")[1] if not dataset in datasets_filter: return "" if exact_search: query_start = text.find(highlight_terms) query_end = query_start + len(highlight_terms) tokens_html = text[0:query_start] tokens_html += "{}".format(text[query_start:query_end]) tokens_html += text[query_end:] else: tokens = text.split() tokens_html = [] for token in tokens: if token in highlight_terms: tokens_html.append("{}".format(token)) else: tokens_html.append(token) tokens_html = " ".join(tokens_html) tokens_html = process_pii(tokens_html) url_html = ( """ {url}
""".format( url=url ) if url is not None else "" ) docid_html = get_docid_html(docid) language = "FIXME" result_html = """{} Document ID: {}
{}

""".format( url_html, docid_html, language, tokens_html ) return "

" + result_html + "

" def format_result_page( language, results, highlight_terms, num_results, exact_search, datasets_filter=None ) -> gr.HTML: filtered_num_results = 0 header_html = "" if language == "detect_language" and not exact_search: header_html += """
Detected language: {}
""".format( list(results.keys())[0] ) results_html = "" for lang, results_for_lang in results.items(): if len(results_for_lang) == 0: if exact_search: results_html += """
No results found.
""" else: results_html += """
No results for language: {}
""".format( lang ) continue results_for_lang_html = "" for result in results_for_lang: result_html = format_result( result, highlight_terms, exact_search, datasets_filter ) if result_html != "": filtered_num_results += 1 results_for_lang_html += result_html if language == "all" and not exact_search: results_for_lang_html = f"""
Results for language: {lang} {results_for_lang_html}
""" results_html += results_for_lang_html if num_results is not None: header_html += """
Total number of matches: {}
""".format( num_results ) return header_html + results_html def extract_results_from_payload(query, language, payload, exact_search): results = payload["results"] processed_results = dict() datasets = set() highlight_terms = None num_results = None if exact_search: highlight_terms = query num_results = payload["num_results"] results = {"dummy": results} else: highlight_terms = payload["highlight_terms"] for lang, results_for_lang in results.items(): processed_results[lang] = list() for result in results_for_lang: text = result["text"] url = ( result["meta"]["url"] if "meta" in result and result["meta"] is not None and "url" in result["meta"] else None ) docid = result["docid"] _, dataset, _ = docid.split("/") datasets.add(dataset) processed_results[lang].append((text, url, docid)) return processed_results, highlight_terms, num_results, list(datasets) def process_error(error_type): if error_type == "unsupported_lang": detected_lang = payload["err"]["meta"]["detected_lang"] return f"""

Detected language {detected_lang} is not supported.
Please choose a language from the dropdown or type another query.




""" def extract_error_from_payload(payload): if "err" in payload: return payload["err"]["type"] return None def request_payload(query, language, exact_search, num_results=10, received_results=0): post_data = {"query": query, "k": num_results, "received_results": received_results} if language != "detect_language": post_data["lang"] = language address = "http://34.105.160.81:8080" if exact_search else os.environ.get("address") output = requests.post( address, headers={"Content-type": "application/json"}, data=json.dumps(post_data), timeout=60, ) payload = json.loads(output.text) return payload title = ( """

🌸 🔎 ROOTS search tool 🔍 🌸

""" ) description = """ The ROOTS corpus was developed during the [BigScience workshop](https://bigscience.huggingface.co/) for the purpose of training the Multilingual Large Language Model [BLOOM](https://huggingface.co/bigscience/bloom). This tool allows you to search through the ROOTS corpus. We serve a BM25 index for each language or group of languages included in ROOTS. You can read more about the details of the tool design [here](https://huggingface.co/spaces/bigscience-data/scisearch/blob/main/roots_search_tool_specs.pdf). For more information and instructions on how to access the full corpus check [this form](https://forms.gle/qyYswbEL5kA23Wu99).""" if __name__ == "__main__": demo = gr.Blocks( css=".underline-on-hover:hover { text-decoration: underline; } .flagging { font-size:12px; color:Silver; }" ) with demo: processed_results_state = gr.State([]) highlight_terms_state = gr.State([]) num_results_state = gr.State(0) exact_search_state = gr.State(False) lang_state = gr.State("") max_page_size_state = gr.State(100) received_results_state = gr.State(0) with gr.Row(): gr.Markdown(value=title) with gr.Row(): gr.Markdown(value=description) with gr.Row(): query = gr.Textbox( lines=1, max_lines=1, placeholder="Put your query in double quotes for exact search.", label="Query", ) with gr.Row(): lang = gr.Dropdown( choices=[ "ar", "ca", "code", "en", "es", "eu", "fr", "id", "indic", "nigercongo", "pt", "vi", "zh", "detect_language", "all", ], value="en", label="Language", ) k = gr.Slider(1, 100, value=10, step=1, label="Max Results") with gr.Row(): submit_btn = gr.Button("Submit") with gr.Row(visible=False) as datasets_filter: available_datasets = gr.Dropdown( type="value", choices=[], value=[], label="Datasets Filter", multiselect=True, ) with gr.Row(): results = gr.HTML(label="Results") with gr.Row(visible=False) as pagination: next_page_btn = gr.Button("Next Page") with gr.Column(visible=False) as flagging_form: flag_txt = gr.Textbox( lines=1, placeholder="Type here...", label="""If you choose to flag your search, we will save the query, language and the number of results you requested. Please consider adding relevant additional context below:""", ) flag_btn = gr.Button("Flag Results") flag_btn.click(flag, inputs=[query, lang, k, flag_txt], outputs=[flag_txt]) def run_query(query, lang, k, dropdown_input, max_page_size, received_results): query = query.strip() exact_search = False if query.startswith('"') and query.endswith('"') and len(query) >= 2: exact_search = True query = query[1:-1] k = max_page_size else: query = " ".join(query.split()) if query == "" or query is None: return None print("submitting", query, lang, k) payload = request_payload(query, lang, exact_search, k, received_results) err = extract_error_from_payload(payload) if err is not None: return process_error(err) ( processed_results, highlight_terms, num_results, ds, ) = extract_results_from_payload( query, lang, payload, exact_search, ) results_html = format_result_page( lang, processed_results, highlight_terms, num_results, exact_search ) return ( processed_results, highlight_terms, num_results, exact_search, results_html, ds, ) def submit(query, lang, k, dropdown_input, max_page_size): ( processed_results, highlight_terms, num_results, exact_search, results_html, datasets, ) = run_query(query, lang, k, dropdown_input, max_page_size, 0) has_more_results = exact_search and (num_results > max_page_size) return { processed_results_state: processed_results, highlight_terms_state: highlight_terms, num_results_state: num_results, exact_search_state: exact_search, results: results_html, flagging_form: gr.update(visible=True), datasets_filter: gr.update(visible=True), available_datasets: gr.Dropdown.update( choices=datasets, value=datasets ), pagination: gr.update(visible=has_more_results), received_results_state: len(next(iter(processed_results.values()))), } def next_page( query, lang, k, dropdown_input, max_page_size, received_results, processed_results, ): ( processed_results, highlight_terms, num_results, exact_search, results_html, datasets, ) = run_query( query, lang, k, dropdown_input, max_page_size, received_results ) num_processed_results = len(next(iter(processed_results.values()))) has_more_results = exact_search and (num_results > max_page_size) print("num_processed_results", num_processed_results) print("has_more_results", has_more_results) print("current page", len(list(processed_results.values())[0])) return { processed_results_state: processed_results, highlight_terms_state: highlight_terms, num_results_state: num_results, exact_search_state: exact_search, results: results_html, flagging_form: gr.update(visible=True), datasets_filter: gr.update(visible=True), available_datasets: gr.Dropdown.update( choices=datasets, value=datasets ), pagination: gr.update( visible=num_processed_results >= max_page_size and has_more_results ), received_results_state: received_results + num_processed_results, } def filter_datasets( lang, processed_results, highlight_terms, num_results, exact_search, datasets_filter, ): results_html = format_result_page( lang, processed_results, highlight_terms, num_results, exact_search, datasets_filter, ) return {results: results_html} query.submit( fn=submit, inputs=[query, lang, k, available_datasets, max_page_size_state], outputs=[ processed_results_state, highlight_terms_state, num_results_state, exact_search_state, results, flagging_form, datasets_filter, available_datasets, pagination, received_results_state, ], ) submit_btn.click( submit, inputs=[query, lang, k, available_datasets, max_page_size_state], outputs=[ processed_results_state, highlight_terms_state, num_results_state, exact_search_state, results, flagging_form, datasets_filter, available_datasets, pagination, received_results_state, ], ) next_page_btn.click( next_page, inputs=[ query, lang, k, available_datasets, max_page_size_state, received_results_state, processed_results_state, ], outputs=[ processed_results_state, highlight_terms_state, num_results_state, exact_search_state, results, flagging_form, datasets_filter, available_datasets, pagination, received_results_state, ], ) available_datasets.change( filter_datasets, inputs=[ lang, processed_results_state, highlight_terms_state, num_results_state, exact_search_state, available_datasets, ], outputs=[results], ) demo.launch(enable_queue=True, debug=True)