import argparse import json from dotenv import load_dotenv import plotly import shutil import smtplib import ssl import sys import textwrap from data_measurements import dataset_statistics from data_measurements.zipf import zipf from huggingface_hub import create_repo, Repository, hf_api from os import getenv from os.path import exists, join as pjoin from pathlib import Path import utils from utils import dataset_utils logs = utils.prepare_logging(__file__) def load_or_prepare_widgets(ds_args, show_embeddings=False, show_perplexities=False, use_cache=False): """ Loader specifically for the widgets used in the app. Args: ds_args: show_embeddings: show_perplexities: use_cache: Returns: """ dstats = dataset_statistics.DatasetStatisticsCacheClass(**ds_args, use_cache=use_cache) # Header widget dstats.load_or_prepare_dset_peek() # General stats widget dstats.load_or_prepare_general_stats() # Labels widget dstats.load_or_prepare_labels() # Text lengths widget dstats.load_or_prepare_text_lengths() if show_embeddings: # Embeddings widget dstats.load_or_prepare_embeddings() if show_perplexities: # Text perplexities widget dstats.load_or_prepare_text_perplexities() # Text duplicates widget dstats.load_or_prepare_text_duplicates() # nPMI widget dstats.load_or_prepare_npmi() # Zipf widget dstats.load_or_prepare_zipf() def load_or_prepare(dataset_args, calculation=False, use_cache=False): # TODO: Catch error exceptions for each measurement, so that an error # for one measurement doesn't break the calculation of all of them. do_all = False dstats = dataset_statistics.DatasetStatisticsCacheClass(**dataset_args, use_cache=use_cache) logs.info("Tokenizing dataset.") dstats.load_or_prepare_tokenized_df() logs.info("Calculating vocab.") dstats.load_or_prepare_vocab() if not calculation: do_all = True if do_all or calculation == "general": logs.info("\n* Calculating general statistics.") dstats.load_or_prepare_general_stats() logs.info("Done!") logs.info( "Basic text statistics now available at %s." % dstats.general_stats_json_fid) if do_all or calculation == "duplicates": logs.info("\n* Calculating text duplicates.") dstats.load_or_prepare_text_duplicates() duplicates_fid_dict = dstats.duplicates_files logs.info("If all went well, then results are in the following files:") for key, value in duplicates_fid_dict.items(): logs.info("%s: %s" % (key, value)) if do_all or calculation == "lengths": logs.info("\n* Calculating text lengths.") dstats.load_or_prepare_text_lengths() length_fid_dict = dstats.length_obj.get_filenames() print("If all went well, then results are in the following files:") for key, value in length_fid_dict.items(): print("%s: %s" % (key, value)) print() if do_all or calculation == "labels": logs.info("\n* Calculating label statistics.") if dstats.label_field not in dstats.dset.features: logs.warning("No label field found.") logs.info("No label statistics to calculate.") else: dstats.load_or_prepare_labels() npmi_fid_dict = dstats.label_files print("If all went well, then results are in the following files:") for key, value in npmi_fid_dict.items(): print("%s: %s" % (key, value)) print() if do_all or calculation == "npmi": print("\n* Preparing nPMI.") dstats.load_or_prepare_npmi() npmi_fid_dict = dstats.npmi_files print("If all went well, then results are in the following files:") for key, value in npmi_fid_dict.items(): if isinstance(value, dict): print(key + ":") for key2, value2 in value.items(): print("\t%s: %s" % (key2, value2)) else: print("%s: %s" % (key, value)) print() if do_all or calculation == "zipf": logs.info("\n* Preparing Zipf.") dstats.load_or_prepare_zipf() logs.info("Done!") zipf_json_fid, zipf_fig_json_fid, zipf_fig_html_fid = zipf.get_zipf_fids( dstats.dataset_cache_dir) logs.info("Zipf results now available at %s." % zipf_json_fid) logs.info( "Figure saved to %s, with corresponding json at %s." % (zipf_fig_html_fid, zipf_fig_json_fid) ) # Don't do this one until someone specifically asks for it -- takes awhile. if calculation == "embeddings": logs.info("\n* Preparing text embeddings.") dstats.load_or_prepare_embeddings() # Don't do this one until someone specifically asks for it -- takes awhile. if calculation == "perplexities": logs.info("\n* Preparing text perplexities.") dstats.load_or_prepare_text_perplexities() def pass_args_to_DMT(dset_name, dset_config, split_name, text_field, label_field, label_names, calculation, dataset_cache_dir, prepare_gui=False, use_cache=True): if not use_cache: logs.info("Not using any cache; starting afresh") dataset_args = { "dset_name": dset_name, "dset_config": dset_config, "split_name": split_name, "text_field": text_field, "label_field": label_field, "label_names": label_names, "dataset_cache_dir": dataset_cache_dir } if prepare_gui: load_or_prepare_widgets(dataset_args, use_cache=use_cache) else: load_or_prepare(dataset_args, calculation=calculation, use_cache=use_cache) def set_defaults(args): if not args.config: args.config = "default" logs.info("Config name not specified. Assuming it's 'default'.") if not args.split: args.split = "train" logs.info("Split name not specified. Assuming it's 'train'.") if not args.feature: args.feature = "text" logs.info("Text column name not given. Assuming it's 'text'.") if not args.label_field: args.label_field = "label" logs.info("Label column name not given. Assuming it's 'label'.") return args def main(): parser = argparse.ArgumentParser( formatter_class=argparse.RawDescriptionHelpFormatter, description=textwrap.dedent( """ Example for hate speech18 dataset: python3 run_data_measurements.py --dataset="hate_speech18" --config="default" --split="train" --feature="text" Example for IMDB dataset: python3 run_data_measurements.py --dataset="imdb" --config="plain_text" --split="train" --label_field="label" --feature="text" """ ), ) parser.add_argument( "-d", "--dataset", required=True, help="Name of dataset to prepare" ) parser.add_argument( "-c", "--config", required=False, default="", help="Dataset configuration to prepare" ) parser.add_argument( "-s", "--split", required=False, default="", type=str, help="Dataset split to prepare" ) parser.add_argument( "-f", "--feature", "-t", "--text-field", required=False, nargs="+", type=str, default="", help="Column to prepare (handled as text)", ) parser.add_argument( "-w", "--calculation", help="""What to calculate (defaults to everything except embeddings and perplexities).\n Options are:\n - `general` (for duplicate counts, missing values, length statistics.)\n - `duplicates` for duplicate counts\n - `lengths` for text length distribution\n - `labels` for label distribution\n - `embeddings` (Warning: Slow.)\n - `perplexities` (Warning: Slow.)\n - `npmi` for word associations\n - `zipf` for zipfian statistics """, ) parser.add_argument( "-l", "--label_field", type=str, required=False, default="", help="Field name for label column in dataset (Required if there is a label field that you want information about)", ) parser.add_argument('-n', '--label_names', nargs='+', default=[]) parser.add_argument( "--use_cache", default=False, required=False, action="store_true", help="Whether to use cached files (Optional)", ) parser.add_argument("--out_dir", default="cache_dir", help="Where to write out to.") parser.add_argument( "--overwrite_previous", default=False, required=False, action="store_true", help="Whether to overwrite a previous local cache for these same arguments (Optional)", ) parser.add_argument( "--email", default=None, help="An email that recieves a message about whether the computation was successful. If email is not None, then you must have EMAIL_PASSWORD= for the sender email (data.measurements.tool@gmail.com) in a file named .env at the root of this repo.") parser.add_argument( "--push_cache_to_hub", default=False, required=False, action="store_true", help="Whether to push the cache to an organization on the hub. If you are using this option, you must have HUB_CACHE_ORGANIZATION= and HF_TOKEN= on separate lines in a file named .env at the root of this repo.", ) parser.add_argument("--prepare_GUI_data", default=False, required=False, action="store_true", help="Use this to process all of the stats used in the GUI.") parser.add_argument("--keep_local", default=True, required=False, action="store_true", help="Whether to save the data locally.") orig_args = parser.parse_args() args = set_defaults(orig_args) logs.info("Proceeding with the following arguments:") logs.info(args) # run_data_measurements.py -d hate_speech18 -c default -s train -f text -w npmi if args.email is not None: if Path(".env").is_file(): load_dotenv(".env") EMAIL_PASSWORD = getenv("EMAIL_PASSWORD") context = ssl.create_default_context() port = 465 server = smtplib.SMTP_SSL("smtp.gmail.com", port, context=context) server.login("data.measurements.tool@gmail.com", EMAIL_PASSWORD) dataset_cache_name, local_dataset_cache_dir = dataset_utils.get_cache_dir_naming(args.out_dir, args.dataset, args.config, args.split, args.feature) if not args.use_cache and exists(local_dataset_cache_dir): if args.overwrite_previous: shutil.rmtree(local_dataset_cache_dir) else: raise OSError("Cached results for this dataset already exist at %s. " "Delete it or use the --overwrite_previous argument." % local_dataset_cache_dir) # Initialize the local cache directory dataset_utils.make_path(local_dataset_cache_dir) # Initialize the repository # TODO: print out local or hub cache directory location. if args.push_cache_to_hub: repo = dataset_utils.initialize_cache_hub_repo(local_dataset_cache_dir, dataset_cache_name) # Run the measurements. try: pass_args_to_DMT( dset_name=args.dataset, dset_config=args.config, split_name=args.split, text_field=args.feature, label_field=args.label_field, label_names=args.label_names, calculation=args.calculation, dataset_cache_dir=local_dataset_cache_dir, prepare_gui=args.prepare_GUI_data, use_cache=args.use_cache, ) if args.push_cache_to_hub: repo.push_to_hub(commit_message="Added dataset cache.") computed_message = f"Data measurements have been computed for dataset" \ f" with these arguments: {args}." logs.info(computed_message) if args.email is not None: computed_message += "\nYou can return to the data measurements tool " \ "to view them." server.sendmail("data.measurements.tool@gmail.com", args.email, "Subject: Data Measurements Computed!\n\n" + computed_message) logs.info(computed_message) except Exception as e: logs.exception(e) error_message = f"An error occurred in computing data measurements " \ f"for dataset with arguments: {args}. " \ f"Feel free to make an issue here: " \ f"https://github.com/huggingface/data-measurements-tool/issues" if args.email is not None: server.sendmail("data.measurements.tool@gmail.com", args.email, "Subject: Data Measurements not Computed\n\n" + error_message) logs.warning("Data measurements not computed. ☹️") logs.warning(error_message) return if not args.keep_local: # Remove the dataset from local storage - we only want it stored on the hub. logs.warning("Deleting measurements data locally at %s" % local_dataset_cache_dir) shutil.rmtree(local_dataset_cache_dir) else: logs.info("Measurements made available locally at %s" % local_dataset_cache_dir) if __name__ == "__main__": main()