import os from functools import lru_cache import gradio as gr import plotly.graph_objects as go from wimbd.es import es_init, count_documents_containing_phrases es = es_init(None, os.getenv("lm_datasets_cloud_id"), os.getenv("lm_datasets_api_key")) es_dolma = es_init(None, os.getenv("dolma_cloud_id"), os.getenv("dolma_api_key")) datasets = ["OpenWebText", "C4", "OSCAR", "The Pile", "LAION-2B-en", "Dolma"] dataset_es_map = { "OSCAR": "re_oscar", "LAION-2B-en": "re_laion2b-en-*", "LAION-5B": "*laion2b*", "OpenWebText": "openwebtext", "The Pile": "re_pile", "C4": "c4", "Dolma v1.5": "docs_v1.5_2023-11-02", "Dolma v1.7": "docs_v1.7_2024-06-04", "Tulu v2": "tulu-v2-sft-mixture", } default_checked = ["C4", "The Pile", "Dolma v1.7"] # Datasets to be checked by default @lru_cache() def get_counts(index_name, phrase, es): return count_documents_containing_phrases(index_name, phrase, es=es) def process_input(phrases, *dataset_choices): results = [] for dataset_name, index_name, is_selected in zip( dataset_es_map.keys(), dataset_es_map.values(), dataset_choices ): if is_selected: for phrase in phrases.split("\n"): phrase = phrase.strip() if phrase: if "dolma" in dataset_name.lower(): count = get_counts(index_name, phrase, es=es_dolma) else: count = get_counts(index_name, phrase, es=es) results.append((dataset_name, phrase, count)) # Format results for different output components table_data = [[dataset, phrase, str(count)] for dataset, phrase, count in results] # Create bar chart using plotly fig = go.Figure() for phrase in set([r[1] for r in results]): dataset_names = [r[0] for r in results if r[1] == phrase] counts = [r[2] for r in results if r[1] == phrase] fig.add_trace(go.Bar(x=dataset_names, y=counts, name=phrase)) fig.update_layout( title="Document Counts by Dataset and Phrase", xaxis_title="Dataset", yaxis_title="Count", barmode="group", ) # return table_data, markdown_text, fig return table_data, fig citation_text = """If you find this tool useful, please kindly cite our paper: ```bibtex @inproceedings{elazar2023s, title={What's In My Big Data?}, author={Elazar, Yanai and Bhagia, Akshita and Magnusson, Ian Helgi and Ravichander, Abhilasha and Schwenk, Dustin and Suhr, Alane and Walsh, Evan Pete and Groeneveld, Dirk and Soldaini, Luca and Singh, Sameer and Hajishirzi, Hanna and Smith, Noah A. and Dodge, Jesse}, booktitle={The Twelfth International Conference on Learning Representations}, year={2024} }```""" def custom_layout(input_components, output_components, citation): return [ input_components[0], # Textbox *input_components[1:], # Checkboxes output_components[0], # Dataframe # output_components[1], # Markdown output_components[1], # Plot citation, # Citation Markdown ] iface = gr.Interface( fn=process_input, inputs=[ gr.Textbox(label="Enter phrases (one per line)", lines=5), *[ gr.Checkbox(label=dataset, value=(dataset in default_checked)) for dataset in dataset_es_map.keys() ], ], outputs=[ gr.Dataframe(headers=["Dataset", "Phrase", "Count"], label="Counts Table"), # gr.Markdown(label="Results as Text"), gr.Plot(label="Results Chart"), # gr.Markdown(value=citation_text) ], title="What's In My Big Data? String Counts Demo", description="""This app connects to the WIMBD Elasticsearch instance and counts the number of documents containing a given string in the various indexed datasets.\\ The app uses the wimbd pypi package, which can be installed by simply running `pip install wimbd`.\\ Access to the indices require an API key, due to the sensitive nature of the data, but can be accessed by filling up the following [form](https://forms.gle/Mk9uwJibR9H4hh9Y9).\\ This app was created by [Yanai Elazar](https://yanaiela.github.io/), and for bugs, improvements, or feature requests, please open an issue on the [GitHub repository](https://github.com/allenai/wimbd), or send me an email. The returned counts are the number of documents that contain each string per dataset.""", article=citation_text, # This adds the citation at the bottom theme=custom_layout, # This uses our custom layout function ) iface.launch()