Spaces:
Runtime error
Runtime error
import torch | |
from typing import Any, Dict, List, Union | |
class TTSDataCollatorWithPadding: | |
def __init__(self, model, processor): | |
self.model = model | |
self.processor = processor | |
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: | |
input_ids = [{"input_ids": feature["input_ids"]} for feature in features] | |
label_features = [{"input_values": feature["labels"]} for feature in features] | |
speaker_features = [feature["speaker_embeddings"] for feature in features] | |
# collate the inputs and targets into a batch | |
batch = self.processor.pad( | |
input_ids=input_ids, | |
labels=label_features, | |
return_tensors="pt", | |
) | |
# replace padding with -100 to ignore loss correctly | |
batch["labels"] = batch["labels"].masked_fill( | |
batch.decoder_attention_mask.unsqueeze(-1).ne(1), -100 | |
) | |
# not used during fine-tuning | |
del batch["decoder_attention_mask"] | |
# round down target lengths to multiple of reduction factor | |
if self.model.config.reduction_factor > 1: | |
target_lengths = torch.tensor([ | |
len(feature["input_values"]) for feature in label_features | |
]) | |
target_lengths = target_lengths.new([ | |
length - length % self.model.config.reduction_factor for length in target_lengths | |
]) | |
max_length = max(target_lengths) | |
batch["labels"] = batch["labels"][:, :max_length] | |
# add the speaker embeddings | |
batch["speaker_embeddings"] = torch.tensor(speaker_features) | |
return batch | |