import os import random import gradio as gr import matplotlib.pyplot as plt import numpy as np from functools import partial from datasets import load_dataset dataset_names = [ "AI4Code", "AMPS", "ASFPublicMail", "CPDataset", "DMMath", "Discourse", "Enwiki", "EuroParliamentProceedings", "FreeLaw_Options", "GithubDiff", "GithubIssues", "Gutenberg", "LeetCode", "PileOfLaw", "PubMed", "S2ORC", "StackExchange", "USENET", "USPTO", "UbuntuIRC", "arXiv", ] dataset_data = {} for name in dataset_names: path = f"data/{name}/data.json" ds = load_dataset( "CarperAI/pilev2_smol_metadata", data_files=path, use_auth_token=os.environ["HF_TOKEN"], split="train", # download_mode="force_redownload", ) dataset_data[name] = { "ds": ds, "check_word_number_criteria": np.array(ds["check_word_number_criteria"]), "check_char_repetition_criteria": np.array(ds["check_char_repetition_criteria"]), "check_flagged_words_criteria": np.array(ds["check_flagged_words_criteria"]), "check_stop_word_ratio_criteria": np.array(ds["check_stop_word_ratio_criteria"]), "check_perplexity_criteria": np.array(ds["check_perplexity_criteria"]), "check_compression_ratio_criteria": np.array(ds["check_compression_ratio_criteria"]), } def plt_plot(criteria, dataset, threshold, greater_than=True): plt.close("all") x = dataset_data[dataset][criteria] # calculate percentage of data that will be removed given threshold perc = np.sum(x > threshold if greater_than else x < threshold) / len(x) # create a figure fig = plt.figure() # add a subplot ax = fig.add_subplot(111) # plot some data using black ax.hist(x, bins=50, color="black") # plot red dashed line at threshold ax.axvline(threshold, color='r', linestyle='dashed', linewidth=2) # set title # add percentage of data removed ax.set_title(f"{dataset} (removed {perc:.2%})") plt.xlabel("Value") plt.ylabel("Frequency") # make it look nice plt.tight_layout() return fig def check_filtered(criteria, dataset, threshold, greater_than=True): ds = dataset_data[dataset]["ds"] filtered_ds = ds.filter( lambda x: x[criteria] > threshold if greater_than else x[criteria] < threshold ) if len(filtered_ds) == 0: return "No examples found" # get random sample of 1 sample = filtered_ds.select([random.randint(0, len(filtered_ds) - 1)])["text"][0] return sample with gr.Blocks() as demo: dataset = gr.Radio(dataset_names, label="Dataset", value="arXiv") with gr.Tab("Character Repetition Criteria"): # plot some random data plot = gr.Plot() threshold = gr.Slider(minimum=0, maximum=1, label="Threshold") calculate = gr.Button("Calculate") check = gr.Button("Check Filtered Data") filtered_data = gr.Textbox(lines=5, label="Filtered Data") plot_fn = partial(plt_plot, "check_char_repetition_criteria") calculate.click(plot_fn, [dataset, threshold], plot) check_fn = partial(check_filtered, "check_char_repetition_criteria") check.click(check_fn, [dataset, threshold], filtered_data) with gr.Tab("Number of Words Criteria"): # plot some random data plot = gr.Plot() threshold = gr.Slider(minimum=0, maximum=50_000, label="Threshold") calculate = gr.Button("Calculate") check = gr.Button("Check Filtered Data") filtered_data = gr.Textbox(lines=5, label="Filtered Data") plot_fn = partial(plt_plot, "check_word_number_criteria") calculate.click(plot_fn, [dataset, threshold], plot) check_fn = partial(check_filtered, "check_word_number_criteria") check.click(check_fn, [dataset, threshold], filtered_data) with gr.Tab("Character Repetition Criteria"): # plot some random data plot = gr.Plot() threshold = gr.Slider(minimum=0, maximum=1, label="Threshold") calculate = gr.Button("Calculate") check = gr.Button("Check Filtered Data") filtered_data = gr.Textbox(lines=5, label="Filtered Data") plot_fn = partial(plt_plot, "check_char_repetition_criteria") calculate.click(plot_fn, [dataset, threshold], plot) check_fn = partial(check_filtered, "check_char_repetition_criteria") check.click(check_fn, [dataset, threshold], filtered_data) with gr.Tab("Stop Word Ratio Criteria"): plot = gr.Plot() threshold = gr.Slider(minimum=0, maximum=1, label="Threshold") calculate = gr.Button("Calculate") check = gr.Button("Check Filtered Data") filtered_data = gr.Textbox(lines=5, label="Filtered Data") plot_fn = partial(plt_plot, "check_stop_word_ratio_criteria") calculate.click(plot_fn, [dataset, threshold], plot) check_fn = partial(check_filtered, "check_stop_word_ratio_criteria") check.click(check_fn, [dataset, threshold], filtered_data) with gr.Tab("Flagged Word Criteria"): plot = gr.Plot() threshold = gr.Slider(minimum=0, maximum=1, label="Threshold") calculate = gr.Button("Calculate") check = gr.Button("Check Filtered Data") filtered_data = gr.Textbox(lines=5, label="Filtered Data") plot_fn = partial(plt_plot, "check_flagged_words_criteria") calculate.click(plot_fn, [dataset, threshold], plot) check_fn = partial(check_filtered, "check_flagged_words_criteria") check.click(check_fn, [dataset, threshold], filtered_data) with gr.Tab("Perplexity Criteria"): plot = gr.Plot() threshold = gr.Slider(minimum=0, maximum=50_000, label="Threshold") calculate = gr.Button("Calculate") check = gr.Button("Check Filtered Data") filtered_data = gr.Textbox(lines=5, label="Filtered Data") plot_fn = partial(plt_plot, "check_perplexity_criteria") calculate.click(plot_fn, [dataset, threshold], plot) check_fn = partial(check_filtered, "check_perplexity_criteria") check.click(check_fn, [dataset, threshold], filtered_data) with gr.Tab("Compression Ratio Criteria"): plot = gr.Plot() threshold = gr.Slider(minimum=0, maximum=1, label="Threshold") calculate = gr.Button("Calculate") check = gr.Button("Check Filtered Data") filtered_data = gr.Textbox(lines=5, label="Filtered Data") plot_fn = partial( plt_plot, "check_compression_ratio_criteria", greater_than=False ) calculate.click(plot_fn, [dataset, threshold], plot) check_fn = partial( check_filtered, "check_compression_ratio_criteria", greater_than=False ) check.click(check_fn, [dataset, threshold], filtered_data) if __name__ == "__main__": demo.launch()