# Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # author: adiyoss import json import logging from pathlib import Path import os import time import torch import torch.nn.functional as F from . import augment, distrib, pretrained from .enhance import enhance from .evaluate import evaluate from .stft_loss import MultiResolutionSTFTLoss from .utils import bold, copy_state, pull_metric, serialize_model, swap_state, LogProgress logger = logging.getLogger(__name__) class Solver(object): def __init__(self, data, model, optimizer, args): self.tr_loader = data['tr_loader'] self.cv_loader = data['cv_loader'] self.tt_loader = data['tt_loader'] self.model = model self.dmodel = distrib.wrap(model) self.optimizer = optimizer # data augment augments = [] if args.remix: augments.append(augment.Remix()) if args.bandmask: augments.append(augment.BandMask(args.bandmask, sample_rate=args.sample_rate)) if args.shift: augments.append(augment.Shift(args.shift, args.shift_same)) if args.revecho: augments.append( augment.RevEcho(args.revecho)) self.augment = torch.nn.Sequential(*augments) # Training config self.device = args.device self.epochs = args.epochs # Checkpoints self.continue_from = args.continue_from self.eval_every = args.eval_every self.checkpoint = args.checkpoint if self.checkpoint: self.checkpoint_file = Path(args.checkpoint_file) self.best_file = Path(args.best_file) logger.debug("Checkpoint will be saved to %s", self.checkpoint_file.resolve()) self.history_file = args.history_file self.best_state = None self.restart = args.restart self.history = [] # Keep track of loss self.samples_dir = args.samples_dir # Where to save samples self.num_prints = args.num_prints # Number of times to log per epoch self.args = args self.mrstftloss = MultiResolutionSTFTLoss(factor_sc=args.stft_sc_factor, factor_mag=args.stft_mag_factor) self._reset() def _serialize(self): package = {} package['model'] = serialize_model(self.model) package['optimizer'] = self.optimizer.state_dict() package['history'] = self.history package['best_state'] = self.best_state package['args'] = self.args tmp_path = str(self.checkpoint_file) + ".tmp" torch.save(package, tmp_path) # renaming is sort of atomic on UNIX (not really true on NFS) # but still less chances of leaving a half written checkpoint behind. os.rename(tmp_path, self.checkpoint_file) # Saving only the latest best model. model = package['model'] model['state'] = self.best_state tmp_path = str(self.best_file) + ".tmp" torch.save(model, tmp_path) os.rename(tmp_path, self.best_file) def _reset(self): """_reset.""" load_from = None load_best = False keep_history = True # Reset if self.checkpoint and self.checkpoint_file.exists() and not self.restart: load_from = self.checkpoint_file elif self.continue_from: load_from = self.continue_from load_best = self.args.continue_best keep_history = False if load_from: logger.info(f'Loading checkpoint model: {load_from}') package = torch.load(load_from, 'cpu') if load_best: self.model.load_state_dict(package['best_state']) else: self.model.load_state_dict(package['model']['state']) if 'optimizer' in package and not load_best: self.optimizer.load_state_dict(package['optimizer']) if keep_history: self.history = package['history'] self.best_state = package['best_state'] continue_pretrained = self.args.continue_pretrained if continue_pretrained: logger.info("Fine tuning from pre-trained model %s", continue_pretrained) model = getattr(pretrained, self.args.continue_pretrained)() self.model.load_state_dict(model.state_dict()) def train(self): # Optimizing the model if self.history: logger.info("Replaying metrics from previous run") for epoch, metrics in enumerate(self.history): info = " ".join(f"{k.capitalize()}={v:.5f}" for k, v in metrics.items()) logger.info(f"Epoch {epoch + 1}: {info}") for epoch in range(len(self.history), self.epochs): # Train one epoch self.model.train() start = time.time() logger.info('-' * 70) logger.info("Training...") train_loss = self._run_one_epoch(epoch) logger.info( bold(f'Train Summary | End of Epoch {epoch + 1} | ' f'Time {time.time() - start:.2f}s | Train Loss {train_loss:.5f}')) if self.cv_loader: # Cross validation logger.info('-' * 70) logger.info('Cross validation...') self.model.eval() with torch.no_grad(): valid_loss = self._run_one_epoch(epoch, cross_valid=True) logger.info( bold(f'Valid Summary | End of Epoch {epoch + 1} | ' f'Time {time.time() - start:.2f}s | Valid Loss {valid_loss:.5f}')) else: valid_loss = 0 best_loss = min(pull_metric(self.history, 'valid') + [valid_loss]) metrics = {'train': train_loss, 'valid': valid_loss, 'best': best_loss} # Save the best model if valid_loss == best_loss: logger.info(bold('New best valid loss %.4f'), valid_loss) self.best_state = copy_state(self.model.state_dict()) # evaluate and enhance samples every 'eval_every' argument number of epochs # also evaluate on last epoch if (epoch + 1) % self.eval_every == 0 or epoch == self.epochs - 1: # Evaluate on the testset logger.info('-' * 70) logger.info('Evaluating on the test set...') # We switch to the best known model for testing with swap_state(self.model, self.best_state): pesq, stoi = evaluate(self.args, self.model, self.tt_loader) metrics.update({'pesq': pesq, 'stoi': stoi}) # enhance some samples logger.info('Enhance and save samples...') enhance(self.args, self.model, self.samples_dir) self.history.append(metrics) info = " | ".join(f"{k.capitalize()} {v:.5f}" for k, v in metrics.items()) logger.info('-' * 70) logger.info(bold(f"Overall Summary | Epoch {epoch + 1} | {info}")) if distrib.rank == 0: json.dump(self.history, open(self.history_file, "w"), indent=2) # Save model each epoch if self.checkpoint: self._serialize() logger.debug("Checkpoint saved to %s", self.checkpoint_file.resolve()) def _run_one_epoch(self, epoch, cross_valid=False): total_loss = 0 data_loader = self.tr_loader if not cross_valid else self.cv_loader # get a different order for distributed training, otherwise this will get ignored data_loader.epoch = epoch label = ["Train", "Valid"][cross_valid] name = label + f" | Epoch {epoch + 1}" logprog = LogProgress(logger, data_loader, updates=self.num_prints, name=name) for i, data in enumerate(logprog): noisy, clean = [x.to(self.device) for x in data] if not cross_valid: sources = torch.stack([noisy - clean, clean]) sources = self.augment(sources) noise, clean = sources noisy = noise + clean estimate = self.dmodel(noisy) # apply a loss function after each layer with torch.autograd.set_detect_anomaly(True): if self.args.loss == 'l1': loss = F.l1_loss(clean, estimate) elif self.args.loss == 'l2': loss = F.mse_loss(clean, estimate) elif self.args.loss == 'huber': loss = F.smooth_l1_loss(clean, estimate) else: raise ValueError(f"Invalid loss {self.args.loss}") # MultiResolution STFT loss if self.args.stft_loss: sc_loss, mag_loss = self.mrstftloss(estimate.squeeze(1), clean.squeeze(1)) loss += sc_loss + mag_loss # optimize model in training mode if not cross_valid: self.optimizer.zero_grad() loss.backward() self.optimizer.step() total_loss += loss.item() logprog.update(loss=format(total_loss / (i + 1), ".5f")) # Just in case, clear some memory del loss, estimate return distrib.average([total_loss / (i + 1)], i + 1)[0]