|
import logging |
|
from dataclasses import dataclass |
|
from typing import Dict, List, Optional, Union, Set |
|
|
|
import torch |
|
import numpy as np |
|
import datasets |
|
from datasets import load_dataset, Dataset, IterableDataset, interleave_datasets, concatenate_datasets |
|
from transformers import AutoFeatureExtractor, AutoTokenizer |
|
from tqdm import tqdm |
|
|
|
from accelerate import Accelerator |
|
|
|
|
|
@dataclass |
|
class DataCollatorEncodecWithPadding: |
|
""" |
|
Data collator that will dynamically pad the inputs received to the longest sequence in the batch or |
|
to `max_length` if `max_length` is set and `padding=max_length`. |
|
""" |
|
|
|
feature_extractor: AutoFeatureExtractor |
|
audio_column_name: str |
|
feature_extractor_input_name: Optional[str] = "input_values" |
|
max_length: Optional[int] = None |
|
padding: Optional[str] = "longest" |
|
|
|
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: |
|
|
|
|
|
audios = [feature[self.audio_column_name]["array"] for feature in features] |
|
len_audio = [len(audio) for audio in audios] |
|
|
|
batch = self.feature_extractor(audios, return_tensors="pt", padding=self.padding, max_length=self.max_length) |
|
batch["len_audio"] = torch.tensor(len_audio).unsqueeze(1) |
|
return batch |
|
|
|
|
|
@dataclass |
|
class DataCollatorParlerTTSWithPadding: |
|
""" |
|
Data collator that will dynamically pad the inputs received. |
|
Args: |
|
prompt_tokenizer (:class:`~transformers.AutoTokenizer`) |
|
The prompt_tokenizer used for proccessing the data. |
|
description_tokenizer (:class:`~transformers.AutoTokenizer`) |
|
The description_tokenizer used for proccessing the data. |
|
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): |
|
Select a strategy to pad the returned sequences (according to the model's padding side and padding index) |
|
among: |
|
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single |
|
sequence if provided). |
|
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the |
|
maximum acceptable input length for the model if that argument is not provided. |
|
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of |
|
different lengths). |
|
pad_to_multiple_of (:obj:`int`, `optional`): |
|
If set will pad the sequence to a multiple of the provided value. |
|
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= |
|
7.5 (Volta). |
|
""" |
|
|
|
prompt_tokenizer: AutoTokenizer |
|
description_tokenizer: AutoTokenizer |
|
padding: Union[bool, str] = "longest" |
|
pad_to_multiple_of: Optional[int] = None |
|
prompt_max_length: Optional[int] = None |
|
description_max_length: Optional[int] = None |
|
audio_max_length: Optional[int] = None |
|
|
|
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: |
|
|
|
|
|
|
|
labels = [torch.tensor(feature["labels"]).transpose(0, 1) for feature in features] |
|
|
|
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100) |
|
if self.audio_max_length is not None and self.padding == "max_length": |
|
labels = torch.nn.functional.pad(labels, pad=(0, 0, 0, max(self.audio_max_length - labels.shape[1], 0))) |
|
|
|
input_ids = [{"input_ids": feature["input_ids"]} for feature in features] |
|
|
|
input_ids = self.description_tokenizer.pad( |
|
input_ids, |
|
return_tensors="pt", |
|
padding=self.padding, |
|
pad_to_multiple_of=self.pad_to_multiple_of, |
|
max_length=self.description_max_length, |
|
) |
|
|
|
batch = {"labels": labels, **input_ids} |
|
|
|
if self.audio_max_length is not None and self.padding == "max_length": |
|
|
|
decoder_attention_mask = torch.ones(labels.shape[:2], dtype=input_ids["attention_mask"].dtype) |
|
batch["decoder_attention_mask"] = decoder_attention_mask |
|
|
|
prompt_input_ids = [{"input_ids": feature["prompt_input_ids"]} for feature in features] |
|
prompt_input_ids = self.prompt_tokenizer.pad( |
|
prompt_input_ids, |
|
return_tensors="pt", |
|
padding=self.padding, |
|
pad_to_multiple_of=self.pad_to_multiple_of, |
|
max_length=self.prompt_max_length, |
|
) |
|
|
|
batch["prompt_input_ids"] = prompt_input_ids["input_ids"] |
|
if "attention_mask" in prompt_input_ids: |
|
batch["prompt_attention_mask"] = prompt_input_ids["attention_mask"] |
|
|
|
return batch |
|
|
|
|
|
def convert_dataset_str_to_list( |
|
dataset_names, |
|
dataset_config_names, |
|
metadata_dataset_names=None, |
|
splits=None, |
|
dataset_samples=None, |
|
default_split="train", |
|
): |
|
if isinstance(dataset_names, str): |
|
dataset_names = dataset_names.split("+") |
|
dataset_config_names = dataset_config_names.split("+") |
|
splits = splits.split("+") if splits is not None else None |
|
dataset_samples = dataset_samples.split("+") if dataset_samples is not None else None |
|
metadata_dataset_names = metadata_dataset_names.split("+") if metadata_dataset_names is not None else None |
|
|
|
|
|
if len(dataset_names) != len(dataset_config_names): |
|
raise ValueError( |
|
f"Ensure one config is passed for each dataset, got {len(dataset_names)} datasets and" |
|
f" {len(dataset_config_names)} configs." |
|
) |
|
|
|
if splits is not None and len(splits) != len(dataset_names): |
|
raise ValueError( |
|
f"Ensure one split is passed for each dataset, got {len(dataset_names)} datasets and {len(splits)} splits." |
|
) |
|
|
|
if metadata_dataset_names is not None and len(metadata_dataset_names) != len(dataset_names): |
|
raise ValueError( |
|
f"Ensure one metadata dataset is passed for each dataset, got {len(dataset_names)} datasets and {len(metadata_dataset_names)} metadata datasets." |
|
) |
|
|
|
if dataset_samples is not None: |
|
if len(dataset_samples) != len(dataset_names): |
|
raise ValueError( |
|
f"Ensure one sample is passed for each dataset, got {len(dataset_names)} datasets and " |
|
f"{len(dataset_samples)} samples." |
|
) |
|
dataset_samples = [float(ds_sample) for ds_sample in dataset_samples] |
|
else: |
|
dataset_samples = [None] * len(dataset_names) |
|
|
|
splits = splits if splits is not None else [default_split for _ in range(len(dataset_names))] |
|
|
|
dataset_names_dict = [] |
|
for i, ds_name in enumerate(dataset_names): |
|
dataset_names_dict.append( |
|
{ |
|
"name": ds_name, |
|
"config": dataset_config_names[i], |
|
"split": splits[i], |
|
"metadata_dataset_name": metadata_dataset_names[i], |
|
"samples": dataset_samples[i], |
|
} |
|
) |
|
return dataset_names_dict |
|
|
|
|
|
def load_multiple_datasets( |
|
accelerator: Accelerator, |
|
dataset_names: Union[List, str], |
|
dataset_config_names: Union[List, str], |
|
metadata_dataset_names: Optional[str] = None, |
|
splits: Optional[Union[List, str]] = None, |
|
label_column_names: Optional[List] = None, |
|
stopping_strategy: Optional[str] = "first_exhausted", |
|
dataset_samples: Optional[Union[List, np.array]] = None, |
|
streaming: Optional[bool] = False, |
|
seed: Optional[int] = None, |
|
id_column_name: Optional[str] = None, |
|
columns_to_keep: Optional[Set[str]] = None, |
|
prompt_column_name: Optional[str] = None, |
|
sampling_rate: Optional[int] = None, |
|
audio_column_name: Optional[str] = None, |
|
logger: Optional[logging.Logger] = None, |
|
**kwargs, |
|
) -> Union[Dataset, IterableDataset]: |
|
dataset_names_dict = convert_dataset_str_to_list( |
|
dataset_names, dataset_config_names, metadata_dataset_names, splits, label_column_names, dataset_samples |
|
) |
|
|
|
if dataset_samples is not None: |
|
dataset_samples = [ds_dict["samples"] for ds_dict in dataset_names_dict] |
|
probabilities = np.array(dataset_samples) / np.sum(dataset_samples) |
|
else: |
|
probabilities = None |
|
|
|
all_datasets = [] |
|
|
|
for dataset_dict in tqdm(dataset_names_dict, desc="Combining datasets..."): |
|
with accelerator.main_process_first(): |
|
dataset = load_dataset( |
|
dataset_dict["name"], |
|
dataset_dict["config"], |
|
split=dataset_dict["split"], |
|
streaming=streaming, |
|
**kwargs, |
|
) |
|
dataset_features = dataset.features.keys() |
|
|
|
if sampling_rate is not None and audio_column_name is not None: |
|
|
|
dataset = dataset.cast_column(audio_column_name, datasets.features.Audio(sampling_rate=sampling_rate)) |
|
|
|
metadata_dataset_name = dataset_dict["metadata_dataset_name"] |
|
if metadata_dataset_name is not None: |
|
logger.info( |
|
f'Merging {dataset_dict["name"]} - {dataset_dict["split"]} with {metadata_dataset_name} - {dataset_dict["split"]}' |
|
) |
|
metadata_dataset = load_dataset( |
|
metadata_dataset_name, |
|
dataset_dict["config"], |
|
split=dataset_dict["split"], |
|
streaming=streaming, |
|
**kwargs, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if dataset_dict["name"] != "parler-tts/mls_eng_10k": |
|
if id_column_name is not None and id_column_name not in dataset.column_names: |
|
raise ValueError( |
|
f"id_column_name={id_column_name} but has not been found in the dataset columns" |
|
f"- one of {', '.join(list(dataset.column_names))}." |
|
) |
|
if id_column_name is not None and id_column_name not in metadata_dataset.column_names: |
|
raise ValueError( |
|
f"id_column_name={id_column_name} but has not been found in the metadata dataset columns" |
|
f"- one of {', '.join(list(metadata_dataset.column_names))}." |
|
) |
|
elif id_column_name is not None: |
|
metadata_dataset = metadata_dataset.rename_column(id_column_name, f"metadata_{id_column_name}") |
|
|
|
metadata_columns_to_remove = set(metadata_dataset.column_names).intersection(set(dataset.column_names)) |
|
|
|
if prompt_column_name is not None: |
|
|
|
|
|
if prompt_column_name in dataset.column_names: |
|
logger.info( |
|
f"REMOVE {prompt_column_name} from dataset {dataset_dict['name']} - dataset_dict['split']" |
|
) |
|
dataset.remove_columns(prompt_column_name) |
|
|
|
metadata_columns_to_remove = set(metadata_dataset.column_names).intersection(set(dataset.column_names)) |
|
metadata_dataset = metadata_dataset.remove_columns(metadata_columns_to_remove) |
|
|
|
dataset = concatenate_datasets([dataset, metadata_dataset], axis=1) |
|
|
|
if id_column_name is not None and dataset_dict["name"] != "parler-tts/mls_eng_10k": |
|
if ( |
|
len( |
|
dataset.filter( |
|
lambda id1, id2: id1 != id2, |
|
input_columns=[id_column_name, f"metadata_{id_column_name}"], |
|
) |
|
) |
|
!= 0 |
|
): |
|
raise ValueError( |
|
f"Concatenate didn't work. Some ids don't correspond on dataset {dataset_dict['name']}" |
|
) |
|
|
|
dataset_features = dataset.features.keys() |
|
|
|
if columns_to_keep is not None: |
|
dataset = dataset.remove_columns(set(dataset_features - columns_to_keep)) |
|
all_datasets.append(dataset) |
|
|
|
if len(all_datasets) == 1: |
|
|
|
return all_datasets[0] |
|
|
|
if streaming: |
|
interleaved_dataset = interleave_datasets( |
|
all_datasets, |
|
stopping_strategy=stopping_strategy, |
|
probabilities=probabilities, |
|
seed=seed, |
|
) |
|
else: |
|
with accelerator.main_process_first(): |
|
interleaved_dataset = concatenate_datasets(all_datasets) |
|
|
|
return interleaved_dataset |
|
|