|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
from contextlib import contextmanager |
|
from dataclasses import dataclass |
|
from enum import Enum |
|
from pathlib import Path |
|
from typing import Optional, Tuple |
|
|
|
import torch |
|
import torch.distributed as dist |
|
import torch.nn as nn |
|
from fairseq2.data import VocabularyInfo |
|
from fairseq2.models.sequence import SequenceModelOutput |
|
from fairseq2.nn.padding import PaddingMask |
|
from fairseq2.optim.lr_scheduler import MyleLR |
|
from fairseq2.typing import Device |
|
from torch.optim import Adam |
|
|
|
from seamless_communication.cli.m4t.finetune import dataloader, dist_utils |
|
from seamless_communication.models.unity import UnitYModel |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class FinetuneMode(Enum): |
|
SPEECH_TO_SPEECH = "SPEECH_TO_SPEECH" |
|
SPEECH_TO_TEXT = "SPEECH_TO_TEXT" |
|
TEXT_TO_SPEECH = "TEXT_TO_SPEECH" |
|
|
|
|
|
@dataclass |
|
class FinetuneParams: |
|
save_model_path: Path |
|
"""Path were to save finetuned model.""" |
|
|
|
finetune_mode: FinetuneMode = FinetuneMode.TEXT_TO_SPEECH |
|
"""Allows to freeze S2T or T2U part of the model""" |
|
|
|
max_epochs: int = 10 |
|
""" Maximum number of trainign epochs""" |
|
|
|
label_smoothing: float = 0.2 |
|
""" Label smoothing coefficient for nll_loss """ |
|
|
|
warmup_steps: int = 100 |
|
""" Number of steps with linearly increasing LR""" |
|
|
|
log_steps: int = 10 |
|
""" Log inner loss after each `log_steps` training steps""" |
|
|
|
eval_steps: int = 50 |
|
""" Get eval loss after each `eval_steps` training steps """ |
|
|
|
patience: int = 3 |
|
""" Terminate if eval loss did not improve |
|
over the last `patience * eval_steps` training steps""" |
|
|
|
learning_rate: float = 1e-5 |
|
""" Optimizer learining rate """ |
|
|
|
train_batch_size: int = 5 |
|
"""The batch size during train steps""" |
|
|
|
eval_batch_size: int = 5 |
|
"""The batch size during evaluation.""" |
|
|
|
device: Device = torch.device("cuda") |
|
""" Where to run computation""" |
|
|
|
|
|
class UnitYFinetuneWrapper(nn.Module): |
|
"""Convenience wrapper that does a forward pass |
|
and returns S2T and T2U logits""" |
|
|
|
def __init__(self, model: UnitYModel, mode: FinetuneMode, device: Device): |
|
super().__init__() |
|
assert model.t2u_model is not None |
|
self.model: UnitYModel = model |
|
self.freeze_s2t: bool = mode == FinetuneMode.TEXT_TO_SPEECH |
|
self.freeze_t2u: bool = mode == FinetuneMode.SPEECH_TO_TEXT |
|
self.device = device |
|
|
|
def forward( |
|
self, batch: dataloader.MultimodalSeqsBatch |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
|
assert self.model.t2u_model is not None |
|
dummy_context = contextmanager(lambda: iter([None]))() |
|
with torch.no_grad() if self.freeze_s2t else dummy_context: |
|
assert batch.speech_to_text.src_tokens is not None |
|
seqs = batch.speech_to_text.src_tokens.to(self.device) |
|
seq_lens = batch.speech_to_text.src_lengths.to(self.device) |
|
speech_encoder_out, speech_encoder_padding_mask = self.model.encode_speech( |
|
seqs=seqs, padding_mask=PaddingMask(seq_lens, seqs.size(1)) |
|
) |
|
assert batch.speech_to_text.prev_output_tokens is not None |
|
seqs = batch.speech_to_text.prev_output_tokens.to(self.device) |
|
seq_lens = batch.speech_to_text.target_lengths.to(self.device) |
|
text_decoder_out, text_decoder_padding_mask = self.model.decode( |
|
seqs=seqs, |
|
padding_mask=PaddingMask(seq_lens, seqs.size(1)), |
|
encoder_output=speech_encoder_out, |
|
encoder_padding_mask=speech_encoder_padding_mask, |
|
) |
|
text_logits = self.model.final_proj(text_decoder_out) |
|
if batch.text_to_units.prev_output_tokens is None: |
|
return (text_logits, None) |
|
dummy_context = contextmanager(lambda: iter([None]))() |
|
with torch.no_grad() if self.freeze_t2u else dummy_context: |
|
( |
|
unit_encoder_out, |
|
unit_encoder_padding_mask, |
|
) = self.model.t2u_model.encode( |
|
text_decoder_output=text_decoder_out, |
|
text_decoder_padding_mask=text_decoder_padding_mask, |
|
) |
|
seqs = batch.text_to_units.prev_output_tokens.to(self.device) |
|
seq_lens = batch.text_to_units.target_lengths.to(self.device) |
|
unit_decoder_out, _ = self.model.t2u_model.decode( |
|
seqs=seqs, |
|
padding_mask=PaddingMask(seq_lens, seqs.size(1)), |
|
encoder_output=unit_encoder_out, |
|
encoder_padding_mask=unit_encoder_padding_mask, |
|
) |
|
unit_logits = self.model.t2u_model.final_proj(unit_decoder_out) |
|
|
|
return (text_logits, unit_logits) |
|
|
|
|
|
class CalcLoss: |
|
"""Calculates negative log likelihood loss for S2T and T2U""" |
|
|
|
def __init__( |
|
self, |
|
label_smoothing: float, |
|
s2t_vocab_info: VocabularyInfo, |
|
t2u_vocab_info: VocabularyInfo, |
|
): |
|
self.label_smoothing = label_smoothing |
|
self.s2t_vocab_info = s2t_vocab_info |
|
self.t2u_vocab_info = t2u_vocab_info |
|
|
|
def __call__( |
|
self, |
|
batch: dataloader.MultimodalSeqsBatch, |
|
text_logits: torch.Tensor, |
|
unit_logits: Optional[torch.Tensor], |
|
) -> torch.Tensor: |
|
assert batch.speech_to_text.target_lengths is not None |
|
s2t_numel = torch.sum(batch.speech_to_text.target_lengths).to( |
|
text_logits.device |
|
) |
|
s2t_loss = SequenceModelOutput( |
|
logits=text_logits, vocab_info=self.s2t_vocab_info |
|
).compute_loss( |
|
targets=batch.speech_to_text.target_tokens.to(text_logits.device), |
|
ignore_prefix_size=1, |
|
label_smoothing=self.label_smoothing, |
|
) |
|
if unit_logits is None: |
|
return s2t_loss / s2t_numel |
|
assert batch.text_to_units.target_lengths is not None |
|
s2u_numel = torch.sum(batch.text_to_units.target_lengths).to(unit_logits.device) |
|
s2u_loss = SequenceModelOutput( |
|
logits=unit_logits, vocab_info=self.t2u_vocab_info |
|
).compute_loss( |
|
targets=batch.text_to_units.target_tokens.to(unit_logits.device), |
|
ignore_prefix_size=1, |
|
label_smoothing=self.label_smoothing, |
|
) |
|
return s2t_loss / s2t_numel + s2u_loss / s2u_numel |
|
|
|
|
|
class LossCollector: |
|
"""Aggregrates loss history across nodes""" |
|
|
|
def __init__(self, device: Optional[Device] = None, reduce_op: str = "avg"): |
|
self.n_samples: float = 0 |
|
self.val_sum: float = 0.0 |
|
self.reduce_op = reduce_op |
|
self.device = device |
|
self.is_distributed = dist_utils.is_dist_initialized() |
|
|
|
def reset(self) -> None: |
|
self.n_samples = 0 |
|
self.val_sum = 0.0 |
|
|
|
def update(self, n_samples: int, batch_loss: float) -> None: |
|
self.n_samples += n_samples |
|
self.val_sum += batch_loss |
|
|
|
def reduce(self) -> float: |
|
n_samples, val_sum = self._collect() |
|
if self.reduce_op == "avg": |
|
return val_sum / (n_samples + 1) |
|
if self.reduce_op == "sum": |
|
return val_sum |
|
raise ValueError() |
|
|
|
def _collect(self) -> Tuple[float, float]: |
|
if not self.is_distributed: |
|
return self.n_samples, self.val_sum |
|
local_val = torch.tensor([[self.n_samples, self.val_sum]], device=self.device) |
|
all_vals = [ |
|
torch.zeros((1, 2), device=self.device) |
|
for _ in range(dist_utils.get_world_size()) |
|
] |
|
dist.all_gather(all_vals, local_val) |
|
losses = torch.concat(all_vals, dim=0) |
|
reduced = torch.sum(losses, dim=0).reshape(2).cpu() |
|
return reduced[0].item(), reduced[1].item() |
|
|
|
|
|
class UnitYFinetune: |
|
def __init__( |
|
self, |
|
model: UnitYModel, |
|
params: FinetuneParams, |
|
train_data_loader: dataloader.UnitYDataLoader, |
|
eval_data_loader: Optional[dataloader.UnitYDataLoader] = None, |
|
): |
|
self.params = params |
|
|
|
assert model.t2u_model is not None |
|
self.calc_loss = CalcLoss( |
|
label_smoothing=self.params.label_smoothing, |
|
s2t_vocab_info=model.target_vocab_info, |
|
t2u_vocab_info=model.t2u_model.target_vocab_info, |
|
) |
|
self.model = self._wrap_model_for_trainining(model=model) |
|
self.train_data_loader = train_data_loader |
|
self.eval_data_loader = eval_data_loader |
|
self.optimizer = Adam( |
|
params=self.model.parameters(), |
|
lr=self.params.learning_rate, |
|
betas=(0.9, 0.98), |
|
eps=1e-08, |
|
maximize=False, |
|
weight_decay=0.0, |
|
fused=True, |
|
) |
|
self.grad_scaler = torch.cuda.amp.GradScaler() |
|
self.lr_scheduler = MyleLR( |
|
optimizer=self.optimizer, |
|
num_warmup_steps=self.params.warmup_steps, |
|
start_lr=1e-9, |
|
) |
|
|
|
self.train_loss_hist = LossCollector(device=params.device) |
|
self.epoch_idx: int = 0 |
|
self.update_idx: int = 0 |
|
self.patience_left: int = self.params.patience |
|
self.best_eval_loss: Optional[float] = None |
|
self.is_best_state: bool = False |
|
|
|
def _reset_stats(self) -> None: |
|
self.train_loss_hist.reset() |
|
self.epoch_idx = 0 |
|
self.update_idx = 0 |
|
self.patience_left = self.params.patience |
|
self.best_eval_loss = None |
|
self.is_best_state = False |
|
|
|
def _wrap_model_for_trainining(self, model: UnitYModel) -> nn.Module: |
|
wrapped_model = UnitYFinetuneWrapper( |
|
model=model, mode=self.params.finetune_mode, device=self.params.device |
|
) |
|
if not dist_utils.is_dist_initialized(): |
|
return wrapped_model |
|
return nn.parallel.DistributedDataParallel( |
|
wrapped_model, |
|
device_ids=[dist_utils.get_local_rank()], |
|
find_unused_parameters=True, |
|
) |
|
|
|
def _update_eval_stats(self, eval_loss: float) -> None: |
|
self.is_best_state = ( |
|
self.best_eval_loss is None or eval_loss < self.best_eval_loss |
|
) |
|
self.best_eval_loss = eval_loss if self.is_best_state else self.best_eval_loss |
|
self.patience_left = ( |
|
self.params.patience if self.is_best_state else self.patience_left - 1 |
|
) |
|
logger.info( |
|
f"Eval after {self.update_idx} updates: " |
|
f"loss={eval_loss:.4f} " |
|
f"best_loss={self.best_eval_loss:.4f} " |
|
f"patience_steps_left={self.patience_left}" |
|
) |
|
|
|
def _eval_model(self) -> None: |
|
"""Calc avg loss on eval dataset and update evaluation stats""" |
|
if self.eval_data_loader is None: |
|
return |
|
logger.info("Run evaluation") |
|
loss_hist = LossCollector(device=self.params.device) |
|
self.model.eval() |
|
with torch.no_grad(): |
|
for batch in self.eval_data_loader.get_dataloader(): |
|
assert batch.speech_to_text.src_tokens is not None |
|
loss = self.calc_loss(batch, *self.model(batch)) |
|
if loss.isnan(): |
|
logger.warning("Eval loss value is NaN, setting to inf") |
|
loss_val = float("Inf") |
|
else: |
|
loss_val = loss.item() |
|
del batch |
|
loss_hist.update(1, loss_val) |
|
eval_loss = loss_hist.reduce() |
|
self._update_eval_stats(eval_loss) |
|
|
|
def _train_step_log(self): |
|
"""Log train stats""" |
|
if (self.update_idx + 1) % self.params.log_steps == 0: |
|
avg_loss = self.train_loss_hist.reduce() |
|
self.train_loss_hist.reset() |
|
logger.info( |
|
f"Epoch {str(self.epoch_idx + 1).zfill(3)} / " |
|
f"update {str(self.update_idx + 1).zfill(5)}: " |
|
f"train loss={avg_loss:.4f} " |
|
f"last lr={self.lr_scheduler.get_last_lr()[0]:.2E}" |
|
) |
|
|
|
def _train_step(self, batch: dataloader.MultimodalSeqsBatch) -> None: |
|
"""Run one train step""" |
|
self.model.train() |
|
self.optimizer.zero_grad() |
|
tokens, units = self.model(batch) |
|
loss = self.calc_loss(batch, tokens, units) |
|
self.grad_scaler.scale(loss).backward() |
|
self.grad_scaler.step(self.optimizer) |
|
self.grad_scaler.update() |
|
self.lr_scheduler.step() |
|
assert batch.speech_to_text.src_tokens is not None |
|
self.train_loss_hist.update(1, loss.item()) |
|
self._train_step_log() |
|
|
|
def _save_model(self): |
|
logger.info("Saving model") |
|
if dist_utils.is_main_process(): |
|
state_dict = { |
|
key.replace("module.model.", ""): value |
|
for key, value in self.model.state_dict().items() |
|
} |
|
torch.save(state_dict, self.params.save_model_path) |
|
if dist_utils.is_dist_initialized(): |
|
dist.barrier() |
|
|
|
def run(self): |
|
logger.info("Start finetuning") |
|
self._reset_stats() |
|
self._eval_model() |
|
batch_itr = self.train_data_loader.get_dataloader() |
|
while self.epoch_idx < self.params.max_epochs and self.patience_left: |
|
for train_batch in batch_itr: |
|
self._train_step(batch=train_batch) |
|
if self.update_idx and self.update_idx % self.params.eval_steps == 0: |
|
self._eval_model() |
|
if self.is_best_state: |
|
self._save_model() |
|
elif not self.patience_left: |
|
no_improve_steps = self.params.eval_steps * self.params.patience |
|
logger.info( |
|
"Early termination, as eval loss did not improve " |
|
f"over last {no_improve_steps} updates" |
|
) |
|
break |
|
self.update_idx += 1 |
|
self.epoch_idx += 1 |
|
|