File size: 8,386 Bytes
b84aa12 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 |
# Copyright (c) Meta Platforms, Inc. and affiliates
# All rights reserved.
#
# This source code is licensed under the license found in the
# MIT_LICENSE file in the root directory of this source tree.
import json
import logging
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional
import numpy as np
import torch
import torchaudio
import torchaudio.compliance.kaldi as ta_kaldi
from datasets import Dataset
from datasets.distributed import split_dataset_by_node
from fairseq2.data.text import TextTokenEncoder
from fairseq2.models.nllb import NllbTokenizer
from torch import Tensor
from torch.nn.functional import pad as pad_tensor
from torch.utils.data import DataLoader
from seamless_communication.datasets.datatypes import LangPairSample
from seamless_communication.models.unity.unit_tokenizer import (
UnitTokenEncoder,
UnitTokenizer,
)
logger = logging.getLogger(__name__)
@dataclass
class SeqsBatch:
src_tokens: Optional[Tensor]
src_lengths: Optional[Tensor]
target_tokens: Optional[Tensor]
prev_output_tokens: Optional[Tensor]
target_lengths: Optional[Tensor]
def __del__(self) -> None:
"""Explicitly delete tensors
to force GPU memory cleanup"""
for tensor in [
self.src_tokens,
self.src_lengths,
self.target_tokens,
self.prev_output_tokens,
self.target_lengths,
]:
if tensor is not None:
del tensor
@dataclass
class MultimodalSeqsBatch:
speech_to_text: SeqsBatch
text_to_units: SeqsBatch
def __del__(self) -> None:
del self.speech_to_text
del self.text_to_units
@dataclass
class BatchingConfig:
fbank_feats_pad_idx: int = 0
"""The pad index to use in fbanks batching."""
batch_size: int = 5
rank: int = 0
"""The rank of this worker in the process group."""
world_size: int = 1
"""The world size of the process group."""
num_workers: int = 2
"""Parallelism in dataset preparation."""
float_dtype: torch.dtype = torch.float16
"""Select between fp16/fp32 for float tensors """
def worker_init_fn(worker_id):
np.random.seed(np.random.get_state()[1][0] + worker_id)
class UnitYDataLoader:
def __init__(
self,
text_tokenizer: NllbTokenizer,
unit_tokenizer: UnitTokenizer,
dataset_manifest_path: str,
batching_config: BatchingConfig,
):
self.text_tokenizer = text_tokenizer
self.text_encoders_per_lang: Dict[str, TextTokenEncoder] = {}
self.unit_tokenizer = unit_tokenizer
self.unit_encoders_per_lang: Dict[str, UnitTokenEncoder] = {}
self.batching_config = batching_config
self.dataset = self._load_manifest(dataset_manifest_path)
def get_dataloader(self) -> DataLoader:
subset = split_dataset_by_node(
self.dataset,
rank=self.batching_config.rank,
world_size=self.batching_config.world_size,
)
data_loader = DataLoader(
dataset=subset,
batch_size=self.batching_config.batch_size,
shuffle=True,
num_workers=self.batching_config.num_workers,
collate_fn=self._prepare_batch,
worker_init_fn=worker_init_fn,
)
return data_loader
def __iter__(self) -> Iterable[MultimodalSeqsBatch]:
return self.get_dataloader().__iter__()
def _get_source_fbank(self, sample: LangPairSample) -> Tensor:
audio_input = torchaudio.load(sample.source.audio_local_path)[0]
return ta_kaldi.fbank(audio_input, num_mel_bins=80)
def _get_tokenized_target_text(self, sample: LangPairSample) -> Tensor:
"""Expected sequence is [<eos>, <lang_tok> , ..text tokens.., <eos>]"""
target_lang = sample.target.lang
if target_lang not in self.text_encoders_per_lang:
self.text_encoders_per_lang[
target_lang
] = self.text_tokenizer.create_encoder(lang=target_lang, mode="target")
tokens = self.text_encoders_per_lang[target_lang](sample.target.text)
eos_idx = self.text_tokenizer.vocab_info.eos_idx
tokens = torch.concat([tokens, torch.LongTensor([eos_idx])])
return tokens
def _get_tokenized_units(self, sample: LangPairSample) -> Optional[Tensor]:
"""Expected sequence is [<eos>, <lang_tok> , ..unit tokens.., <eos>]"""
if sample.target.units is None:
return None
target_lang = sample.target.lang
if target_lang not in self.unit_encoders_per_lang:
self.unit_encoders_per_lang[
target_lang
] = self.unit_tokenizer.create_encoder(lang=target_lang)
tokens = self.unit_encoders_per_lang[target_lang](
torch.LongTensor(sample.target.units).unsqueeze(0)
)
eos_idx = self.unit_tokenizer.vocab_info.eos_idx
tokens = torch.concat([tokens.squeeze(0), torch.LongTensor([eos_idx])])
return tokens
def _batch_tensors(self, tensors: List[Tensor], pad_value: Any) -> Tensor:
padding_size = max(tensor.shape[0] for tensor in tensors)
dims = len(tensors[0].shape)
padded_tensors = []
for tensor in tensors:
padding = [0] * 2 * dims
padding[-1] = padding_size - tensor.shape[0]
padded_tensors.append(pad_tensor(tensor, padding, "constant", pad_value))
return torch.stack([tensor for tensor in padded_tensors], dim=0)
def _prepare_batch(self, raw_samples: List[Dict[str, Any]]) -> MultimodalSeqsBatch:
samples = [LangPairSample.from_json(sample) for sample in raw_samples]
# input speech
src_tokens_list = [self._get_source_fbank(sample) for sample in samples]
src_tokens = self._batch_tensors(
src_tokens_list, pad_value=self.batching_config.fbank_feats_pad_idx
).to(self.batching_config.float_dtype)
src_lengths = torch.LongTensor(
[src_tokens.shape[0] for src_tokens in src_tokens_list]
)
# output text
text_tokens_list = [
self._get_tokenized_target_text(sample) for sample in samples
]
text_pad_idx = self.text_tokenizer.vocab_info.pad_idx
prev_outputs_tokens = self._batch_tensors(
[tokens[:-1] for tokens in text_tokens_list], pad_value=text_pad_idx
)
target_tokens = self._batch_tensors(
[tokens[1:] for tokens in text_tokens_list], pad_value=text_pad_idx
)
tokens_lengths = torch.LongTensor(
[tokens.shape[0] - 1 for tokens in text_tokens_list]
)
# output units
units_list_raw = [self._get_tokenized_units(sample) for sample in samples]
if None in units_list_raw:
prev_outputs_units = None
target_units = None
units_lengths = None
else:
units_list: List[Tensor] = [
value for value in units_list_raw if value is not None
]
units_pad_idx = self.unit_tokenizer.vocab_info.pad_idx
prev_outputs_units = self._batch_tensors(
[tokens[:-1] for tokens in units_list], pad_value=units_pad_idx
)
target_units = self._batch_tensors(
[tokens[1:] for tokens in units_list], pad_value=units_pad_idx
)
units_lengths = torch.LongTensor(
[tokens.shape[0] - 1 for tokens in units_list]
)
return MultimodalSeqsBatch(
speech_to_text=SeqsBatch(
src_tokens=src_tokens,
src_lengths=src_lengths,
target_tokens=target_tokens,
prev_output_tokens=prev_outputs_tokens,
target_lengths=tokens_lengths,
),
text_to_units=SeqsBatch(
src_tokens=None,
src_lengths=None,
target_tokens=target_units,
prev_output_tokens=prev_outputs_units,
target_lengths=units_lengths,
),
)
def _load_manifest(self, dataset_manifest_path: str) -> Dataset:
with open(dataset_manifest_path) as fp_in:
dataset = [json.loads(line) for line in fp_in]
return Dataset.from_list(dataset)
|