Spaces:
Runtime error
Runtime error
| # Copyright 2020 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| import time | |
| import warnings | |
| from dataclasses import dataclass, field | |
| from enum import Enum | |
| from typing import List, Optional, Union | |
| import torch | |
| from filelock import FileLock | |
| from torch.utils.data import Dataset | |
| from ...tokenization_utils_base import PreTrainedTokenizerBase | |
| from ...utils import logging | |
| from ..processors.glue import glue_convert_examples_to_features, glue_output_modes, glue_processors | |
| from ..processors.utils import InputFeatures | |
| logger = logging.get_logger(__name__) | |
| class GlueDataTrainingArguments: | |
| """ | |
| Arguments pertaining to what data we are going to input our model for training and eval. | |
| Using `HfArgumentParser` we can turn this class into argparse arguments to be able to specify them on the command | |
| line. | |
| """ | |
| task_name: str = field(metadata={"help": "The name of the task to train on: " + ", ".join(glue_processors.keys())}) | |
| data_dir: str = field( | |
| metadata={"help": "The input data dir. Should contain the .tsv files (or other data files) for the task."} | |
| ) | |
| max_seq_length: int = field( | |
| default=128, | |
| metadata={ | |
| "help": ( | |
| "The maximum total input sequence length after tokenization. Sequences longer " | |
| "than this will be truncated, sequences shorter will be padded." | |
| ) | |
| }, | |
| ) | |
| overwrite_cache: bool = field( | |
| default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} | |
| ) | |
| def __post_init__(self): | |
| self.task_name = self.task_name.lower() | |
| class Split(Enum): | |
| train = "train" | |
| dev = "dev" | |
| test = "test" | |
| class GlueDataset(Dataset): | |
| """ | |
| This will be superseded by a framework-agnostic approach soon. | |
| """ | |
| args: GlueDataTrainingArguments | |
| output_mode: str | |
| features: List[InputFeatures] | |
| def __init__( | |
| self, | |
| args: GlueDataTrainingArguments, | |
| tokenizer: PreTrainedTokenizerBase, | |
| limit_length: Optional[int] = None, | |
| mode: Union[str, Split] = Split.train, | |
| cache_dir: Optional[str] = None, | |
| ): | |
| warnings.warn( | |
| "This dataset will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets " | |
| "library. You can have a look at this example script for pointers: " | |
| "https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_glue.py", | |
| FutureWarning, | |
| ) | |
| self.args = args | |
| self.processor = glue_processors[args.task_name]() | |
| self.output_mode = glue_output_modes[args.task_name] | |
| if isinstance(mode, str): | |
| try: | |
| mode = Split[mode] | |
| except KeyError: | |
| raise KeyError("mode is not a valid split name") | |
| # Load data features from cache or dataset file | |
| cached_features_file = os.path.join( | |
| cache_dir if cache_dir is not None else args.data_dir, | |
| f"cached_{mode.value}_{tokenizer.__class__.__name__}_{args.max_seq_length}_{args.task_name}", | |
| ) | |
| label_list = self.processor.get_labels() | |
| if args.task_name in ["mnli", "mnli-mm"] and tokenizer.__class__.__name__ in ( | |
| "RobertaTokenizer", | |
| "RobertaTokenizerFast", | |
| "XLMRobertaTokenizer", | |
| "BartTokenizer", | |
| "BartTokenizerFast", | |
| ): | |
| # HACK(label indices are swapped in RoBERTa pretrained model) | |
| label_list[1], label_list[2] = label_list[2], label_list[1] | |
| self.label_list = label_list | |
| # Make sure only the first process in distributed training processes the dataset, | |
| # and the others will use the cache. | |
| lock_path = cached_features_file + ".lock" | |
| with FileLock(lock_path): | |
| if os.path.exists(cached_features_file) and not args.overwrite_cache: | |
| start = time.time() | |
| self.features = torch.load(cached_features_file) | |
| logger.info( | |
| f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start | |
| ) | |
| else: | |
| logger.info(f"Creating features from dataset file at {args.data_dir}") | |
| if mode == Split.dev: | |
| examples = self.processor.get_dev_examples(args.data_dir) | |
| elif mode == Split.test: | |
| examples = self.processor.get_test_examples(args.data_dir) | |
| else: | |
| examples = self.processor.get_train_examples(args.data_dir) | |
| if limit_length is not None: | |
| examples = examples[:limit_length] | |
| self.features = glue_convert_examples_to_features( | |
| examples, | |
| tokenizer, | |
| max_length=args.max_seq_length, | |
| label_list=label_list, | |
| output_mode=self.output_mode, | |
| ) | |
| start = time.time() | |
| torch.save(self.features, cached_features_file) | |
| # ^ This seems to take a lot of time so I want to investigate why and how we can improve. | |
| logger.info( | |
| f"Saving features into cached file {cached_features_file} [took {time.time() - start:.3f} s]" | |
| ) | |
| def __len__(self): | |
| return len(self.features) | |
| def __getitem__(self, i) -> InputFeatures: | |
| return self.features[i] | |
| def get_labels(self): | |
| return self.label_list | |