#!/usr/bin/env python3 import logging import sys from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Union import librosa import torch from datasets import DatasetDict, load_dataset from packaging import version from torch import nn from transformers import ( HfArgumentParser, Trainer, TrainingArguments, Wav2Vec2Config, Wav2Vec2FeatureExtractor, Wav2Vec2ForPreTraining, is_apex_available, trainer_utils, ) from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices if is_apex_available(): from apex import amp if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.6"): _is_native_amp_available = True from torch.cuda.amp import autocast logger = logging.getLogger(__name__) @dataclass class ModelArguments: """ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. """ model_name_or_path: str = field( metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} ) cache_dir: Optional[str] = field( default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, ) freeze_feature_extractor: Optional[bool] = field( default=True, metadata={"help": "Whether to freeze the feature extractor layers of the model."} ) verbose_logging: Optional[bool] = field( default=False, metadata={"help": "Whether to log verbose messages or not."}, ) max_gumbel_temperature: Optional[float] = field( default=2.0, metadata={"help": "Maximum temperature for gumbel softmax."} ) min_gumbel_temperature: Optional[float] = field( default=0.5, metadata={"help": "Minimum temperature for gumbel softmax."} ) gumbel_temperature_decay: Optional[float] = field( default=0.999995, metadata={"help": "Decay of gumbel temperature during training."} ) def configure_logger(model_args: ModelArguments, training_args: TrainingArguments): logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", handlers=[logging.StreamHandler(sys.stdout)], ) logging_level = logging.WARNING if model_args.verbose_logging: logging_level = logging.DEBUG elif trainer_utils.is_main_process(training_args.local_rank): logging_level = logging.INFO logger.setLevel(logging_level) @dataclass class DataTrainingArguments: """ Arguments pertaining to what data we are going to input our model for training and eval. Using `HfArgumentParser` we can turn this class into argparse arguments to be able to specify them on the command line. """ dataset_name: str = field( default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} ) dataset_config_name: Optional[str] = field( default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} ) train_split_name: Optional[str] = field( default="train", metadata={ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'" }, ) validation_split_name: Optional[str] = field( default="validation", metadata={ "help": ( "The name of the validation data set split to use (via the datasets library). Defaults to 'validation'" ) }, ) speech_file_column: Optional[str] = field( default="file", metadata={"help": "Column in the dataset that contains speech file path. Defaults to 'file'"}, ) overwrite_cache: bool = field( default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."} ) validation_split_percentage: Optional[int] = field( default=1, metadata={ "help": "The percentage of the train set used as validation set in case there's no validation split" }, ) preprocessing_num_workers: Optional[int] = field( default=None, metadata={"help": "The number of processes to use for the preprocessing."}, ) max_duration_in_seconds: Optional[float] = field( default=20.0, metadata={"help": "Filter audio files that are longer than `max_duration_in_seconds` seconds"} ) @dataclass class DataCollatorForWav2Vec2Pretraining: """ Data collator that will dynamically pad the inputs received and prepare masked indices for self-supervised pretraining. Args: model (:class:`~transformers.Wav2Vec2ForPreTraining`): The Wav2Vec2 model used for pretraining. The data collator needs to have access to config and ``_get_feat_extract_output_lengths`` function for correct padding. feature_extractor (:class:`~transformers.Wav2Vec2FeatureExtractor`): The processor used for proccessing the data. padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): Select a strategy to pad the returned sequences (according to the model's padding side and padding index) among: * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single sequence if provided). * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the maximum acceptable input length for the model if that argument is not provided. * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different lengths). max_length (:obj:`int`, `optional`): Maximum length of the ``input_values`` of the returned list and optionally padding length (see above). pad_to_multiple_of (:obj:`int`, `optional`): If set will pad the sequence to a multiple of the provided value. This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta). """ model: Wav2Vec2ForPreTraining feature_extractor: Wav2Vec2FeatureExtractor padding: Union[bool, str] = "longest" pad_to_multiple_of: Optional[int] = None max_length: Optional[int] = None def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: # reformat list to dict and set to pytorch format batch = self.feature_extractor.pad( features, max_length=self.max_length, padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of, return_tensors="pt", ) mask_indices_seq_length = self.model._get_feat_extract_output_lengths(batch["input_values"].shape[-1]) batch_size = batch["input_values"].shape[0] # make sure that no loss is computed on padded inputs if batch["attention_mask"] is not None: # compute real output lengths according to convolution formula output_lengths = self.model._get_feat_extract_output_lengths(batch["attention_mask"].sum(-1)).to( torch.long ) attention_mask = torch.zeros( (batch_size, mask_indices_seq_length), dtype=torch.long, device=batch["input_values"].device ) # these two operations makes sure that all values # before the output lengths indices are attended to attention_mask[ (torch.arange(attention_mask.shape[0], device=batch["input_values"].device), output_lengths - 1) ] = 1 attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() # sample randomly masked indices batch["mask_time_indices"] = _compute_mask_indices( (batch_size, mask_indices_seq_length), self.model.config.mask_time_prob, self.model.config.mask_time_length, attention_mask=attention_mask, min_masks=2, ) return batch class Wav2Vec2PreTrainer(Trainer): """ Subclassed :class:`~transformers.Trainer` for Wav2Vec2-like pretraining. Trainer can decay gumbel softmax temperature during training. """ def __init__(self, *args, max_gumbel_temp=1, min_gumbel_temp=0, gumbel_temp_decay=1.0, **kwargs): super().__init__(*args, **kwargs) self.num_update_step = 0 self.max_gumbel_temp = max_gumbel_temp self.min_gumbel_temp = min_gumbel_temp self.gumbel_temp_decay = gumbel_temp_decay def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: """ Perform a training step on a batch of inputs. Subclass and override to inject custom behavior. Args: model (:obj:`nn.Module`): The model to train. inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`): The inputs and targets of the model. The dictionary will be unpacked before being fed to the model. Most models expect the targets under the argument :obj:`labels`. Check your model's documentation for all accepted arguments. Return: :obj:`torch.Tensor`: The tensor with training loss on this batch. """ model.train() inputs = self._prepare_inputs(inputs) if self.use_amp: with autocast(): loss = self.compute_loss(model, inputs) else: loss = self.compute_loss(model, inputs) if self.args.n_gpu > 1 or self.deepspeed: if model.module.config.ctc_loss_reduction == "mean": loss = loss.mean() elif model.module.config.ctc_loss_reduction == "sum": loss = loss.sum() / (inputs["mask_time_indices"]).sum() else: raise ValueError(f"{model.config.ctc_loss_reduction} is not valid. Choose one of ['mean', 'sum']") if self.args.gradient_accumulation_steps > 1: loss = loss / self.args.gradient_accumulation_steps if self.use_amp: self.scaler.scale(loss).backward() elif self.use_apex: with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() elif self.deepspeed: self.deepspeed.backward(loss) else: loss.backward() self.num_update_step += 1 # make sure gumbel softmax temperature is decayed if self.args.n_gpu > 1 or self.deepspeed: model.module.set_gumbel_temperature( max(self.max_gumbel_temp * self.gumbel_temp_decay**self.num_update_step, self.min_gumbel_temp) ) else: model.set_gumbel_temperature( max(self.max_gumbel_temp * self.gumbel_temp_decay**self.num_update_step, self.min_gumbel_temp) ) return loss.detach() def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() configure_logger(model_args, training_args) # Downloading and loading a dataset from the hub. datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir) if "validation" not in datasets.keys(): # make sure only "validation" and "train" keys remain" datasets = DatasetDict() datasets["validation"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"{data_args.train_split_name}[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, ) datasets["train"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"{data_args.train_split_name}[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, ) else: # make sure only "validation" and "train" keys remain" datasets = DatasetDict() datasets["validation"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split="validation", cache_dir=model_args.cache_dir, ) datasets["train"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"{data_args.train_split_name}", cache_dir=model_args.cache_dir, ) # only normalized-inputs-training is supported feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, do_normalize=True ) def prepare_dataset(batch): # check that all files have the correct sampling rate batch["speech"], _ = librosa.load(batch[data_args.speech_file_column], sr=feature_extractor.sampling_rate) return batch # load audio files into numpy arrays vectorized_datasets = datasets.map( prepare_dataset, num_proc=data_args.preprocessing_num_workers, remove_columns=datasets["train"].column_names ) # filter audio files that are too long vectorized_datasets = vectorized_datasets.filter( lambda data: len(data["speech"]) < int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate) ) def normalize(batch): return feature_extractor(batch["speech"], sampling_rate=feature_extractor.sampling_rate) # normalize and transform to `BatchFeatures` vectorized_datasets = vectorized_datasets.map( normalize, batched=True, num_proc=data_args.preprocessing_num_workers, load_from_cache_file=not data_args.overwrite_cache, remove_columns=vectorized_datasets["train"].column_names, ) # pretraining is only supported for "newer" stable layer norm architecture # apply_spec_augment has to be True, mask_feature_prob has to be 0.0 config = Wav2Vec2Config.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, gradient_checkpointing=training_args.gradient_checkpointing, ) if not config.do_stable_layer_norm or config.feat_extract_norm != "layer": raise ValueError( "PreTraining is only supported for ``config.do_stable_layer_norm=True`` and" " ``config.feat_extract_norm='layer'" ) model = Wav2Vec2ForPreTraining(config) data_collator = DataCollatorForWav2Vec2Pretraining(model=model, feature_extractor=feature_extractor) trainer = Wav2Vec2PreTrainer( model=model, data_collator=data_collator, args=training_args, train_dataset=vectorized_datasets["train"], eval_dataset=vectorized_datasets["validation"], tokenizer=feature_extractor, max_gumbel_temp=model_args.max_gumbel_temperature, min_gumbel_temp=model_args.min_gumbel_temperature, gumbel_temp_decay=model_args.gumbel_temperature_decay, ) trainer.train() if __name__ == "__main__": main()