""" Geneformer precollator and pretrainer. Huggingface data collator and trainer modified to accommodate single-cell transcriptomics data. """ import collections import math import pickle import warnings from enum import Enum from typing import Dict, Iterator, List, Optional, Union import numpy as np import torch from datasets import Dataset from packaging import version from torch.utils.data.distributed import DistributedSampler from torch.utils.data.sampler import RandomSampler from transformers import ( BatchEncoding, DataCollatorForLanguageModeling, SpecialTokensMixin, Trainer, ) from transformers.file_utils import is_datasets_available, is_sagemaker_dp_enabled from transformers.trainer_pt_utils import ( DistributedLengthGroupedSampler, DistributedSamplerWithLoop, LengthGroupedSampler, ) from transformers.training_args import ParallelMode from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj from transformers.utils.generic import _is_tensorflow, _is_torch from .tokenizer import TOKEN_DICTIONARY_FILE logger = logging.get_logger(__name__) EncodedInput = List[int] VERY_LARGE_INTEGER = int( 1e30 ) # This is used to set the max input length for a model with infinite size input LARGE_INTEGER = int( 1e20 ) # This is used when we need something big but slightly smaller than VERY_LARGE_INTEGER if is_sagemaker_dp_enabled(): import smdistributed.dataparallel.torch.distributed as dist else: import torch.distributed as dist _is_torch_generator_available = False if version.parse(torch.__version__) >= version.parse("1.6"): _is_torch_generator_available = True with open(TOKEN_DICTIONARY_FILE, "rb") as f: token_dictionary = pickle.load(f) class ExplicitEnum(Enum): """ Enum with more explicit error message for missing values. """ @classmethod def _missing_(cls, value): raise ValueError( "%r is not a valid %s, please select one of %s" % (value, cls.__name__, str(list(cls._value2member_map_.keys()))) ) class TruncationStrategy(ExplicitEnum): """ Possible values for the ``truncation`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for tab-completion in an IDE. """ ONLY_FIRST = "only_first" ONLY_SECOND = "only_second" LONGEST_FIRST = "longest_first" DO_NOT_TRUNCATE = "do_not_truncate" class PaddingStrategy(ExplicitEnum): """ Possible values for the ``padding`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for tab-completion in an IDE. """ LONGEST = "longest" MAX_LENGTH = "max_length" DO_NOT_PAD = "do_not_pad" class TensorType(ExplicitEnum): """ Possible values for the ``return_tensors`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for tab-completion in an IDE. """ PYTORCH = "pt" TENSORFLOW = "tf" NUMPY = "np" JAX = "jax" class GeneformerPreCollator(SpecialTokensMixin): def __init__(self, *args, **kwargs) -> None: self.token_dictionary = kwargs.get("token_dictionary") self.mask_token = "" self.mask_token_id = self.token_dictionary.get("") self.pad_token = "" self.pad_token_id = self.token_dictionary.get("") self.padding_side = "right" self.all_special_ids = [ self.token_dictionary.get(""), self.token_dictionary.get(""), ] self.model_input_names = ["input_ids"] super().__init__(*args, **kwargs) def _get_padding_truncation_strategies( self, padding=False, truncation=False, max_length=None, pad_to_multiple_of=None, verbose=True, **kwargs, ): """ Find the correct padding/truncation strategy with backward compatibility for old arguments (truncation_strategy and pad_to_max_length) and behaviors. """ old_truncation_strategy = kwargs.pop("truncation_strategy", "do_not_truncate") old_pad_to_max_length = kwargs.pop("pad_to_max_length", False) # Backward compatibility for previous behavior, maybe we should deprecate it: # If you only set max_length, it activates truncation for max_length if max_length is not None and padding is False and truncation is False: if verbose: if not self.deprecation_warnings.get( "Truncation-not-explicitly-activated", False ): logger.warning( "Truncation was not explicitly activated but `max_length` is provided a specific value, " "please use `truncation=True` to explicitly truncate examples to max length. " "Defaulting to 'longest_first' truncation strategy. " "If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy " "more precisely by providing a specific strategy to `truncation`." ) self.deprecation_warnings["Truncation-not-explicitly-activated"] = True truncation = "longest_first" # Get padding strategy if padding is False and old_pad_to_max_length: if verbose: warnings.warn( "The `pad_to_max_length` argument is deprecated and will be removed in a future version, " "use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or " "use `padding='max_length'` to pad to a max length. In this case, you can give a specific " "length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the " "maximal input size of the model (e.g. 512 for Bert).", FutureWarning, ) if max_length is None: padding_strategy = PaddingStrategy.LONGEST else: padding_strategy = PaddingStrategy.MAX_LENGTH elif padding is not False: if padding is True: padding_strategy = ( PaddingStrategy.LONGEST ) # Default to pad to the longest sequence in the batch elif not isinstance(padding, PaddingStrategy): padding_strategy = PaddingStrategy(padding) elif isinstance(padding, PaddingStrategy): padding_strategy = padding else: padding_strategy = PaddingStrategy.DO_NOT_PAD # Get truncation strategy if truncation is False and old_truncation_strategy != "do_not_truncate": if verbose: warnings.warn( "The `truncation_strategy` argument is deprecated and will be removed in a future version, " "use `truncation=True` to truncate examples to a max length. You can give a specific " "length with `max_length` (e.g. `max_length=45`) or leave max_length to None to truncate to the " "maximal input size of the model (e.g. 512 for Bert). " " If you have pairs of inputs, you can give a specific truncation strategy selected among " "`truncation='only_first'` (will only truncate the first sentence in the pairs) " "`truncation='only_second'` (will only truncate the second sentence in the pairs) " "or `truncation='longest_first'` (will iteratively remove tokens from the longest sentence in the pairs).", FutureWarning, ) truncation_strategy = TruncationStrategy(old_truncation_strategy) elif truncation is not False: if truncation is True: truncation_strategy = ( TruncationStrategy.LONGEST_FIRST ) # Default to truncate the longest sequences in pairs of inputs elif not isinstance(truncation, TruncationStrategy): truncation_strategy = TruncationStrategy(truncation) elif isinstance(truncation, TruncationStrategy): truncation_strategy = truncation else: truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE # Set max length if needed if max_length is None: if padding_strategy == PaddingStrategy.MAX_LENGTH: if self.model_max_length > LARGE_INTEGER: if verbose: if not self.deprecation_warnings.get( "Asking-to-pad-to-max_length", False ): logger.warning( "Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. " "Default to no padding." ) self.deprecation_warnings["Asking-to-pad-to-max_length"] = True padding_strategy = PaddingStrategy.DO_NOT_PAD else: max_length = self.model_max_length if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE: if self.model_max_length > LARGE_INTEGER: if verbose: if not self.deprecation_warnings.get( "Asking-to-truncate-to-max_length", False ): logger.warning( "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. " "Default to no truncation." ) self.deprecation_warnings[ "Asking-to-truncate-to-max_length" ] = True truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE else: max_length = self.model_max_length # Test if we have a padding token if padding_strategy != PaddingStrategy.DO_NOT_PAD and ( not self.pad_token or self.pad_token_id < 0 ): raise ValueError( "Asking to pad but the tokenizer does not have a padding token. " "Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` " "or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`." ) # Check that we will truncate to a multiple of pad_to_multiple_of if both are provided if ( truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and padding_strategy != PaddingStrategy.DO_NOT_PAD and pad_to_multiple_of is not None and max_length is not None and (max_length % pad_to_multiple_of != 0) ): raise ValueError( f"Truncation and padding are both activated but " f"truncation length ({max_length}) is not a multiple of pad_to_multiple_of ({pad_to_multiple_of})." ) return padding_strategy, truncation_strategy, max_length, kwargs def pad( self, encoded_inputs: Union[ BatchEncoding, List[BatchEncoding], Dict[str, EncodedInput], Dict[str, List[EncodedInput]], List[Dict[str, EncodedInput]], ], padding: Union[bool, str, PaddingStrategy] = True, max_length: Optional[int] = None, pad_to_multiple_of: Optional[int] = None, return_attention_mask: Optional[bool] = True, return_tensors: Optional[Union[str, TensorType]] = None, verbose: bool = True, ) -> BatchEncoding: """ Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length in the batch. Padding side (left/right) padding token ids are defined at the tokenizer level (with ``self.padding_side``, ``self.pad_token_id`` and ``self.pad_token_type_id``) .. note:: If the ``encoded_inputs`` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the result will use the same type unless you provide a different tensor type with ``return_tensors``. In the case of PyTorch tensors, you will lose the specific device of your tensors however. Args: encoded_inputs (:class:`~transformers.BatchEncoding`, list of :class:`~transformers.BatchEncoding`, :obj:`Dict[str, List[int]]`, :obj:`Dict[str, List[List[int]]` or :obj:`List[Dict[str, List[int]]]`): Tokenized inputs. Can represent one input (:class:`~transformers.BatchEncoding` or :obj:`Dict[str, List[int]]`) or a batch of tokenized inputs (list of :class:`~transformers.BatchEncoding`, `Dict[str, List[List[int]]]` or `List[Dict[str, List[int]]]`) so you can use this method during preprocessing as well as in a PyTorch Dataloader collate function. Instead of :obj:`List[int]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors), see the note above for the return type. padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): Select a strategy to pad the returned sequences (according to the model's padding side and padding index) among: * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single sequence if provided). * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the maximum acceptable input length for the model if that argument is not provided. * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different lengths). max_length (:obj:`int`, `optional`): Maximum length of the returned list and optionally padding length (see above). pad_to_multiple_of (:obj:`int`, `optional`): If set will pad the sequence to a multiple of the provided value. This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta). return_attention_mask (:obj:`bool`, `optional`): Whether to return the attention mask. If left to the default, will return the attention mask according to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute. `What are attention masks? <../glossary.html#attention-mask>`__ return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`): If set, will return tensors instead of list of python integers. Acceptable values are: * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects. * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects. * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects. verbose (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not to print more information and warnings. """ # If we have a list of dicts, let's convert it in a dict of lists # We do this to allow using this method as a collate_fn function in PyTorch Dataloader if isinstance(encoded_inputs, (list, tuple)) and isinstance( encoded_inputs[0], (dict, BatchEncoding) ): encoded_inputs = { key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys() } # The model's main input name, usually `input_ids`, has be passed for padding if self.model_input_names[0] not in encoded_inputs: raise ValueError( "You should supply an encoding or a list of encodings to this method" f"that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}" ) required_input = encoded_inputs[self.model_input_names[0]] if not required_input: if return_attention_mask: encoded_inputs["attention_mask"] = [] return encoded_inputs # If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects # and rebuild them afterwards if no return_tensors is specified # Note that we lose the specific device the tensor may be on for PyTorch first_element = required_input[0] if isinstance(first_element, (list, tuple)): # first_element might be an empty list/tuple in some edge cases so we grab the first non empty element. index = 0 while len(required_input[index]) == 0: index += 1 if index < len(required_input): first_element = required_input[index][0] # At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do. if not isinstance(first_element, (int, list, tuple)): if is_tf_available() and _is_tensorflow(first_element): return_tensors = "tf" if return_tensors is None else return_tensors elif is_torch_available() and _is_torch(first_element): return_tensors = "pt" if return_tensors is None else return_tensors if isinstance(first_element, np.ndarray): return_tensors = "np" if return_tensors is None else return_tensors else: raise ValueError( f"type of {first_element} unknown: {type(first_element)}. " f"Should be one of a python, numpy, pytorch or tensorflow object." ) for key, value in encoded_inputs.items(): encoded_inputs[key] = to_py_obj(value) # Convert padding_strategy in PaddingStrategy padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies( padding=padding, max_length=max_length, verbose=verbose ) required_input = encoded_inputs[self.model_input_names[0]] if required_input and not isinstance(required_input[0], (list, tuple)): encoded_inputs = self._pad( encoded_inputs, max_length=max_length, padding_strategy=padding_strategy, pad_to_multiple_of=pad_to_multiple_of, return_attention_mask=return_attention_mask, ) return BatchEncoding(encoded_inputs, tensor_type=return_tensors) batch_size = len(required_input) assert all( len(v) == batch_size for v in encoded_inputs.values() ), "Some items in the output dictionary have a different batch size than others." if padding_strategy == PaddingStrategy.LONGEST: max_length = max(len(inputs) for inputs in required_input) padding_strategy = PaddingStrategy.MAX_LENGTH batch_outputs = {} for i in range(batch_size): inputs = dict((k, v[i]) for k, v in encoded_inputs.items()) outputs = self._pad( inputs, max_length=max_length, padding_strategy=padding_strategy, pad_to_multiple_of=pad_to_multiple_of, return_attention_mask=return_attention_mask, ) for key, value in outputs.items(): if key not in batch_outputs: batch_outputs[key] = [] batch_outputs[key].append(value) return BatchEncoding(batch_outputs, tensor_type=return_tensors) def _pad( self, encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], max_length: Optional[int] = None, padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, pad_to_multiple_of: Optional[int] = None, return_attention_mask: Optional[bool] = None, ) -> dict: """ Pad encoded inputs (on left/right and up to predefined length or max length in the batch) Args: encoded_inputs: Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`). max_length: maximum length of the returned list and optionally padding length (see below). Will truncate by taking into account the special tokens. padding_strategy: PaddingStrategy to use for padding. - PaddingStrategy.LONGEST Pad to the longest sequence in the batch - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) - PaddingStrategy.DO_NOT_PAD: Do not pad The tokenizer padding sides are defined in self.padding_side: - 'left': pads on the left of the sequences - 'right': pads on the right of the sequences pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability >= 7.5 (Volta). return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics) """ # Load from model defaults if return_attention_mask is None: return_attention_mask = "attention_mask" in self.model_input_names required_input = encoded_inputs[self.model_input_names[0]] if padding_strategy == PaddingStrategy.LONGEST: max_length = len(required_input) if ( max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0) ): max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of needs_to_be_padded = ( padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length ) if needs_to_be_padded: difference = max_length - len(required_input) if self.padding_side == "right": if return_attention_mask: encoded_inputs["attention_mask"] = [1] * len(required_input) + [ 0 ] * difference if "token_type_ids" in encoded_inputs: encoded_inputs["token_type_ids"] = ( encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference ) if "special_tokens_mask" in encoded_inputs: encoded_inputs["special_tokens_mask"] = ( encoded_inputs["special_tokens_mask"] + [1] * difference ) encoded_inputs[self.model_input_names[0]] = ( required_input + [self.pad_token_id] * difference ) elif self.padding_side == "left": if return_attention_mask: encoded_inputs["attention_mask"] = [0] * difference + [1] * len( required_input ) if "token_type_ids" in encoded_inputs: encoded_inputs["token_type_ids"] = [ self.pad_token_type_id ] * difference + encoded_inputs["token_type_ids"] if "special_tokens_mask" in encoded_inputs: encoded_inputs["special_tokens_mask"] = [ 1 ] * difference + encoded_inputs["special_tokens_mask"] encoded_inputs[self.model_input_names[0]] = [ self.pad_token_id ] * difference + required_input else: raise ValueError("Invalid padding strategy:" + str(self.padding_side)) elif return_attention_mask and "attention_mask" not in encoded_inputs: encoded_inputs["attention_mask"] = [1] * len(required_input) return encoded_inputs def get_special_tokens_mask( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False, ) -> List[int]: """ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods. Args: token_ids_0 (:obj:`List[int]`): List of ids of the first sequence. token_ids_1 (:obj:`List[int]`, `optional`): List of ids of the second sequence. already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not the token list is already formatted with special tokens for the model. Returns: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. """ assert already_has_special_tokens and token_ids_1 is None, ( "You cannot use ``already_has_special_tokens=False`` with this tokenizer. " "Please use a slow (full python) tokenizer to activate this argument." "Or set `return_special_tokens_mask=True` when calling the encoding method " "to get the special tokens mask in any tokenizer. " ) all_special_ids = self.all_special_ids # cache the property special_tokens_mask = [ 1 if token in all_special_ids else 0 for token in token_ids_0 ] return special_tokens_mask def convert_tokens_to_ids( self, tokens: Union[str, List[str]] ) -> Union[int, List[int]]: """ Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the vocabulary. Args: tokens (:obj:`str` or :obj:`List[str]`): One or several token(s) to convert to token id(s). Returns: :obj:`int` or :obj:`List[int]`: The token id or list of token ids. """ if tokens is None: return None if isinstance(tokens, str): return self._convert_token_to_id_with_added_voc(tokens) ids = [] for token in tokens: ids.append(self._convert_token_to_id_with_added_voc(token)) return ids def _convert_token_to_id_with_added_voc(self, token): if token is None: return None return self.token_dictionary.get(token) def __len__(self): return len(self.token_dictionary) class GeneformerPretrainer(Trainer): def __init__(self, *args, **kwargs): data_collator = kwargs.get("data_collator") token_dictionary = kwargs.get("token_dictionary") if data_collator is None: precollator = GeneformerPreCollator(token_dictionary=token_dictionary) # # Data Collator Functions data_collator = DataCollatorForLanguageModeling( tokenizer=precollator, mlm=True, mlm_probability=0.15 ) kwargs["data_collator"] = data_collator super().__init__(*args, **kwargs) # load previously saved length vector for dataset to speed up LengthGroupedSampler # pre-obtained with [dataset[i]["length"] for i in range(len(dataset))] if kwargs.get("example_lengths_file"): with open(kwargs.get("example_lengths_file"), "rb") as f: self.example_lengths = pickle.load(f) else: raise Exception( "example_lengths_file is required; e.g. https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/genecorpus_30M_2048_sorted_lengths.pkl" ) # modify LengthGroupedSampler to avoid dataset[length_column_name] hanging def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]: if not isinstance(self.train_dataset, collections.abc.Sized): return None generator = None if self.args.world_size <= 1 and _is_torch_generator_available: generator = torch.Generator() generator.manual_seed( int(torch.empty((), dtype=torch.int64).random_().item()) ) # Build the sampler. if self.args.group_by_length: if is_datasets_available() and isinstance(self.train_dataset, Dataset): lengths = self.example_lengths else: lengths = None print(f"Lengths: {len(lengths)}") model_input_name = ( self.tokenizer.model_input_names[0] if self.tokenizer is not None else None ) if self.args.world_size <= 1: return LengthGroupedSampler( self.train_dataset, self.args.train_batch_size, lengths=lengths, model_input_name=model_input_name, generator=generator, ) else: return CustomDistributedLengthGroupedSampler( self.train_dataset, self.args.train_batch_size, num_replicas=self.args.world_size, rank=self.args.process_index, lengths=lengths, model_input_name=model_input_name, seed=self.args.seed, ) else: if self.args.world_size <= 1: if _is_torch_generator_available: return RandomSampler(self.train_dataset, generator=generator) return RandomSampler(self.train_dataset) elif ( self.args.parallel_mode in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL] and not self.args.dataloader_drop_last ): # Use a loop for TPUs when drop_last is False to have all batches have the same size. return DistributedSamplerWithLoop( self.train_dataset, batch_size=self.args.per_device_train_batch_size, num_replicas=self.args.world_size, rank=self.args.process_index, seed=self.args.seed, ) else: return DistributedSampler( self.train_dataset, num_replicas=self.args.world_size, rank=self.args.process_index, seed=self.args.seed, ) class CustomDistributedLengthGroupedSampler(DistributedLengthGroupedSampler): r""" Distributed Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while keeping a bit of randomness. """ # Copied and adapted from PyTorch DistributedSampler. def __init__( self, dataset: Dataset, batch_size: int, num_replicas: Optional[int] = None, rank: Optional[int] = None, seed: int = 0, drop_last: bool = False, lengths: Optional[List[int]] = None, model_input_name: Optional[str] = None, ): if num_replicas is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") num_replicas = dist.get_world_size() if rank is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") rank = dist.get_rank() self.dataset = dataset self.batch_size = batch_size self.num_replicas = num_replicas self.rank = rank self.epoch = 0 self.drop_last = drop_last # If the dataset length is evenly divisible by # of replicas, then there # is no need to drop any data, since the dataset will be split equally. if self.drop_last and len(self.dataset) % self.num_replicas != 0: # Split to nearest available length that is evenly divisible. # This is to ensure each rank receives the same amount of data when # using this Sampler. self.num_samples = math.ceil( (len(self.dataset) - self.num_replicas) / self.num_replicas ) else: self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) self.total_size = self.num_samples * self.num_replicas self.seed = seed self.model_input_name = ( model_input_name if model_input_name is not None else "input_ids" ) if lengths is None: print("Lengths is none - calculating lengths.") if ( not ( isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding) ) or self.model_input_name not in dataset[0] ): raise ValueError( "Can only automatically infer lengths for datasets whose items are dictionaries with an " f"'{self.model_input_name}' key." ) lengths = [len(feature[self.model_input_name]) for feature in dataset] self.lengths = lengths def __iter__(self) -> Iterator: # Deterministically shuffle based on epoch and seed g = torch.Generator() g.manual_seed(self.seed + self.epoch) indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=g) if not self.drop_last: # add extra samples to make it evenly divisible indices += indices[: (self.total_size - len(indices))] else: # remove tail of data to make it evenly divisible. indices = indices[: self.total_size] assert len(indices) == self.total_size # subsample indices = indices[self.rank : self.total_size : self.num_replicas] assert len(indices) == self.num_samples return iter(indices) def get_length_grouped_indices( lengths, batch_size, mega_batch_mult=None, generator=None ): """ Return a list of indices so that each slice of :obj:`batch_size` consecutive indices correspond to elements of similar lengths. To do this, the indices are: - randomly permuted - grouped in mega-batches of size :obj:`mega_batch_mult * batch_size` - sorted by length in each mega-batch The result is the concatenation of all mega-batches, with the batch of :obj:`batch_size` containing the element of maximum length placed first, so that an OOM happens sooner rather than later. """ # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller. if mega_batch_mult is None: # mega_batch_mult = min(len(lengths) // (batch_size * 4), 50) mega_batch_mult = min(len(lengths) // (batch_size * 4), 1000) # Just in case, for tiny datasets if mega_batch_mult == 0: mega_batch_mult = 1 # We need to use torch for the random part as a distributed sampler will set the random seed for torch. indices = torch.randperm(len(lengths), generator=generator) megabatch_size = mega_batch_mult * batch_size megabatches = [ indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size) ] megabatches = [ list(sorted(megabatch, key=lambda i: lengths[i], reverse=True)) for megabatch in megabatches ] # The rest is to get the biggest batch first. # Since each megabatch is sorted by descending length, the longest element is the first megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches] max_idx = torch.argmax(torch.tensor(megabatch_maximums)).item() # Switch to put the longest element in first position megabatches[0][0], megabatches[max_idx][0] = ( megabatches[max_idx][0], megabatches[0][0], ) return [item for sublist in megabatches for item in sublist]