victan commited on
Commit
b84aa12
1 Parent(s): f1fcd6a

Upload seamless_communication/cli/m4t/finetune/dataloader.py with huggingface_hub

Browse files
seamless_communication/cli/m4t/finetune/dataloader.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # MIT_LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import json
9
+ import logging
10
+ from dataclasses import dataclass
11
+ from typing import Any, Dict, Iterable, List, Optional
12
+
13
+ import numpy as np
14
+ import torch
15
+ import torchaudio
16
+ import torchaudio.compliance.kaldi as ta_kaldi
17
+ from datasets import Dataset
18
+ from datasets.distributed import split_dataset_by_node
19
+ from fairseq2.data.text import TextTokenEncoder
20
+ from fairseq2.models.nllb import NllbTokenizer
21
+ from torch import Tensor
22
+ from torch.nn.functional import pad as pad_tensor
23
+ from torch.utils.data import DataLoader
24
+
25
+ from seamless_communication.datasets.datatypes import LangPairSample
26
+ from seamless_communication.models.unity.unit_tokenizer import (
27
+ UnitTokenEncoder,
28
+ UnitTokenizer,
29
+ )
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ @dataclass
35
+ class SeqsBatch:
36
+ src_tokens: Optional[Tensor]
37
+ src_lengths: Optional[Tensor]
38
+ target_tokens: Optional[Tensor]
39
+ prev_output_tokens: Optional[Tensor]
40
+ target_lengths: Optional[Tensor]
41
+
42
+ def __del__(self) -> None:
43
+ """Explicitly delete tensors
44
+ to force GPU memory cleanup"""
45
+ for tensor in [
46
+ self.src_tokens,
47
+ self.src_lengths,
48
+ self.target_tokens,
49
+ self.prev_output_tokens,
50
+ self.target_lengths,
51
+ ]:
52
+ if tensor is not None:
53
+ del tensor
54
+
55
+
56
+ @dataclass
57
+ class MultimodalSeqsBatch:
58
+ speech_to_text: SeqsBatch
59
+ text_to_units: SeqsBatch
60
+
61
+ def __del__(self) -> None:
62
+ del self.speech_to_text
63
+ del self.text_to_units
64
+
65
+
66
+ @dataclass
67
+ class BatchingConfig:
68
+ fbank_feats_pad_idx: int = 0
69
+ """The pad index to use in fbanks batching."""
70
+
71
+ batch_size: int = 5
72
+
73
+ rank: int = 0
74
+ """The rank of this worker in the process group."""
75
+
76
+ world_size: int = 1
77
+ """The world size of the process group."""
78
+
79
+ num_workers: int = 2
80
+ """Parallelism in dataset preparation."""
81
+
82
+ float_dtype: torch.dtype = torch.float16
83
+ """Select between fp16/fp32 for float tensors """
84
+
85
+
86
+ def worker_init_fn(worker_id):
87
+ np.random.seed(np.random.get_state()[1][0] + worker_id)
88
+
89
+
90
+ class UnitYDataLoader:
91
+ def __init__(
92
+ self,
93
+ text_tokenizer: NllbTokenizer,
94
+ unit_tokenizer: UnitTokenizer,
95
+ dataset_manifest_path: str,
96
+ batching_config: BatchingConfig,
97
+ ):
98
+ self.text_tokenizer = text_tokenizer
99
+ self.text_encoders_per_lang: Dict[str, TextTokenEncoder] = {}
100
+ self.unit_tokenizer = unit_tokenizer
101
+ self.unit_encoders_per_lang: Dict[str, UnitTokenEncoder] = {}
102
+ self.batching_config = batching_config
103
+ self.dataset = self._load_manifest(dataset_manifest_path)
104
+
105
+ def get_dataloader(self) -> DataLoader:
106
+ subset = split_dataset_by_node(
107
+ self.dataset,
108
+ rank=self.batching_config.rank,
109
+ world_size=self.batching_config.world_size,
110
+ )
111
+ data_loader = DataLoader(
112
+ dataset=subset,
113
+ batch_size=self.batching_config.batch_size,
114
+ shuffle=True,
115
+ num_workers=self.batching_config.num_workers,
116
+ collate_fn=self._prepare_batch,
117
+ worker_init_fn=worker_init_fn,
118
+ )
119
+ return data_loader
120
+
121
+ def __iter__(self) -> Iterable[MultimodalSeqsBatch]:
122
+ return self.get_dataloader().__iter__()
123
+
124
+ def _get_source_fbank(self, sample: LangPairSample) -> Tensor:
125
+ audio_input = torchaudio.load(sample.source.audio_local_path)[0]
126
+ return ta_kaldi.fbank(audio_input, num_mel_bins=80)
127
+
128
+ def _get_tokenized_target_text(self, sample: LangPairSample) -> Tensor:
129
+ """Expected sequence is [<eos>, <lang_tok> , ..text tokens.., <eos>]"""
130
+ target_lang = sample.target.lang
131
+ if target_lang not in self.text_encoders_per_lang:
132
+ self.text_encoders_per_lang[
133
+ target_lang
134
+ ] = self.text_tokenizer.create_encoder(lang=target_lang, mode="target")
135
+ tokens = self.text_encoders_per_lang[target_lang](sample.target.text)
136
+ eos_idx = self.text_tokenizer.vocab_info.eos_idx
137
+ tokens = torch.concat([tokens, torch.LongTensor([eos_idx])])
138
+ return tokens
139
+
140
+ def _get_tokenized_units(self, sample: LangPairSample) -> Optional[Tensor]:
141
+ """Expected sequence is [<eos>, <lang_tok> , ..unit tokens.., <eos>]"""
142
+ if sample.target.units is None:
143
+ return None
144
+ target_lang = sample.target.lang
145
+ if target_lang not in self.unit_encoders_per_lang:
146
+ self.unit_encoders_per_lang[
147
+ target_lang
148
+ ] = self.unit_tokenizer.create_encoder(lang=target_lang)
149
+ tokens = self.unit_encoders_per_lang[target_lang](
150
+ torch.LongTensor(sample.target.units).unsqueeze(0)
151
+ )
152
+ eos_idx = self.unit_tokenizer.vocab_info.eos_idx
153
+ tokens = torch.concat([tokens.squeeze(0), torch.LongTensor([eos_idx])])
154
+ return tokens
155
+
156
+ def _batch_tensors(self, tensors: List[Tensor], pad_value: Any) -> Tensor:
157
+ padding_size = max(tensor.shape[0] for tensor in tensors)
158
+ dims = len(tensors[0].shape)
159
+ padded_tensors = []
160
+ for tensor in tensors:
161
+ padding = [0] * 2 * dims
162
+ padding[-1] = padding_size - tensor.shape[0]
163
+ padded_tensors.append(pad_tensor(tensor, padding, "constant", pad_value))
164
+ return torch.stack([tensor for tensor in padded_tensors], dim=0)
165
+
166
+ def _prepare_batch(self, raw_samples: List[Dict[str, Any]]) -> MultimodalSeqsBatch:
167
+ samples = [LangPairSample.from_json(sample) for sample in raw_samples]
168
+ # input speech
169
+ src_tokens_list = [self._get_source_fbank(sample) for sample in samples]
170
+ src_tokens = self._batch_tensors(
171
+ src_tokens_list, pad_value=self.batching_config.fbank_feats_pad_idx
172
+ ).to(self.batching_config.float_dtype)
173
+ src_lengths = torch.LongTensor(
174
+ [src_tokens.shape[0] for src_tokens in src_tokens_list]
175
+ )
176
+ # output text
177
+ text_tokens_list = [
178
+ self._get_tokenized_target_text(sample) for sample in samples
179
+ ]
180
+ text_pad_idx = self.text_tokenizer.vocab_info.pad_idx
181
+ prev_outputs_tokens = self._batch_tensors(
182
+ [tokens[:-1] for tokens in text_tokens_list], pad_value=text_pad_idx
183
+ )
184
+ target_tokens = self._batch_tensors(
185
+ [tokens[1:] for tokens in text_tokens_list], pad_value=text_pad_idx
186
+ )
187
+ tokens_lengths = torch.LongTensor(
188
+ [tokens.shape[0] - 1 for tokens in text_tokens_list]
189
+ )
190
+ # output units
191
+ units_list_raw = [self._get_tokenized_units(sample) for sample in samples]
192
+ if None in units_list_raw:
193
+ prev_outputs_units = None
194
+ target_units = None
195
+ units_lengths = None
196
+ else:
197
+ units_list: List[Tensor] = [
198
+ value for value in units_list_raw if value is not None
199
+ ]
200
+ units_pad_idx = self.unit_tokenizer.vocab_info.pad_idx
201
+ prev_outputs_units = self._batch_tensors(
202
+ [tokens[:-1] for tokens in units_list], pad_value=units_pad_idx
203
+ )
204
+ target_units = self._batch_tensors(
205
+ [tokens[1:] for tokens in units_list], pad_value=units_pad_idx
206
+ )
207
+ units_lengths = torch.LongTensor(
208
+ [tokens.shape[0] - 1 for tokens in units_list]
209
+ )
210
+ return MultimodalSeqsBatch(
211
+ speech_to_text=SeqsBatch(
212
+ src_tokens=src_tokens,
213
+ src_lengths=src_lengths,
214
+ target_tokens=target_tokens,
215
+ prev_output_tokens=prev_outputs_tokens,
216
+ target_lengths=tokens_lengths,
217
+ ),
218
+ text_to_units=SeqsBatch(
219
+ src_tokens=None,
220
+ src_lengths=None,
221
+ target_tokens=target_units,
222
+ prev_output_tokens=prev_outputs_units,
223
+ target_lengths=units_lengths,
224
+ ),
225
+ )
226
+
227
+ def _load_manifest(self, dataset_manifest_path: str) -> Dataset:
228
+ with open(dataset_manifest_path) as fp_in:
229
+ dataset = [json.loads(line) for line in fp_in]
230
+ return Dataset.from_list(dataset)