import logging from dataclasses import dataclass from typing import Dict, List, Optional, Union, Set import torch import numpy as np import datasets from datasets import load_dataset, Dataset, IterableDataset, interleave_datasets, concatenate_datasets from transformers import AutoFeatureExtractor, AutoTokenizer from tqdm import tqdm from accelerate import Accelerator @dataclass class DataCollatorEncodecWithPadding: """ Data collator that will dynamically pad the inputs received to the longest sequence in the batch or to `max_length` if `max_length` is set and `padding=max_length`. """ feature_extractor: AutoFeatureExtractor audio_column_name: str feature_extractor_input_name: Optional[str] = "input_values" max_length: Optional[int] = None padding: Optional[str] = "longest" def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: # split inputs and labels since they have to be of different lengths and need # different padding methods audios = [feature[self.audio_column_name]["array"] for feature in features] len_audio = [len(audio) for audio in audios] # since resampling has already been performed in the 'load_multiple_datasets' function, # a fixed sampling_rate(44100hz) is passed to the feature_extractor. sampling_rate = self.feature_extractor.sampling_rate batch = self.feature_extractor( audios, sampling_rate=sampling_rate, return_tensors="pt", padding=self.padding, max_length=self.max_length ) batch["len_audio"] = torch.tensor(len_audio).unsqueeze(1) return batch @dataclass class DataCollatorParlerTTSWithPadding: """ Data collator that will dynamically pad the inputs received. Args: prompt_tokenizer (:class:`~transformers.AutoTokenizer`) The prompt_tokenizer used for proccessing the data. description_tokenizer (:class:`~transformers.AutoTokenizer`) The description_tokenizer used for proccessing the data. padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): Select a strategy to pad the returned sequences (according to the model's padding side and padding index) among: * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single sequence if provided). * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the maximum acceptable input length for the model if that argument is not provided. * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different lengths). pad_to_multiple_of (:obj:`int`, `optional`): If set will pad the sequence to a multiple of the provided value. This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta). """ prompt_tokenizer: AutoTokenizer description_tokenizer: AutoTokenizer padding: Union[bool, str] = "longest" pad_to_multiple_of: Optional[int] = None prompt_max_length: Optional[int] = None description_max_length: Optional[int] = None audio_max_length: Optional[int] = None def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: # split inputs and labels since they have to be of different lengths and need # different padding methods labels = [torch.tensor(feature["labels"]).transpose(0, 1) for feature in features] # (bsz, seq_len, num_codebooks) labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100) if self.audio_max_length is not None and self.padding == "max_length": labels = torch.nn.functional.pad(labels, pad=(0, 0, 0, max(self.audio_max_length - labels.shape[1], 0))) input_ids = [{"input_ids": feature["input_ids"]} for feature in features] input_ids = self.description_tokenizer.pad( input_ids, return_tensors="pt", padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of, max_length=self.description_max_length, ) batch = {"labels": labels, **input_ids} if self.audio_max_length is not None and self.padding == "max_length": # if we do torch.compile, we need to also specify the attention_mask decoder_attention_mask = torch.ones(labels.shape[:2], dtype=input_ids["attention_mask"].dtype) batch["decoder_attention_mask"] = decoder_attention_mask prompt_input_ids = [{"input_ids": feature["prompt_input_ids"]} for feature in features] prompt_input_ids = self.prompt_tokenizer.pad( prompt_input_ids, return_tensors="pt", padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of, max_length=self.prompt_max_length, ) batch["prompt_input_ids"] = prompt_input_ids["input_ids"] if "attention_mask" in prompt_input_ids: batch["prompt_attention_mask"] = prompt_input_ids["attention_mask"] return batch def convert_dataset_str_to_list( dataset_names, dataset_config_names, metadata_dataset_names=None, splits=None, dataset_samples=None, default_split="train", ): if isinstance(dataset_names, str): dataset_names = dataset_names.split("+") dataset_config_names = dataset_config_names.split("+") splits = splits.split("+") if splits is not None else None dataset_samples = dataset_samples.split("+") if dataset_samples is not None else None metadata_dataset_names = metadata_dataset_names.split("+") if metadata_dataset_names is not None else None # basic checks to ensure we've got the right number of datasets/configs/splits/columns/probs if len(dataset_names) != len(dataset_config_names): raise ValueError( f"Ensure one config is passed for each dataset, got {len(dataset_names)} datasets and" f" {len(dataset_config_names)} configs." ) if splits is not None and len(splits) != len(dataset_names): raise ValueError( f"Ensure one split is passed for each dataset, got {len(dataset_names)} datasets and {len(splits)} splits." ) if metadata_dataset_names is not None and len(metadata_dataset_names) != len(dataset_names): raise ValueError( f"Ensure one metadata dataset is passed for each dataset, got {len(dataset_names)} datasets and {len(metadata_dataset_names)} metadata datasets." ) if dataset_samples is not None: if len(dataset_samples) != len(dataset_names): raise ValueError( f"Ensure one sample is passed for each dataset, got {len(dataset_names)} datasets and " f"{len(dataset_samples)} samples." ) dataset_samples = [float(ds_sample) for ds_sample in dataset_samples] else: dataset_samples = [None] * len(dataset_names) splits = splits if splits is not None else [default_split for _ in range(len(dataset_names))] dataset_names_dict = [] for i, ds_name in enumerate(dataset_names): dataset_names_dict.append( { "name": ds_name, "config": dataset_config_names[i], "split": splits[i], "metadata_dataset_name": metadata_dataset_names[i], "samples": dataset_samples[i], } ) return dataset_names_dict def load_multiple_datasets( accelerator: Accelerator, dataset_names: Union[List, str], dataset_config_names: Union[List, str], metadata_dataset_names: Optional[str] = None, splits: Optional[Union[List, str]] = None, label_column_names: Optional[List] = None, stopping_strategy: Optional[str] = "first_exhausted", dataset_samples: Optional[Union[List, np.array]] = None, streaming: Optional[bool] = False, seed: Optional[int] = None, id_column_name: Optional[str] = None, columns_to_keep: Optional[Set[str]] = None, prompt_column_name: Optional[str] = None, sampling_rate: Optional[int] = None, audio_column_name: Optional[str] = None, logger: Optional[logging.Logger] = None, **kwargs, ) -> Union[Dataset, IterableDataset]: dataset_names_dict = convert_dataset_str_to_list( dataset_names, dataset_config_names, metadata_dataset_names, splits, label_column_names, dataset_samples ) if dataset_samples is not None: dataset_samples = [ds_dict["samples"] for ds_dict in dataset_names_dict] probabilities = np.array(dataset_samples) / np.sum(dataset_samples) else: probabilities = None all_datasets = [] # iterate over the datasets we want to interleave for dataset_dict in tqdm(dataset_names_dict, desc="Combining datasets..."): with accelerator.main_process_first(): dataset = load_dataset( dataset_dict["name"], dataset_dict["config"], split=dataset_dict["split"], streaming=streaming, **kwargs, ) dataset_features = dataset.features.keys() if sampling_rate is not None and audio_column_name is not None: # resample target audio dataset = dataset.cast_column(audio_column_name, datasets.features.Audio(sampling_rate=sampling_rate)) metadata_dataset_name = dataset_dict["metadata_dataset_name"] if metadata_dataset_name is not None: logger.info( f'Merging {dataset_dict["name"]} - {dataset_dict["split"]} with {metadata_dataset_name} - {dataset_dict["split"]}' ) metadata_dataset = load_dataset( metadata_dataset_name, dataset_dict["config"], split=dataset_dict["split"], streaming=streaming, **kwargs, ) # TODO(YL): I forgot to create unique ids for MLS english. # To iterate faster, I bypass the original id check and do another one. - Done once because assuming it won't change next time # if dataset_dict["name"] == "parler-tts/mls_eng_10k": # def concat_ids(book_id, speaker_id, begin_time): # return {"id": f"{book_id}_{speaker_id}_{str(begin_time).replace('.', '_')}"} # dataset = dataset.map(concat_ids, input_columns=["book_id", "speaker_id", "begin_time"], num_proc=24) # metadata_dataset = metadata_dataset.map(concat_ids, input_columns=["book_id", "speaker_id", "begin_time"], num_proc=24) # metadata_dataset = metadata_dataset.rename_column(id_column_name, f"metadata_{id_column_name}") if dataset_dict["name"] != "parler-tts/mls_eng_10k": if id_column_name is not None and id_column_name not in dataset.column_names: raise ValueError( f"id_column_name={id_column_name} but has not been found in the dataset columns" f"- one of {', '.join(list(dataset.column_names))}." ) if id_column_name is not None and id_column_name not in metadata_dataset.column_names: raise ValueError( f"id_column_name={id_column_name} but has not been found in the metadata dataset columns" f"- one of {', '.join(list(metadata_dataset.column_names))}." ) elif id_column_name is not None: metadata_dataset = metadata_dataset.rename_column(id_column_name, f"metadata_{id_column_name}") metadata_columns_to_remove = set(metadata_dataset.column_names).intersection(set(dataset.column_names)) if prompt_column_name is not None: # We might have applied some transformations to the prompts (e.g punctuation restoration) # so we make sure to remove it from the original dataset if prompt_column_name in dataset.column_names: logger.info( f"REMOVE {prompt_column_name} from dataset {dataset_dict['name']} - dataset_dict['split']" ) dataset.remove_columns(prompt_column_name) metadata_columns_to_remove = set(metadata_dataset.column_names).intersection(set(dataset.column_names)) metadata_dataset = metadata_dataset.remove_columns(metadata_columns_to_remove) dataset = concatenate_datasets([dataset, metadata_dataset], axis=1) if id_column_name is not None and dataset_dict["name"] != "parler-tts/mls_eng_10k": if ( len( dataset.filter( lambda id1, id2: id1 != id2, input_columns=[id_column_name, f"metadata_{id_column_name}"], ) ) != 0 ): raise ValueError( f"Concatenate didn't work. Some ids don't correspond on dataset {dataset_dict['name']}" ) dataset_features = dataset.features.keys() if columns_to_keep is not None: dataset = dataset.remove_columns(set(dataset_features - columns_to_keep)) all_datasets.append(dataset) if len(all_datasets) == 1: # we have a single dataset so just return it as is return all_datasets[0] if streaming: interleaved_dataset = interleave_datasets( all_datasets, stopping_strategy=stopping_strategy, probabilities=probabilities, seed=seed, ) else: with accelerator.main_process_first(): interleaved_dataset = concatenate_datasets(all_datasets) return interleaved_dataset