import csv import functools from typing import Dict, List, Optional, Tuple import datasets import pkg_resources import seqio import t5 import tensorflow as tf from t5.data.glue_utils import get_glue_metric, get_super_glue_metric from t5.evaluation import metrics as mt import promptsource.templates from promptsource.seqio_tasks import utils GET_METRICS = { "BLEU": mt.bleu, "ROUGE": mt.rouge, "Span Squad": mt.span_squad, "Squad": mt.squad, "Trivia QA": mt.trivia_qa, "Accuracy": mt.accuracy, "Sequence Accuracy": mt.sequence_accuracy, "Pearson Correlation": mt.pearson_corrcoef, "Spearman Correlation": mt.spearman_corrcoef, "MultiRC": mt.multirc_f1_over_all_answers, "AUC": mt.auc, "COQA F1": mt.coqa_f1, "Edit Distance": mt.edit_distance, # "Mean Reciprocal Rank": mt.accuracy, # NOTE not in T5? "Other": mt.accuracy, # Missing support for mean_multiclass_f1 etc. which need a num_classes parameter } MAX_EXAMPLES_PER_DATASET = 500_000 def strip_whitespace(output_or_target, example=None, is_target=False): """Cached tasks from promptsource all have a leading space on the ground-truth targets.""" return output_or_target.strip() def maybe_get_class_id_postprocessor(template): if template.get_fixed_answer_choices_list(): def postprocess_fn(output_or_target, example=None, is_target=False): output_or_target = strip_whitespace(output_or_target) return t5.data.postprocessors.string_label_to_class_id( output_or_target, label_classes=template.get_fixed_answer_choices_list() ) return postprocess_fn else: return strip_whitespace def get_tf_dataset(split, shuffle_files, seed, dataset_name, subset_name, template, split_mapping): # HF datasets does not support file-level shuffling del shuffle_files, seed dataset = datasets.load_dataset(dataset_name, subset_name) dataset = dataset[split_mapping[split]] dataset = utils.apply_template(dataset, template) return utils.hf_dataset_to_tf_dataset(dataset) def add_task(dataset_name, subset_name, template_name, task_name=None, split_mapping=None): template = all_templates.get_dataset(dataset_name, subset_name)[template_name] task_name = task_name or utils.get_task_name(dataset_name, subset_name, template_name) if dataset_name == "glue": metrics = get_glue_metric(subset_name) elif dataset_name == "super_glue": if subset_name in ("wsc.fixed", "multirc"): # TODO: WSC and MultiRC need special pre/postprocesing metrics = [mt.accuracy] else: metrics = get_super_glue_metric(subset_name) else: # TODO what if metric is null? metrics = [GET_METRICS[m] for m in template.metadata.metrics] dataset_splits = utils.get_dataset_splits(dataset_name, subset_name) split_mapping = split_mapping or {k: k for k in dataset_splits.keys()} dataset_fn = functools.partial( get_tf_dataset, seed=None, dataset_name=dataset_name, subset_name=subset_name, template=template, split_mapping=split_mapping, ) data_source = seqio.FunctionDataSource( dataset_fn, splits=list(split_mapping.keys()), num_input_examples={s: dataset_splits[split_mapping[s]].num_examples for s in split_mapping.keys()}, ) output_features = { "inputs": seqio.Feature(t5.data.get_default_vocabulary(), add_eos=False, dtype=tf.int32), "targets": seqio.Feature(t5.data.get_default_vocabulary(), add_eos=True, dtype=tf.int32), } preprocessors = [ seqio.preprocessors.tokenize, seqio.preprocessors.append_eos, seqio.CacheDatasetPlaceholder(required=False), ] # Add train and normal eval tasks seqio.TaskRegistry.add( task_name, data_source, preprocessors=preprocessors, output_features=output_features, metric_fns=metrics, postprocess_fn=maybe_get_class_id_postprocessor(template), ) # Add rank classification eval task if template.answer_choices: rank_classification_preprocessor = functools.partial( t5.data.preprocessors.rank_classification, inputs_fn=lambda ex: tf.fill((len(ex["answer_choices"]),), ex["inputs"]), targets_fn=lambda ex: ex["answer_choices"], is_correct_fn=lambda ex: tf.equal(ex["answer_choices"], tf.strings.strip(ex["targets"])), weight_fn=lambda ex: 1.0, ) fixed_choices = template.get_fixed_answer_choices_list() num_classes = len(fixed_choices) if fixed_choices else None seqio.TaskRegistry.add( task_name + "_score_eval", data_source, preprocessors=[rank_classification_preprocessor] + preprocessors, output_features=output_features, metric_fns=[functools.partial(t5.evaluation.metrics.rank_classification, num_classes=num_classes)], postprocess_fn=t5.data.postprocessors.rank_classification, ) datatset_subset_tuple = Tuple[str, Optional[str]] d4_train: List[datatset_subset_tuple] = [] d4_eval: List[datatset_subset_tuple] = [] d3_train_gpt: List[datatset_subset_tuple] = [] d3_train_sglue: List[datatset_subset_tuple] = [] bias_fairness_eval: List[datatset_subset_tuple] = [] gsheet: Dict[datatset_subset_tuple, Dict] = {} experiment_path = pkg_resources.resource_filename(__name__, "experiment_D4.csv") with open(experiment_path) as exp_file: reader = csv.DictReader(exp_file) for row in reader: if row["skip"]: continue if row["subset"] == "": row["subset"] = None # to match promptsource.Template object dataset_subset = (row["HF_name"], row["subset"]) if row["do_train"] == "TRUE": d4_train.append(dataset_subset) if row["do_eval"] == "TRUE": d4_eval.append(dataset_subset) if row["D3_do_train"] == "TRUE" and "GPT" in row["seed_paper"]: d3_train_gpt.append(dataset_subset) if row["D3_do_train"] == "TRUE" and row["HF_name"] == "super_glue": d3_train_sglue.append(dataset_subset) if ( row["do_eval"] == "TRUE" and row["task_by_convention"] == "bias_and_fairness" and row["HF_name"] != "winogender" ): bias_fairness_eval.append(dataset_subset) gsheet[dataset_subset] = row all_datasets = d4_train + d4_eval + d3_train_gpt + d3_train_sglue + bias_fairness_eval all_templates = promptsource.templates.TemplateCollection() all_templates.remove("anli") # Need to special-case ANLI due to weird split conventions # 3 stages of training/ablation: D4 -> GPT -> SuperGLUE d4_train_mixture: List[str] = [] # strings are dataset_subset_template gpt_train_mixture: List[str] = [] sglue_train_mixture: List[str] = [] d4_eval_mixture: List[str] = [] bias_fairness_eval_mixture: List[str] = [] mixture_cap: Dict[str, int] = {} single_original_task: Dict[Tuple[str, str], str] = {} all_original_tasks: List[str] = [] for dataset_name, subset_name in all_templates.keys: if (dataset_name, subset_name) not in all_datasets: all_templates.remove(dataset_name, subset_name) continue dataset = all_templates.get_dataset(dataset_name, subset_name) num_templates = len(dataset.all_template_names) train_size = gsheet[(dataset_name, subset_name)]["train_size"] if train_size == "": train_size = 0 else: train_size = int(train_size) if train_size > MAX_EXAMPLES_PER_DATASET: cap = MAX_EXAMPLES_PER_DATASET // num_templates else: cap = train_size for template_name in dataset.all_template_names: add_task(dataset_name, subset_name, template_name) template = dataset[template_name] task_name = utils.get_task_name(dataset_name, subset_name, template_name) if (dataset_name, subset_name) not in single_original_task and template.metadata.original_task: single_original_task[(dataset_name, subset_name)] = task_name if template.metadata.original_task: all_original_tasks.append(task_name) if (dataset_name, subset_name) in d4_train: d4_train_mixture.append(task_name) mixture_cap[task_name] = cap if (dataset_name, subset_name) in d3_train_gpt: gpt_train_mixture.append(task_name) mixture_cap[task_name] = cap if (dataset_name, subset_name) in d3_train_sglue: sglue_train_mixture.append(task_name) mixture_cap[task_name] = cap if (dataset_name, subset_name) in d4_eval: if template.metadata.original_task: d4_eval_mixture.append(task_name) # TODO use template.metadata.answer_choices here for rank eval if (dataset_name, subset_name) in bias_fairness_eval: bias_fairness_eval_mixture.append(task_name) # Special case for ANLI, which has weirdly-named splits and rounds that should be subsets dataset_name, subset_name = ("anli", None) dataset = all_templates.get_dataset(dataset_name, subset_name) for anli_round in ("r1", "r2", "r3"): for template_name in all_templates.get_dataset(dataset_name, subset_name).all_template_names: task_name = utils.get_task_name(dataset_name, subset_name, template_name) + f"_{anli_round}" split_mapping = { "train": f"train_{anli_round}", "validation": f"dev_{anli_round}", "test": f"test_{anli_round}", } add_task(dataset_name, subset_name, template_name, task_name, split_mapping) template = dataset[template_name] if template.metadata.original_task: d4_eval_mixture.append(task_name) # TODO or add to ANLI special mixture # TODO use template.metadata.answer_choices here for rank eval TASK_BLACKLIST = [ # Tasks which often tokenize to > 1024 tokens currently "hotpot_qa_distractor_Generate_Explanations", "hotpot_qa_fullwiki_Generate_Explanations", "hotpot_qa_distractor_Generate_Answer_and_Explanations", "hotpot_qa_fullwiki_Generate_Answer_and_Explanations", "hotpot_qa_fullwiki_Generate_Answer", "hotpot_qa_distractor_Generate_Answer", "hotpot_qa_distractor_Generate_Title_2", "hotpot_qa_fullwiki_Generate_Title_2", "hotpot_qa_fullwiki_Generate_Title_1", "hotpot_qa_distractor_Generate_Title_1", "hotpot_qa_distractor_Generate_Question", "hotpot_qa_fullwiki_Generate_Question", "tab_fact_tab_fact_tab_fact_3", "tab_fact_tab_fact_tab_fact_2", "tab_fact_tab_fact_tab_fact_1", "tab_fact_tab_fact_tab_fact_7", "tab_fact_tab_fact_tab_fact_4", "tab_fact_tab_fact_tab_fact_5", "tab_fact_tab_fact_tab_fact_6", "wiki_hop_masked_Choose_Best_Object_Candidate", "wiki_hop_masked_Indirect_Question_about_Birthplace_Citizenship_Place_of_Death", "narrativeqa_Template_05", "ecthr_cases_alleged_violation_prediction_silver_rationales", # Tasks with broken cached files "gigaword_summarize_", ] # Tasks that failed caching (won't try to fix them for now) - remove when we are done D4_TRAIN_SCORE_EVAL_TASK_BLACKLIST = [ "amazon_polarity_Is_this_product_review_positive_score_eval", "amazon_polarity_Is_this_review_negative_score_eval", "amazon_polarity_Is_this_review_score_eval", "amazon_polarity_User_recommend_this_product_score_eval", "amazon_polarity_convey_negative_or_positive_sentiment_score_eval", "amazon_polarity_flattering_or_not_score_eval", "amazon_polarity_negative_or_positive_tone_score_eval", "amazon_polarity_user_satisfied_score_eval", "amazon_polarity_would_you_buy_score_eval", "dbpedia_14_given_a_choice_of_categories__score_eval", "dbpedia_14_given_list_what_category_does_the_paragraph_belong_to_score_eval", "dbpedia_14_pick_one_category_for_the_following_text_score_eval", "wiki_hop_original_choose_best_object_affirmative_1_score_eval", "wiki_hop_original_choose_best_object_affirmative_2_score_eval", "wiki_hop_original_choose_best_object_affirmative_3_score_eval", "wiki_hop_original_choose_best_object_interrogative_1_score_eval", "wiki_hop_original_choose_best_object_interrogative_2_score_eval", ] seqio.MixtureRegistry.add( "d4_train", [task for task in d4_train_mixture if task not in TASK_BLACKLIST], default_rate=lambda t: mixture_cap[t.name], ) seqio.MixtureRegistry.add( "gpt_train", [task for task in gpt_train_mixture if task not in TASK_BLACKLIST], default_rate=lambda t: mixture_cap[t.name], ) seqio.MixtureRegistry.add( "sglue_train", [task for task in sglue_train_mixture if task not in TASK_BLACKLIST], default_rate=lambda t: mixture_cap[t.name], ) seqio.MixtureRegistry.add( "d4_gpt_train", [task for task in d4_train_mixture + gpt_train_mixture if task not in TASK_BLACKLIST], default_rate=lambda t: mixture_cap[t.name], ) seqio.MixtureRegistry.add( "d4_gpt_sglue_train", [task for task in d4_train_mixture + gpt_train_mixture + sglue_train_mixture if task not in TASK_BLACKLIST], default_rate=lambda t: mixture_cap[t.name], ) seqio.MixtureRegistry.add( "d4_eval", [task for task in d4_eval_mixture if task not in TASK_BLACKLIST], default_rate=functools.partial(seqio.mixing_rate_num_examples, maximum=500_000), ) # eval mixture does not need to be capped seqio.MixtureRegistry.add( "d4_score_eval", [ task for task in seqio.TaskRegistry.names() if task.endswith("_score_eval") and task.split("_score_eval")[0] in d4_eval_mixture and task.split("_score_eval")[0] not in TASK_BLACKLIST ], default_rate=functools.partial(seqio.mixing_rate_num_examples, maximum=500_000), ) # Train tasks we don't care about evaluating on D4_TRAIN_SKIP_EVAL = [ "paws_labeled_final", "adversarial_qa_dbidaf", "adversarial_qa_dbert", "duorc_ParaphraseRC", "dream", "amazon_polarity", "app_reviews", "imdb", "wiki_bio", "gigaword", "multi_news", "samsum", "dbpedia_14", "trec", ] seqio.MixtureRegistry.add( "d4_train_eval", [ task for task in d4_train_mixture if task not in TASK_BLACKLIST and not any([skip in task for skip in D4_TRAIN_SKIP_EVAL]) and task in all_original_tasks ], default_rate=lambda t: mixture_cap[t.name], ) seqio.MixtureRegistry.add( "d4_train_score_eval", [ task for task in seqio.TaskRegistry.names() if task.endswith("_score_eval") and task.split("_score_eval")[0] in d4_train_mixture and task.split("_score_eval")[0] not in TASK_BLACKLIST and task not in D4_TRAIN_SCORE_EVAL_TASK_BLACKLIST and not any([skip in task for skip in D4_TRAIN_SKIP_EVAL]) and task.split("_score_eval")[0] in all_original_tasks ], default_rate=functools.partial(seqio.mixing_rate_num_examples, maximum=500_000), ) seqio.MixtureRegistry.add( "d4_train_one_og_prompt", [task for task in single_original_task.values() if task in d4_train_mixture and task not in TASK_BLACKLIST], default_rate=lambda t: mixture_cap[t.name], ) seqio.MixtureRegistry.add( "d4_train_all_og_prompts", [task for task in all_original_tasks if task in d4_train_mixture and task not in TASK_BLACKLIST], default_rate=lambda t: mixture_cap[t.name], ) seqio.MixtureRegistry.add( "bias_fairness_eval", bias_fairness_eval_mixture, default_rate=functools.partial(seqio.mixing_rate_num_examples, maximum=500_000), ) seqio.MixtureRegistry.add( "bias_fairness_eval_score_eval", [ task for task in seqio.TaskRegistry.names() if task.endswith("_score_eval") and task.split("_score_eval")[0] in bias_fairness_eval_mixture ], default_rate=functools.partial(seqio.mixing_rate_num_examples, maximum=500_000), )