|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import ast |
|
import gradio as gr |
|
from os.path import isdir |
|
from data_measurements.dataset_statistics import DatasetStatisticsCacheClass as dmt_cls |
|
import utils |
|
from utils import dataset_utils |
|
from utils import gradio_utils as gr_utils |
|
import widgets |
|
import app as ap |
|
from app import load_or_prepare_widgets |
|
|
|
|
|
logs = utils.prepare_logging(__file__) |
|
|
|
|
|
DATASET_NAME_TO_DICT = dataset_utils.get_dataset_info_dicts() |
|
|
|
|
|
def get_load_prepare_list(dstats): |
|
""" |
|
# Get load_or_prepare functions for the measurements we will display |
|
""" |
|
|
|
|
|
load_prepare_list = [ |
|
("text_lengths", dstats.load_or_prepare_text_lengths), |
|
] |
|
|
|
return load_prepare_list |
|
|
|
|
|
def get_ui_widgets(): |
|
"""Get the widgets that will be displayed in the UI.""" |
|
return [ |
|
widgets.TextLengths(),] |
|
|
|
|
|
def get_widgets(): |
|
""" |
|
# A measurement widget requires 2 things: |
|
# - A load or prepare function |
|
# - A display function |
|
# We define these in two separate functions get_load_prepare_list and get_ui_widgets; |
|
# any widget can be added by modifying both functions and the rest of the app logic will work. |
|
# get_load_prepare_list is a function since it requires a DatasetStatisticsCacheClass which will |
|
# not be created until dataset and config values are selected in the ui |
|
""" |
|
return get_load_prepare_list, get_ui_widgets() |
|
|
|
|
|
def get_title(dstats): |
|
title_str = f"### Showing: {dstats.dset_name} - {dstats.dset_config} - {dstats.split_name} - {'-'.join(dstats.text_field)}" |
|
logs.info("showing header") |
|
return title_str |
|
|
|
|
|
def display_initial_UI(): |
|
"""Displays the header in the UI""" |
|
|
|
dataset_args = gr_utils.sidebar_selection(DATASET_NAME_TO_DICT) |
|
return dataset_args |
|
|
|
|
|
|
|
|
|
def show_column(dstats, display_list, show_perplexities, column_id=""): |
|
""" |
|
Function for displaying the elements in the streamlit app. |
|
Args: |
|
dstats (class): The dataset_statistics.py DatasetStatisticsCacheClass |
|
display_list (list): List of tuples for (widget_name, widget_display_function) |
|
show_perplexities (Bool): Whether perplexities should be loaded and displayed for this dataset |
|
column_id (str): Which column of the dataset the analysis is done on [DEPRECATED for v1] |
|
""" |
|
|
|
|
|
gr_utils.expander_header(dstats, DATASET_NAME_TO_DICT) |
|
for widget_tuple in display_list: |
|
widget_type = widget_tuple[0] |
|
widget_fn = widget_tuple[1] |
|
logs.info("showing %s." % widget_type) |
|
try: |
|
widget_fn(dstats, column_id) |
|
except Exception as e: |
|
logs.warning("Jk jk jk. There was an issue with %s:" % widget_type) |
|
logs.exception(e) |
|
|
|
if show_perplexities: |
|
gr_utils.expander_text_perplexities(dstats, column_id) |
|
logs.info("Have finished displaying the widgets.") |
|
|
|
|
|
def create_demo(live: bool, pull_cache_from_hub: bool): |
|
with gr.Blocks() as demo: |
|
state = gr.State() |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
dataset_args = display_initial_UI() |
|
get_load_prepare_list_fn, widget_list = get_widgets() |
|
|
|
|
|
show_perplexities = gr.Checkbox(label="Show text perplexities") |
|
with gr.Column(scale=4): |
|
gr.Markdown("# Data Measurements Tool") |
|
title = gr.Markdown() |
|
for widget in widget_list: |
|
widget.render() |
|
|
|
def update_ui(dataset: str, config: str, split: str, feature: str): |
|
feature = ast.literal_eval(feature) |
|
label_field, label_names = gr_utils.get_label_names(dataset, config, DATASET_NAME_TO_DICT) |
|
dstats = dmt_cls(dset_name=dataset, dset_config=config, split_name=split, text_field=feature, |
|
label_field=label_field, label_names=label_names, use_cache=True) |
|
load_prepare_list = get_load_prepare_list_fn(dstats) |
|
dstats = load_or_prepare_widgets(dstats, load_prepare_list, show_perplexities=False, |
|
live=live, pull_cache_from_hub=pull_cache_from_hub) |
|
output = {title: get_title(dstats), state: dstats} |
|
for widget in widget_list: |
|
output.update(widget.update(dstats)) |
|
return output |
|
|
|
def update_dataset(dataset: str): |
|
new_values = gr_utils.update_dataset(dataset, DATASET_NAME_TO_DICT) |
|
config = new_values[0][1] |
|
feature = new_values[1][1] |
|
split = new_values[2][1] |
|
new_dropdown = { |
|
dataset_args["text_field"]: gr.Dropdown.update(choices=new_values[1][0], value=feature), |
|
dataset_args["split_name"]: gr.Dropdown.update(choices=new_values[2][0], value=split), |
|
} |
|
return new_dropdown |
|
|
|
def update_config(dataset: str, config: str): |
|
new_values = gr_utils.update_config(dataset, config, DATASET_NAME_TO_DICT) |
|
|
|
feature = new_values[0][1] |
|
split = new_values[1][1] |
|
new_dropdown = { |
|
dataset_args["text_field"]: gr.Dropdown.update(choices=new_values[0][0], value=feature), |
|
dataset_args["split_name"]: gr.Dropdown.update(choices=new_values[1][0], value=split) |
|
} |
|
return new_dropdown |
|
|
|
measurements = [comp for output in widget_list for comp in output.output_components] |
|
demo.load(update_ui, |
|
inputs=[dataset_args["dset_name"], dataset_args["dset_config"], dataset_args["split_name"], dataset_args["text_field"]], |
|
outputs=[title, state] + measurements) |
|
print(dataset_args["text_field"]) |
|
for widget in widget_list: |
|
widget.add_events(state) |
|
|
|
dataset_args["dset_name"].change(update_dataset, |
|
inputs=[dataset_args["dset_name"]], |
|
outputs=[dataset_args["dset_config"], |
|
dataset_args["split_name"], dataset_args["text_field"], |
|
title, state] + measurements) |
|
|
|
dataset_args["dset_config"].change(update_config, |
|
inputs=[dataset_args["dset_name"], dataset_args["dset_config"]], |
|
outputs=[dataset_args["split_name"], dataset_args["text_field"], |
|
title, state] + measurements) |
|
|
|
dataset_args["calculate_btn"].click(update_ui, |
|
inputs=[dataset_args["dset_name"], dataset_args["dset_config"], |
|
dataset_args["split_name"], dataset_args["text_field"]], |
|
outputs=[title, state] + measurements) |
|
return demo |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--live", default=False, required=False, action="store_true", help="Flag to specify that this is not running live.") |
|
parser.add_argument( |
|
"--pull_cache_from_hub", default=False, required=False, action="store_true", help="Flag to specify whether to look in the hub for measurements caches. If you are using this option, you must have HUB_CACHE_ORGANIZATION=<the organization you've set up on the hub to store your cache> and HF_TOKEN=<your hf token> on separate lines in a file named .env at the root of this repo.") |
|
arguments = parser.parse_args() |
|
live = arguments.live |
|
pull_cache_from_hub = arguments.pull_cache_from_hub |
|
|
|
|
|
dataset_args = display_initial_UI() |
|
demo = create_demo(live, pull_cache_from_hub) |
|
print("this is the cureenrt TEXT:") |
|
print(dataset_args["text_field"]) |
|
|
|
demo.launch() |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|