Spaces:
Sleeping
Sleeping
| import os.path | |
| from datasets.load import load_dataset, load_metric | |
| from transformers import ( | |
| AutoTokenizer, | |
| DataCollatorWithPadding, | |
| EvalPrediction, | |
| default_data_collator, | |
| ) | |
| import hashlib, torch | |
| import numpy as np | |
| import logging | |
| from collections import defaultdict | |
| task_to_keys = { | |
| "boolq": ("question", "passage"), | |
| "cb": ("premise", "hypothesis"), | |
| "rte": ("premise", "hypothesis"), | |
| "wic": ("processed_sentence1", None), | |
| "wsc": ("span2_word_text", "span1_text"), | |
| "copa": (None, None), | |
| "record": (None, None), | |
| "multirc": ("paragraph", "question_answer") | |
| } | |
| logger = logging.getLogger(__name__) | |
| class SuperGlueDataset(): | |
| def __init__(self, tokenizer: AutoTokenizer, data_args, training_args) -> None: | |
| super().__init__() | |
| raw_datasets = load_dataset("super_glue", data_args.dataset_name) | |
| self.tokenizer = tokenizer | |
| self.data_args = data_args | |
| self.multiple_choice = data_args.dataset_name in ["copa"] | |
| if data_args.dataset_name == "record": | |
| self.num_labels = 2 | |
| self.label_list = ["0", "1"] | |
| elif not self.multiple_choice: | |
| self.label_list = raw_datasets["train"].features["label"].names | |
| self.num_labels = len(self.label_list) | |
| else: | |
| self.num_labels = 1 | |
| # Preprocessing the raw_datasets | |
| self.sentence1_key, self.sentence2_key = task_to_keys[data_args.dataset_name] | |
| # Padding strategy | |
| if data_args.pad_to_max_length: | |
| self.padding = "max_length" | |
| else: | |
| # We will pad later, dynamically at batch creation, to the max sequence length in each batch | |
| self.padding = False | |
| if not self.multiple_choice: | |
| self.label2id = {l: i for i, l in enumerate(self.label_list)} | |
| self.id2label = {id: label for label, id in self.label2id.items()} | |
| print(f"{self.label2id}") | |
| print(f"{self.id2label}") | |
| if data_args.max_seq_length > tokenizer.model_max_length: | |
| logger.warning( | |
| f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" | |
| f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." | |
| ) | |
| self.max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) | |
| if data_args.dataset_name == "record": | |
| digest = hashlib.md5(f"record_{tokenizer.name_or_path}".encode("utf-8")).hexdigest()[:16] # 16 byte binary | |
| path = raw_datasets["train"]._get_cache_file_path("").replace("cache-.arrow", f"record-{digest}.arrow") | |
| if not os.path.exists(path): | |
| print(f"-> path not found!:{path}") | |
| raw_datasets = raw_datasets.map( | |
| self.record_preprocess_function, | |
| batched=True, | |
| load_from_cache_file=not data_args.overwrite_cache, | |
| remove_columns=raw_datasets["train"].column_names, | |
| desc="Running tokenizer on dataset", | |
| ) | |
| data = {"raw_datasets": raw_datasets} | |
| torch.save(data, path) | |
| raw_datasets = torch.load(path)["raw_datasets"] | |
| else: | |
| raw_datasets = raw_datasets.map( | |
| self.preprocess_function, | |
| batched=True, | |
| load_from_cache_file=not data_args.overwrite_cache, | |
| desc="Running tokenizer on dataset", | |
| ) | |
| if training_args.do_train: | |
| self.train_dataset = raw_datasets["train"] | |
| if data_args.max_train_samples is not None: | |
| self.train_dataset = self.train_dataset.select(range(data_args.max_train_samples)) | |
| if training_args.do_eval: | |
| self.eval_dataset = raw_datasets["validation"] | |
| if data_args.max_eval_samples is not None: | |
| self.eval_dataset = self.eval_dataset.select(range(data_args.max_eval_samples)) | |
| if training_args.do_predict or data_args.dataset_name is not None or data_args.test_file is not None: | |
| self.predict_dataset = raw_datasets["test"] | |
| if data_args.max_predict_samples is not None: | |
| self.predict_dataset = self.predict_dataset.select(range(data_args.max_predict_samples)) | |
| self.metric = load_metric("super_glue", data_args.dataset_name) | |
| if data_args.pad_to_max_length: | |
| self.data_collator = default_data_collator | |
| elif training_args.fp16: | |
| self.data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) | |
| self.test_key = "accuracy" if data_args.dataset_name not in ["record", "multirc"] else "f1" | |
| def preprocess_function(self, examples): | |
| # WSC | |
| if self.data_args.dataset_name == "wsc": | |
| examples["span2_word_text"] = [] | |
| for text, span2_index, span2_word in zip(examples["text"], examples["span2_index"], examples["span2_text"]): | |
| if self.data_args.template_id == 0: | |
| examples["span2_word_text"].append(span2_word + ": " + text) | |
| elif self.data_args.template_id == 1: | |
| words_a = text.split() | |
| words_a[span2_index] = "*" + words_a[span2_index] + "*" | |
| examples["span2_word_text"].append(' '.join(words_a)) | |
| # WiC | |
| if self.data_args.dataset_name == "wic": | |
| examples["processed_sentence1"] = [] | |
| if self.data_args.template_id == 1: | |
| self.sentence2_key = "processed_sentence2" | |
| examples["processed_sentence2"] = [] | |
| for sentence1, sentence2, word, start1, end1, start2, end2 in zip(examples["sentence1"], | |
| examples["sentence2"], examples["word"], | |
| examples["start1"], examples["end1"], | |
| examples["start2"], examples["end2"]): | |
| if self.data_args.template_id == 0: # ROBERTA | |
| examples["processed_sentence1"].append( | |
| f"{sentence1} {sentence2} Does {word} have the same meaning in both sentences?") | |
| elif self.data_args.template_id == 1: # BERT | |
| examples["processed_sentence1"].append(word + ": " + sentence1) | |
| examples["processed_sentence2"].append(word + ": " + sentence2) | |
| # MultiRC | |
| if self.data_args.dataset_name == "multirc": | |
| examples["question_answer"] = [] | |
| for question, asnwer in zip(examples["question"], examples["answer"]): | |
| examples["question_answer"].append(f"{question} {asnwer}") | |
| # COPA | |
| if self.data_args.dataset_name == "copa": | |
| examples["text_a"] = [] | |
| for premise, question in zip(examples["premise"], examples["question"]): | |
| joiner = "because" if question == "cause" else "so" | |
| text_a = f"{premise} {joiner}" | |
| examples["text_a"].append(text_a) | |
| result1 = self.tokenizer(examples["text_a"], examples["choice1"], padding=self.padding, | |
| max_length=self.max_seq_length, truncation=True) | |
| result2 = self.tokenizer(examples["text_a"], examples["choice2"], padding=self.padding, | |
| max_length=self.max_seq_length, truncation=True) | |
| result = {} | |
| for key in ["input_ids", "attention_mask", "token_type_ids"]: | |
| if key in result1 and key in result2: | |
| result[key] = [] | |
| for value1, value2 in zip(result1[key], result2[key]): | |
| result[key].append([value1, value2]) | |
| return result | |
| args = ( | |
| (examples[self.sentence1_key],) if self.sentence2_key is None else ( | |
| examples[self.sentence1_key], examples[self.sentence2_key]) | |
| ) | |
| result = self.tokenizer(*args, padding=self.padding, max_length=self.max_seq_length, truncation=True) | |
| return result | |
| def compute_metrics(self, p: EvalPrediction): | |
| preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions | |
| preds = np.argmax(preds, axis=1) | |
| if self.data_args.dataset_name == "record": | |
| return self.reocrd_compute_metrics(p) | |
| if self.data_args.dataset_name == "multirc": | |
| from sklearn.metrics import f1_score | |
| return {"f1": f1_score(preds, p.label_ids)} | |
| if self.data_args.dataset_name is not None: | |
| result = self.metric.compute(predictions=preds, references=p.label_ids) | |
| if len(result) > 1: | |
| result["combined_score"] = np.mean(list(result.values())).item() | |
| return result | |
| elif self.is_regression: | |
| return {"mse": ((preds - p.label_ids) ** 2).mean().item()} | |
| else: | |
| return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()} | |
| def reocrd_compute_metrics(self, p: EvalPrediction): | |
| from .utils import f1_score, exact_match_score, metric_max_over_ground_truths | |
| probs = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions | |
| examples = self.eval_dataset | |
| qid2pred = defaultdict(list) | |
| qid2ans = {} | |
| for prob, example in zip(probs, examples): | |
| qid = example['question_id'] | |
| qid2pred[qid].append((prob[1], example['entity'])) | |
| if qid not in qid2ans: | |
| qid2ans[qid] = example['answers'] | |
| n_correct, n_total = 0, 0 | |
| f1, em = 0, 0 | |
| for qid in qid2pred: | |
| preds = sorted(qid2pred[qid], reverse=True) | |
| entity = preds[0][1] | |
| n_total += 1 | |
| n_correct += (entity in qid2ans[qid]) | |
| f1 += metric_max_over_ground_truths(f1_score, entity, qid2ans[qid]) | |
| em += metric_max_over_ground_truths(exact_match_score, entity, qid2ans[qid]) | |
| acc = n_correct / n_total | |
| f1 = f1 / n_total | |
| em = em / n_total | |
| return {'f1': f1, 'exact_match': em} | |
| def record_preprocess_function(self, examples, split="train"): | |
| results = { | |
| "index": list(), | |
| "question_id": list(), | |
| "input_ids": list(), | |
| "attention_mask": list(), | |
| #"token_type_ids": list(), | |
| "label": list(), | |
| "entity": list(), | |
| "answers": list() | |
| } | |
| for idx, passage in enumerate(examples["passage"]): | |
| query, entities, answers = examples["query"][idx], examples["entities"][idx], examples["answers"][idx] | |
| index = examples["idx"][idx] | |
| passage = passage.replace("@highlight\n", "- ").replace(self.tokenizer.prompt_token, "").replace(self.tokenizer.skey_token, "").replace(self.tokenizer.predict_token, "") | |
| for ent_idx, ent in enumerate(entities): | |
| question = query.replace("@placeholder", ent).replace(self.tokenizer.prompt_token, "").replace(self.tokenizer.skey_token, "").replace(self.tokenizer.predict_token, "") | |
| result = self.tokenizer(passage, question, padding=self.padding, max_length=self.max_seq_length, | |
| truncation=True) | |
| label = 1 if ent in answers else 0 | |
| results["input_ids"].append(result["input_ids"]) | |
| results["attention_mask"].append(result["attention_mask"]) | |
| #if "token_type_ids" in result.keys(): results["token_type_ids"].append(result["token_type_ids"]) | |
| results["label"].append(label) | |
| results["index"].append(index) | |
| results["question_id"].append(index["query"]) | |
| results["entity"].append(ent) | |
| results["answers"].append(answers) | |
| return results |