victan commited on
Commit
a993bec
1 Parent(s): 8192381

Upload seamless_communication/cli/m4t/finetune/trainer.py with huggingface_hub

Browse files
seamless_communication/cli/m4t/finetune/trainer.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # MIT_LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import logging
9
+ from contextlib import contextmanager
10
+ from dataclasses import dataclass
11
+ from enum import Enum
12
+ from pathlib import Path
13
+ from typing import Optional, Tuple
14
+
15
+ import torch
16
+ import torch.distributed as dist
17
+ import torch.nn as nn
18
+ from fairseq2.data import VocabularyInfo
19
+ from fairseq2.models.sequence import SequenceModelOutput
20
+ from fairseq2.nn.padding import PaddingMask
21
+ from fairseq2.optim.lr_scheduler import MyleLR
22
+ from fairseq2.typing import Device
23
+ from torch.optim import Adam
24
+
25
+ from seamless_communication.cli.m4t.finetune import dataloader, dist_utils
26
+ from seamless_communication.models.unity import UnitYModel
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class FinetuneMode(Enum):
32
+ SPEECH_TO_SPEECH = "SPEECH_TO_SPEECH"
33
+ SPEECH_TO_TEXT = "SPEECH_TO_TEXT"
34
+ TEXT_TO_SPEECH = "TEXT_TO_SPEECH"
35
+
36
+
37
+ @dataclass
38
+ class FinetuneParams:
39
+ save_model_path: Path
40
+ """Path were to save finetuned model."""
41
+
42
+ finetune_mode: FinetuneMode = FinetuneMode.TEXT_TO_SPEECH
43
+ """Allows to freeze S2T or T2U part of the model"""
44
+
45
+ max_epochs: int = 10
46
+ """ Maximum number of trainign epochs"""
47
+
48
+ label_smoothing: float = 0.2
49
+ """ Label smoothing coefficient for nll_loss """
50
+
51
+ warmup_steps: int = 100
52
+ """ Number of steps with linearly increasing LR"""
53
+
54
+ log_steps: int = 10
55
+ """ Log inner loss after each `log_steps` training steps"""
56
+
57
+ eval_steps: int = 50
58
+ """ Get eval loss after each `eval_steps` training steps """
59
+
60
+ patience: int = 3
61
+ """ Terminate if eval loss did not improve
62
+ over the last `patience * eval_steps` training steps"""
63
+
64
+ learning_rate: float = 1e-5
65
+ """ Optimizer learining rate """
66
+
67
+ train_batch_size: int = 5
68
+ """The batch size during train steps"""
69
+
70
+ eval_batch_size: int = 5
71
+ """The batch size during evaluation."""
72
+
73
+ device: Device = torch.device("cuda")
74
+ """ Where to run computation"""
75
+
76
+
77
+ class UnitYFinetuneWrapper(nn.Module):
78
+ """Convenience wrapper that does a forward pass
79
+ and returns S2T and T2U logits"""
80
+
81
+ def __init__(self, model: UnitYModel, mode: FinetuneMode, device: Device):
82
+ super().__init__()
83
+ assert model.t2u_model is not None
84
+ self.model: UnitYModel = model
85
+ self.freeze_s2t: bool = mode == FinetuneMode.TEXT_TO_SPEECH
86
+ self.freeze_t2u: bool = mode == FinetuneMode.SPEECH_TO_TEXT
87
+ self.device = device
88
+
89
+ def forward(
90
+ self, batch: dataloader.MultimodalSeqsBatch
91
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
92
+ assert self.model.t2u_model is not None
93
+ dummy_context = contextmanager(lambda: iter([None]))()
94
+ with torch.no_grad() if self.freeze_s2t else dummy_context: # type:ignore
95
+ assert batch.speech_to_text.src_tokens is not None
96
+ seqs = batch.speech_to_text.src_tokens.to(self.device)
97
+ seq_lens = batch.speech_to_text.src_lengths.to(self.device)
98
+ speech_encoder_out, speech_encoder_padding_mask = self.model.encode_speech(
99
+ seqs=seqs, padding_mask=PaddingMask(seq_lens, seqs.size(1))
100
+ )
101
+ assert batch.speech_to_text.prev_output_tokens is not None
102
+ seqs = batch.speech_to_text.prev_output_tokens.to(self.device)
103
+ seq_lens = batch.speech_to_text.target_lengths.to(self.device)
104
+ text_decoder_out, text_decoder_padding_mask = self.model.decode(
105
+ seqs=seqs,
106
+ padding_mask=PaddingMask(seq_lens, seqs.size(1)),
107
+ encoder_output=speech_encoder_out,
108
+ encoder_padding_mask=speech_encoder_padding_mask,
109
+ )
110
+ text_logits = self.model.final_proj(text_decoder_out)
111
+ if batch.text_to_units.prev_output_tokens is None:
112
+ return (text_logits, None)
113
+ dummy_context = contextmanager(lambda: iter([None]))()
114
+ with torch.no_grad() if self.freeze_t2u else dummy_context: # type:ignore
115
+ (
116
+ unit_encoder_out,
117
+ unit_encoder_padding_mask,
118
+ ) = self.model.t2u_model.encode(
119
+ text_decoder_output=text_decoder_out,
120
+ text_decoder_padding_mask=text_decoder_padding_mask,
121
+ )
122
+ seqs = batch.text_to_units.prev_output_tokens.to(self.device)
123
+ seq_lens = batch.text_to_units.target_lengths.to(self.device)
124
+ unit_decoder_out, _ = self.model.t2u_model.decode(
125
+ seqs=seqs,
126
+ padding_mask=PaddingMask(seq_lens, seqs.size(1)),
127
+ encoder_output=unit_encoder_out,
128
+ encoder_padding_mask=unit_encoder_padding_mask,
129
+ )
130
+ unit_logits = self.model.t2u_model.final_proj(unit_decoder_out)
131
+
132
+ return (text_logits, unit_logits)
133
+
134
+
135
+ class CalcLoss:
136
+ """Calculates negative log likelihood loss for S2T and T2U"""
137
+
138
+ def __init__(
139
+ self,
140
+ label_smoothing: float,
141
+ s2t_vocab_info: VocabularyInfo,
142
+ t2u_vocab_info: VocabularyInfo,
143
+ ):
144
+ self.label_smoothing = label_smoothing
145
+ self.s2t_vocab_info = s2t_vocab_info
146
+ self.t2u_vocab_info = t2u_vocab_info
147
+
148
+ def __call__(
149
+ self,
150
+ batch: dataloader.MultimodalSeqsBatch,
151
+ text_logits: torch.Tensor,
152
+ unit_logits: Optional[torch.Tensor],
153
+ ) -> torch.Tensor:
154
+ assert batch.speech_to_text.target_lengths is not None
155
+ s2t_numel = torch.sum(batch.speech_to_text.target_lengths).to(
156
+ text_logits.device
157
+ )
158
+ s2t_loss = SequenceModelOutput(
159
+ logits=text_logits, vocab_info=self.s2t_vocab_info
160
+ ).compute_loss(
161
+ targets=batch.speech_to_text.target_tokens.to(text_logits.device),
162
+ ignore_prefix_size=1,
163
+ label_smoothing=self.label_smoothing,
164
+ )
165
+ if unit_logits is None:
166
+ return s2t_loss / s2t_numel
167
+ assert batch.text_to_units.target_lengths is not None
168
+ s2u_numel = torch.sum(batch.text_to_units.target_lengths).to(unit_logits.device)
169
+ s2u_loss = SequenceModelOutput(
170
+ logits=unit_logits, vocab_info=self.t2u_vocab_info
171
+ ).compute_loss(
172
+ targets=batch.text_to_units.target_tokens.to(unit_logits.device),
173
+ ignore_prefix_size=1,
174
+ label_smoothing=self.label_smoothing,
175
+ )
176
+ return s2t_loss / s2t_numel + s2u_loss / s2u_numel
177
+
178
+
179
+ class LossCollector:
180
+ """Aggregrates loss history across nodes"""
181
+
182
+ def __init__(self, device: Optional[Device] = None, reduce_op: str = "avg"):
183
+ self.n_samples: float = 0
184
+ self.val_sum: float = 0.0
185
+ self.reduce_op = reduce_op
186
+ self.device = device
187
+ self.is_distributed = dist_utils.is_dist_initialized()
188
+
189
+ def reset(self) -> None:
190
+ self.n_samples = 0
191
+ self.val_sum = 0.0
192
+
193
+ def update(self, n_samples: int, batch_loss: float) -> None:
194
+ self.n_samples += n_samples
195
+ self.val_sum += batch_loss
196
+
197
+ def reduce(self) -> float:
198
+ n_samples, val_sum = self._collect()
199
+ if self.reduce_op == "avg":
200
+ return val_sum / (n_samples + 1)
201
+ if self.reduce_op == "sum":
202
+ return val_sum
203
+ raise ValueError()
204
+
205
+ def _collect(self) -> Tuple[float, float]:
206
+ if not self.is_distributed:
207
+ return self.n_samples, self.val_sum
208
+ local_val = torch.tensor([[self.n_samples, self.val_sum]], device=self.device)
209
+ all_vals = [
210
+ torch.zeros((1, 2), device=self.device)
211
+ for _ in range(dist_utils.get_world_size())
212
+ ]
213
+ dist.all_gather(all_vals, local_val)
214
+ losses = torch.concat(all_vals, dim=0)
215
+ reduced = torch.sum(losses, dim=0).reshape(2).cpu()
216
+ return reduced[0].item(), reduced[1].item()
217
+
218
+
219
+ class UnitYFinetune:
220
+ def __init__(
221
+ self,
222
+ model: UnitYModel,
223
+ params: FinetuneParams,
224
+ train_data_loader: dataloader.UnitYDataLoader,
225
+ eval_data_loader: Optional[dataloader.UnitYDataLoader] = None,
226
+ ):
227
+ self.params = params
228
+
229
+ assert model.t2u_model is not None
230
+ self.calc_loss = CalcLoss(
231
+ label_smoothing=self.params.label_smoothing,
232
+ s2t_vocab_info=model.target_vocab_info,
233
+ t2u_vocab_info=model.t2u_model.target_vocab_info,
234
+ )
235
+ self.model = self._wrap_model_for_trainining(model=model)
236
+ self.train_data_loader = train_data_loader
237
+ self.eval_data_loader = eval_data_loader
238
+ self.optimizer = Adam(
239
+ params=self.model.parameters(),
240
+ lr=self.params.learning_rate,
241
+ betas=(0.9, 0.98),
242
+ eps=1e-08,
243
+ maximize=False,
244
+ weight_decay=0.0,
245
+ fused=True,
246
+ )
247
+ self.grad_scaler = torch.cuda.amp.GradScaler()
248
+ self.lr_scheduler = MyleLR(
249
+ optimizer=self.optimizer,
250
+ num_warmup_steps=self.params.warmup_steps,
251
+ start_lr=1e-9,
252
+ )
253
+
254
+ self.train_loss_hist = LossCollector(device=params.device)
255
+ self.epoch_idx: int = 0
256
+ self.update_idx: int = 0
257
+ self.patience_left: int = self.params.patience
258
+ self.best_eval_loss: Optional[float] = None
259
+ self.is_best_state: bool = False
260
+
261
+ def _reset_stats(self) -> None:
262
+ self.train_loss_hist.reset()
263
+ self.epoch_idx = 0
264
+ self.update_idx = 0
265
+ self.patience_left = self.params.patience
266
+ self.best_eval_loss = None
267
+ self.is_best_state = False
268
+
269
+ def _wrap_model_for_trainining(self, model: UnitYModel) -> nn.Module:
270
+ wrapped_model = UnitYFinetuneWrapper(
271
+ model=model, mode=self.params.finetune_mode, device=self.params.device
272
+ )
273
+ if not dist_utils.is_dist_initialized():
274
+ return wrapped_model
275
+ return nn.parallel.DistributedDataParallel(
276
+ wrapped_model,
277
+ device_ids=[dist_utils.get_local_rank()],
278
+ find_unused_parameters=True,
279
+ )
280
+
281
+ def _update_eval_stats(self, eval_loss: float) -> None:
282
+ self.is_best_state = (
283
+ self.best_eval_loss is None or eval_loss < self.best_eval_loss
284
+ )
285
+ self.best_eval_loss = eval_loss if self.is_best_state else self.best_eval_loss
286
+ self.patience_left = (
287
+ self.params.patience if self.is_best_state else self.patience_left - 1
288
+ )
289
+ logger.info(
290
+ f"Eval after {self.update_idx} updates: "
291
+ f"loss={eval_loss:.4f} "
292
+ f"best_loss={self.best_eval_loss:.4f} "
293
+ f"patience_steps_left={self.patience_left}"
294
+ )
295
+
296
+ def _eval_model(self) -> None:
297
+ """Calc avg loss on eval dataset and update evaluation stats"""
298
+ if self.eval_data_loader is None:
299
+ return
300
+ logger.info("Run evaluation")
301
+ loss_hist = LossCollector(device=self.params.device)
302
+ self.model.eval()
303
+ with torch.no_grad():
304
+ for batch in self.eval_data_loader.get_dataloader():
305
+ assert batch.speech_to_text.src_tokens is not None
306
+ loss = self.calc_loss(batch, *self.model(batch))
307
+ if loss.isnan():
308
+ logger.warning("Eval loss value is NaN, setting to inf")
309
+ loss_val = float("Inf")
310
+ else:
311
+ loss_val = loss.item()
312
+ del batch # force memory release
313
+ loss_hist.update(1, loss_val)
314
+ eval_loss = loss_hist.reduce()
315
+ self._update_eval_stats(eval_loss)
316
+
317
+ def _train_step_log(self):
318
+ """Log train stats"""
319
+ if (self.update_idx + 1) % self.params.log_steps == 0:
320
+ avg_loss = self.train_loss_hist.reduce()
321
+ self.train_loss_hist.reset()
322
+ logger.info(
323
+ f"Epoch {str(self.epoch_idx + 1).zfill(3)} / "
324
+ f"update {str(self.update_idx + 1).zfill(5)}: "
325
+ f"train loss={avg_loss:.4f} "
326
+ f"last lr={self.lr_scheduler.get_last_lr()[0]:.2E}"
327
+ )
328
+
329
+ def _train_step(self, batch: dataloader.MultimodalSeqsBatch) -> None:
330
+ """Run one train step"""
331
+ self.model.train()
332
+ self.optimizer.zero_grad()
333
+ tokens, units = self.model(batch)
334
+ loss = self.calc_loss(batch, tokens, units)
335
+ self.grad_scaler.scale(loss).backward()
336
+ self.grad_scaler.step(self.optimizer)
337
+ self.grad_scaler.update()
338
+ self.lr_scheduler.step()
339
+ assert batch.speech_to_text.src_tokens is not None
340
+ self.train_loss_hist.update(1, loss.item())
341
+ self._train_step_log()
342
+
343
+ def _save_model(self):
344
+ logger.info("Saving model")
345
+ if dist_utils.is_main_process():
346
+ state_dict = {
347
+ key.replace("module.model.", ""): value
348
+ for key, value in self.model.state_dict().items()
349
+ }
350
+ torch.save(state_dict, self.params.save_model_path)
351
+ if dist_utils.is_dist_initialized():
352
+ dist.barrier()
353
+
354
+ def run(self):
355
+ logger.info("Start finetuning")
356
+ self._reset_stats()
357
+ self._eval_model()
358
+ batch_itr = self.train_data_loader.get_dataloader()
359
+ while self.epoch_idx < self.params.max_epochs and self.patience_left:
360
+ for train_batch in batch_itr:
361
+ self._train_step(batch=train_batch)
362
+ if self.update_idx and self.update_idx % self.params.eval_steps == 0:
363
+ self._eval_model()
364
+ if self.is_best_state:
365
+ self._save_model()
366
+ elif not self.patience_left:
367
+ no_improve_steps = self.params.eval_steps * self.params.patience
368
+ logger.info(
369
+ "Early termination, as eval loss did not improve "
370
+ f"over last {no_improve_steps} updates"
371
+ )
372
+ break
373
+ self.update_idx += 1
374
+ self.epoch_idx += 1