Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import pytorch_lightning as pl | |
| from torch.optim.lr_scheduler import ReduceLROnPlateau | |
| from collections.abc import MutableMapping | |
| # from speechbrain.processing.speech_augmentation import SpeedPerturb | |
| def flatten_dict(d, parent_key="", sep="_"): | |
| """Flattens a dictionary into a single-level dictionary while preserving | |
| parent keys. Taken from | |
| `SO <https://stackoverflow.com/questions/6027558/flatten-nested-dictionaries-compressing-keys>`_ | |
| Args: | |
| d (MutableMapping): Dictionary to be flattened. | |
| parent_key (str): String to use as a prefix to all subsequent keys. | |
| sep (str): String to use as a separator between two key levels. | |
| Returns: | |
| dict: Single-level dictionary, flattened. | |
| """ | |
| items = [] | |
| for k, v in d.items(): | |
| new_key = parent_key + sep + k if parent_key else k | |
| if isinstance(v, MutableMapping): | |
| items.extend(flatten_dict(v, new_key, sep=sep).items()) | |
| else: | |
| items.append((new_key, v)) | |
| return dict(items) | |
| class AudioLightningModuleMultiDecoder(pl.LightningModule): | |
| def __init__( | |
| self, | |
| audio_model=None, | |
| video_model=None, | |
| optimizer=None, | |
| loss_func=None, | |
| train_loader=None, | |
| val_loader=None, | |
| test_loader=None, | |
| scheduler=None, | |
| config=None, | |
| ): | |
| super().__init__() | |
| self.audio_model = audio_model | |
| self.video_model = video_model | |
| self.optimizer = optimizer | |
| self.loss_func = loss_func | |
| self.train_loader = train_loader | |
| self.val_loader = val_loader | |
| self.test_loader = test_loader | |
| self.scheduler = scheduler | |
| self.config = {} if config is None else config | |
| # Speed Aug | |
| # self.speedperturb = SpeedPerturb( | |
| # self.config["datamodule"]["data_config"]["sample_rate"], | |
| # speeds=[95, 100, 105], | |
| # perturb_prob=1.0 | |
| # ) | |
| # Save lightning"s AttributeDict under self.hparams | |
| self.default_monitor = "val_loss/dataloader_idx_0" | |
| self.save_hyperparameters(self.config_to_hparams(self.config)) | |
| # self.print(self.audio_model) | |
| self.validation_step_outputs = [] | |
| self.test_step_outputs = [] | |
| def forward(self, wav, mouth=None): | |
| """Applies forward pass of the model. | |
| Returns: | |
| :class:`torch.Tensor` | |
| """ | |
| if self.training: | |
| return self.audio_model(wav) | |
| else: | |
| ests, num_blocks = self.audio_model(wav) | |
| assert ests.shape[0] % num_blocks == 0 | |
| batch_size = int(ests.shape[0] / num_blocks) | |
| return ests[-batch_size:] | |
| def training_step(self, batch, batch_nb): | |
| mixtures, targets, _ = batch | |
| new_targets = [] | |
| min_len = -1 | |
| # if self.config["training"]["SpeedAug"] == True: | |
| # with torch.no_grad(): | |
| # for i in range(targets.shape[1]): | |
| # new_target = self.speedperturb(targets[:, i, :]) | |
| # new_targets.append(new_target) | |
| # if i == 0: | |
| # min_len = new_target.shape[-1] | |
| # else: | |
| # if new_target.shape[-1] < min_len: | |
| # min_len = new_target.shape[-1] | |
| # targets = torch.zeros( | |
| # targets.shape[0], | |
| # targets.shape[1], | |
| # min_len, | |
| # device=targets.device, | |
| # dtype=torch.float, | |
| # ) | |
| # for i, new_target in enumerate(new_targets): | |
| # targets[:, i, :] = new_targets[i][:, 0:min_len] | |
| # mixtures = targets.sum(1) | |
| # print(mixtures.shape) | |
| est_sources, num_blocks = self(mixtures) | |
| assert est_sources.shape[0] % num_blocks == 0 | |
| batch_size = int(est_sources.shape[0] / num_blocks) | |
| ests_sources_each_block = [] | |
| for i in range(num_blocks): | |
| ests_sources_each_block.append(est_sources[i * batch_size : (i + 1) * batch_size]) | |
| loss = self.loss_func["train"](ests_sources_each_block, targets) | |
| self.log( | |
| "train_loss", | |
| loss, | |
| on_epoch=True, | |
| prog_bar=True, | |
| sync_dist=True, | |
| logger=True, | |
| ) | |
| return {"loss": loss} | |
| def validation_step(self, batch, batch_nb, dataloader_idx): | |
| # cal val loss | |
| if dataloader_idx == 0: | |
| mixtures, targets, _ = batch | |
| # print(mixtures.shape) | |
| est_sources = self(mixtures) | |
| loss = self.loss_func["val"](est_sources, targets) | |
| self.log( | |
| "val_loss", | |
| loss, | |
| on_epoch=True, | |
| prog_bar=True, | |
| sync_dist=True, | |
| logger=True, | |
| ) | |
| self.validation_step_outputs.append(loss) | |
| return {"val_loss": loss} | |
| # cal test loss | |
| if (self.trainer.current_epoch) % 10 == 0 and dataloader_idx == 1: | |
| mixtures, targets, _ = batch | |
| # print(mixtures.shape) | |
| est_sources = self(mixtures) | |
| tloss = self.loss_func["val"](est_sources, targets) | |
| self.log( | |
| "test_loss", | |
| tloss, | |
| on_epoch=True, | |
| prog_bar=True, | |
| sync_dist=True, | |
| logger=True, | |
| ) | |
| self.test_step_outputs.append(tloss) | |
| return {"test_loss": tloss} | |
| def on_validation_epoch_end(self): | |
| # val | |
| avg_loss = torch.stack(self.validation_step_outputs).mean() | |
| val_loss = torch.mean(self.all_gather(avg_loss)) | |
| self.log( | |
| "lr", | |
| self.optimizer.param_groups[0]["lr"], | |
| on_epoch=True, | |
| prog_bar=True, | |
| sync_dist=True, | |
| ) | |
| self.logger.experiment.log( | |
| {"learning_rate": self.optimizer.param_groups[0]["lr"], "epoch": self.current_epoch} | |
| ) | |
| self.logger.experiment.log( | |
| {"val_pit_sisnr": -val_loss, "epoch": self.current_epoch} | |
| ) | |
| # test | |
| if (self.trainer.current_epoch) % 10 == 0: | |
| avg_loss = torch.stack(self.test_step_outputs).mean() | |
| test_loss = torch.mean(self.all_gather(avg_loss)) | |
| self.logger.experiment.log( | |
| {"test_pit_sisnr": -test_loss, "epoch": self.current_epoch} | |
| ) | |
| self.validation_step_outputs.clear() # free memory | |
| self.test_step_outputs.clear() # free memory | |
| def configure_optimizers(self): | |
| """Initialize optimizers, batch-wise and epoch-wise schedulers.""" | |
| if self.scheduler is None: | |
| return self.optimizer | |
| if not isinstance(self.scheduler, (list, tuple)): | |
| self.scheduler = [self.scheduler] # support multiple schedulers | |
| epoch_schedulers = [] | |
| for sched in self.scheduler: | |
| if not isinstance(sched, dict): | |
| if isinstance(sched, ReduceLROnPlateau): | |
| sched = {"scheduler": sched, "monitor": self.default_monitor} | |
| epoch_schedulers.append(sched) | |
| else: | |
| sched.setdefault("monitor", self.default_monitor) | |
| sched.setdefault("frequency", 1) | |
| # Backward compat | |
| if sched["interval"] == "batch": | |
| sched["interval"] = "step" | |
| assert sched["interval"] in [ | |
| "epoch", | |
| "step", | |
| ], "Scheduler interval should be either step or epoch" | |
| epoch_schedulers.append(sched) | |
| return [self.optimizer], epoch_schedulers | |
| # def lr_scheduler_step(self, scheduler, optimizer_idx, metric): | |
| # if metric is None: | |
| # scheduler.step() | |
| # else: | |
| # scheduler.step(metric) | |
| def train_dataloader(self): | |
| """Training dataloader""" | |
| return self.train_loader | |
| def val_dataloader(self): | |
| """Validation dataloader""" | |
| return [self.val_loader, self.test_loader] | |
| def on_save_checkpoint(self, checkpoint): | |
| """Overwrite if you want to save more things in the checkpoint.""" | |
| checkpoint["training_config"] = self.config | |
| return checkpoint | |
| def config_to_hparams(dic): | |
| """Sanitizes the config dict to be handled correctly by torch | |
| SummaryWriter. It flatten the config dict, converts ``None`` to | |
| ``"None"`` and any list and tuple into torch.Tensors. | |
| Args: | |
| dic (dict): Dictionary to be transformed. | |
| Returns: | |
| dict: Transformed dictionary. | |
| """ | |
| dic = flatten_dict(dic) | |
| for k, v in dic.items(): | |
| if v is None: | |
| dic[k] = str(v) | |
| elif isinstance(v, (list, tuple)): | |
| dic[k] = torch.tensor(v) | |
| return dic | |