IDEFICS_Data_Measurement_Tool / run_data_measurements.py
Ezi's picture
Upload 312 files
46df0b6
raw
history blame
No virus
14.3 kB
import argparse
import json
from dotenv import load_dotenv
import plotly
import shutil
import smtplib
import ssl
import sys
import textwrap
from data_measurements import dataset_statistics
from data_measurements.zipf import zipf
from huggingface_hub import create_repo, Repository, hf_api
from os import getenv
from os.path import exists, join as pjoin
from pathlib import Path
import utils
from utils import dataset_utils
logs = utils.prepare_logging(__file__)
def load_or_prepare_widgets(ds_args, show_embeddings=False,
show_perplexities=False, use_cache=False):
"""
Loader specifically for the widgets used in the app.
Args:
ds_args:
show_embeddings:
show_perplexities:
use_cache:
Returns:
"""
dstats = dataset_statistics.DatasetStatisticsCacheClass(**ds_args, use_cache=use_cache)
# Header widget
dstats.load_or_prepare_dset_peek()
# General stats widget
dstats.load_or_prepare_general_stats()
# Labels widget
dstats.load_or_prepare_labels()
# Text lengths widget
dstats.load_or_prepare_text_lengths()
if show_embeddings:
# Embeddings widget
dstats.load_or_prepare_embeddings()
if show_perplexities:
# Text perplexities widget
dstats.load_or_prepare_text_perplexities()
# Text duplicates widget
dstats.load_or_prepare_text_duplicates()
# nPMI widget
dstats.load_or_prepare_npmi()
# Zipf widget
dstats.load_or_prepare_zipf()
def load_or_prepare(dataset_args, calculation=False, use_cache=False):
# TODO: Catch error exceptions for each measurement, so that an error
# for one measurement doesn't break the calculation of all of them.
do_all = False
dstats = dataset_statistics.DatasetStatisticsCacheClass(**dataset_args,
use_cache=use_cache)
logs.info("Tokenizing dataset.")
dstats.load_or_prepare_tokenized_df()
logs.info("Calculating vocab.")
dstats.load_or_prepare_vocab()
if not calculation:
do_all = True
if do_all or calculation == "general":
logs.info("\n* Calculating general statistics.")
dstats.load_or_prepare_general_stats()
logs.info("Done!")
logs.info(
"Basic text statistics now available at %s." % dstats.general_stats_json_fid)
if do_all or calculation == "duplicates":
logs.info("\n* Calculating text duplicates.")
dstats.load_or_prepare_text_duplicates()
duplicates_fid_dict = dstats.duplicates_files
logs.info("If all went well, then results are in the following files:")
for key, value in duplicates_fid_dict.items():
logs.info("%s: %s" % (key, value))
if do_all or calculation == "lengths":
logs.info("\n* Calculating text lengths.")
dstats.load_or_prepare_text_lengths()
length_fid_dict = dstats.length_obj.get_filenames()
print("If all went well, then results are in the following files:")
for key, value in length_fid_dict.items():
print("%s: %s" % (key, value))
print()
if do_all or calculation == "labels":
logs.info("\n* Calculating label statistics.")
if dstats.label_field not in dstats.dset.features:
logs.warning("No label field found.")
logs.info("No label statistics to calculate.")
else:
dstats.load_or_prepare_labels()
npmi_fid_dict = dstats.label_files
print("If all went well, then results are in the following files:")
for key, value in npmi_fid_dict.items():
print("%s: %s" % (key, value))
print()
if do_all or calculation == "npmi":
print("\n* Preparing nPMI.")
dstats.load_or_prepare_npmi()
npmi_fid_dict = dstats.npmi_files
print("If all went well, then results are in the following files:")
for key, value in npmi_fid_dict.items():
if isinstance(value, dict):
print(key + ":")
for key2, value2 in value.items():
print("\t%s: %s" % (key2, value2))
else:
print("%s: %s" % (key, value))
print()
if do_all or calculation == "zipf":
logs.info("\n* Preparing Zipf.")
dstats.load_or_prepare_zipf()
logs.info("Done!")
zipf_json_fid, zipf_fig_json_fid, zipf_fig_html_fid = zipf.get_zipf_fids(
dstats.dataset_cache_dir)
logs.info("Zipf results now available at %s." % zipf_json_fid)
logs.info(
"Figure saved to %s, with corresponding json at %s."
% (zipf_fig_html_fid, zipf_fig_json_fid)
)
# Don't do this one until someone specifically asks for it -- takes awhile.
if calculation == "embeddings":
logs.info("\n* Preparing text embeddings.")
dstats.load_or_prepare_embeddings()
# Don't do this one until someone specifically asks for it -- takes awhile.
if calculation == "perplexities":
logs.info("\n* Preparing text perplexities.")
dstats.load_or_prepare_text_perplexities()
def pass_args_to_DMT(dset_name, dset_config, split_name, text_field, label_field, label_names, calculation, dataset_cache_dir, prepare_gui=False, use_cache=True):
if not use_cache:
logs.info("Not using any cache; starting afresh")
dataset_args = {
"dset_name": dset_name,
"dset_config": dset_config,
"split_name": split_name,
"text_field": text_field,
"label_field": label_field,
"label_names": label_names,
"dataset_cache_dir": dataset_cache_dir
}
if prepare_gui:
load_or_prepare_widgets(dataset_args, use_cache=use_cache)
else:
load_or_prepare(dataset_args, calculation=calculation, use_cache=use_cache)
def set_defaults(args):
if not args.config:
args.config = "default"
logs.info("Config name not specified. Assuming it's 'default'.")
if not args.split:
args.split = "train"
logs.info("Split name not specified. Assuming it's 'train'.")
if not args.feature:
args.feature = "text"
logs.info("Text column name not given. Assuming it's 'text'.")
if not args.label_field:
args.label_field = "label"
logs.info("Label column name not given. Assuming it's 'label'.")
return args
def main():
parser = argparse.ArgumentParser(
formatter_class=argparse.RawDescriptionHelpFormatter,
description=textwrap.dedent(
"""
Example for hate speech18 dataset:
python3 run_data_measurements.py --dataset="hate_speech18" --config="default" --split="train" --feature="text"
Example for IMDB dataset:
python3 run_data_measurements.py --dataset="imdb" --config="plain_text" --split="train" --label_field="label" --feature="text"
"""
),
)
parser.add_argument(
"-d", "--dataset", required=True, help="Name of dataset to prepare"
)
parser.add_argument(
"-c", "--config", required=False, default="", help="Dataset configuration to prepare"
)
parser.add_argument(
"-s", "--split", required=False, default="", type=str,
help="Dataset split to prepare"
)
parser.add_argument(
"-f",
"--feature",
"-t",
"--text-field",
required=False,
nargs="+",
type=str,
default="",
help="Column to prepare (handled as text)",
)
parser.add_argument(
"-w",
"--calculation",
help="""What to calculate (defaults to everything except embeddings and perplexities).\n
Options are:\n
- `general` (for duplicate counts, missing values, length statistics.)\n
- `duplicates` for duplicate counts\n
- `lengths` for text length distribution\n
- `labels` for label distribution\n
- `embeddings` (Warning: Slow.)\n
- `perplexities` (Warning: Slow.)\n
- `npmi` for word associations\n
- `zipf` for zipfian statistics
""",
)
parser.add_argument(
"-l",
"--label_field",
type=str,
required=False,
default="",
help="Field name for label column in dataset (Required if there is a label field that you want information about)",
)
parser.add_argument('-n', '--label_names', nargs='+', default=[])
parser.add_argument(
"--use_cache",
default=False,
required=False,
action="store_true",
help="Whether to use cached files (Optional)",
)
parser.add_argument("--out_dir", default="cache_dir",
help="Where to write out to.")
parser.add_argument(
"--overwrite_previous",
default=False,
required=False,
action="store_true",
help="Whether to overwrite a previous local cache for these same arguments (Optional)",
)
parser.add_argument(
"--email",
default=None,
help="An email that recieves a message about whether the computation was successful. If email is not None, then you must have EMAIL_PASSWORD=<your email password> for the sender email (data.measurements.tool@gmail.com) in a file named .env at the root of this repo.")
parser.add_argument(
"--push_cache_to_hub",
default=False,
required=False,
action="store_true",
help="Whether to push the cache to an organization on the hub. If you are using this option, you must have HUB_CACHE_ORGANIZATION=<the organization you've set up on the hub to store your cache> and HF_TOKEN=<your hf token> on separate lines in a file named .env at the root of this repo.",
)
parser.add_argument("--prepare_GUI_data", default=False, required=False,
action="store_true",
help="Use this to process all of the stats used in the GUI.")
parser.add_argument("--keep_local", default=True, required=False,
action="store_true",
help="Whether to save the data locally.")
orig_args = parser.parse_args()
args = set_defaults(orig_args)
logs.info("Proceeding with the following arguments:")
logs.info(args)
# run_data_measurements.py -d hate_speech18 -c default -s train -f text -w npmi
if args.email is not None:
if Path(".env").is_file():
load_dotenv(".env")
EMAIL_PASSWORD = getenv("EMAIL_PASSWORD")
context = ssl.create_default_context()
port = 465
server = smtplib.SMTP_SSL("smtp.gmail.com", port, context=context)
server.login("data.measurements.tool@gmail.com", EMAIL_PASSWORD)
dataset_cache_name, local_dataset_cache_dir = dataset_utils.get_cache_dir_naming(args.out_dir, args.dataset, args.config, args.split, args.feature)
if not args.use_cache and exists(local_dataset_cache_dir):
if args.overwrite_previous:
shutil.rmtree(local_dataset_cache_dir)
else:
raise OSError("Cached results for this dataset already exist at %s. "
"Delete it or use the --overwrite_previous argument." % local_dataset_cache_dir)
# Initialize the local cache directory
dataset_utils.make_path(local_dataset_cache_dir)
# Initialize the repository
# TODO: print out local or hub cache directory location.
if args.push_cache_to_hub:
repo = dataset_utils.initialize_cache_hub_repo(local_dataset_cache_dir, dataset_cache_name)
# Run the measurements.
try:
pass_args_to_DMT(
dset_name=args.dataset,
dset_config=args.config,
split_name=args.split,
text_field=args.feature,
label_field=args.label_field,
label_names=args.label_names,
calculation=args.calculation,
dataset_cache_dir=local_dataset_cache_dir,
prepare_gui=args.prepare_GUI_data,
use_cache=args.use_cache,
)
if args.push_cache_to_hub:
repo.push_to_hub(commit_message="Added dataset cache.")
computed_message = f"Data measurements have been computed for dataset" \
f" with these arguments: {args}."
logs.info(computed_message)
if args.email is not None:
computed_message += "\nYou can return to the data measurements tool " \
"to view them."
server.sendmail("data.measurements.tool@gmail.com", args.email,
"Subject: Data Measurements Computed!\n\n" + computed_message)
logs.info(computed_message)
except Exception as e:
logs.exception(e)
error_message = f"An error occurred in computing data measurements " \
f"for dataset with arguments: {args}. " \
f"Feel free to make an issue here: " \
f"https://github.com/huggingface/data-measurements-tool/issues"
if args.email is not None:
server.sendmail("data.measurements.tool@gmail.com", args.email,
"Subject: Data Measurements not Computed\n\n" + error_message)
logs.warning("Data measurements not computed. ☹️")
logs.warning(error_message)
return
if not args.keep_local:
# Remove the dataset from local storage - we only want it stored on the hub.
logs.warning("Deleting measurements data locally at %s" % local_dataset_cache_dir)
shutil.rmtree(local_dataset_cache_dir)
else:
logs.info("Measurements made available locally at %s" % local_dataset_cache_dir)
if __name__ == "__main__":
main()