"""A generic training wrapper.""" from copy import deepcopy import logging from typing import Callable, List, Optional import torch from torch.utils.data import DataLoader LOGGER = logging.getLogger(__name__) class Trainer: def __init__( self, epochs: int = 20, batch_size: int = 32, device: str = "cpu", optimizer_fn: Callable = torch.optim.Adam, optimizer_kwargs: dict = {"lr": 1e-3}, use_scheduler: bool = False, ) -> None: self.epochs = epochs self.batch_size = batch_size self.device = device self.optimizer_fn = optimizer_fn self.optimizer_kwargs = optimizer_kwargs self.epoch_test_losses: List[float] = [] self.use_scheduler = use_scheduler def forward_and_loss(model, criterion, batch_x, batch_y, **kwargs): batch_out = model(batch_x) batch_loss = criterion(batch_out, batch_y) return batch_out, batch_loss class GDTrainer(Trainer): def train( self, dataset: torch.utils.data.Dataset, model: torch.nn.Module, test_len: Optional[float] = None, test_dataset: Optional[torch.utils.data.Dataset] = None, ): if test_dataset is not None: train = dataset test = test_dataset else: test_len = int(len(dataset) * test_len) train_len = len(dataset) - test_len lengths = [train_len, test_len] train, test = torch.utils.data.random_split(dataset, lengths) train_loader = DataLoader( train, batch_size=self.batch_size, shuffle=True, drop_last=True, num_workers=6, ) test_loader = DataLoader( test, batch_size=self.batch_size, shuffle=True, drop_last=True, num_workers=6, ) criterion = torch.nn.BCEWithLogitsLoss() optim = self.optimizer_fn(model.parameters(), **self.optimizer_kwargs) best_model = None best_acc = 0 LOGGER.info(f"Starting training for {self.epochs} epochs!") forward_and_loss_fn = forward_and_loss if self.use_scheduler: batches_per_epoch = len(train_loader) * 2 # every 2nd epoch scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer=optim, T_0=batches_per_epoch, T_mult=1, eta_min=5e-6, # verbose=True, ) use_cuda = self.device != "cpu" for epoch in range(self.epochs): LOGGER.info(f"Epoch num: {epoch}") running_loss = 0 num_correct = 0.0 num_total = 0.0 model.train() for i, (batch_x, _, batch_y) in enumerate(train_loader): batch_size = batch_x.size(0) num_total += batch_size batch_x = batch_x.to(self.device) batch_y = batch_y.unsqueeze(1).type(torch.float32).to(self.device) batch_out, batch_loss = forward_and_loss_fn( model, criterion, batch_x, batch_y, use_cuda=use_cuda ) batch_pred = (torch.sigmoid(batch_out) + 0.5).int() num_correct += (batch_pred == batch_y.int()).sum(dim=0).item() running_loss += batch_loss.item() * batch_size if i % 100 == 0: LOGGER.info( f"[{epoch:04d}][{i:05d}]: {running_loss / num_total} {num_correct/num_total*100}" ) optim.zero_grad() batch_loss.backward() optim.step() if self.use_scheduler: scheduler.step() running_loss /= num_total train_accuracy = (num_correct / num_total) * 100 LOGGER.info( f"Epoch [{epoch+1}/{self.epochs}]: train/loss: {running_loss}, train/accuracy: {train_accuracy}" ) test_running_loss = 0.0 num_correct = 0.0 num_total = 0.0 model.eval() eer_val = 0 for batch_x, _, batch_y in test_loader: batch_size = batch_x.size(0) num_total += batch_size batch_x = batch_x.to(self.device) with torch.no_grad(): batch_pred = model(batch_x) batch_y = batch_y.unsqueeze(1).type(torch.float32).to(self.device) batch_loss = criterion(batch_pred, batch_y) test_running_loss += batch_loss.item() * batch_size batch_pred = torch.sigmoid(batch_pred) batch_pred_label = (batch_pred + 0.5).int() num_correct += (batch_pred_label == batch_y.int()).sum(dim=0).item() if num_total == 0: num_total = 1 test_running_loss /= num_total test_acc = 100 * (num_correct / num_total) LOGGER.info( f"Epoch [{epoch+1}/{self.epochs}]: test/loss: {test_running_loss}, test/accuracy: {test_acc}, test/eer: {eer_val}" ) if best_model is None or test_acc > best_acc: best_acc = test_acc best_model = deepcopy(model.state_dict()) LOGGER.info( f"[{epoch:04d}]: {running_loss} - train acc: {train_accuracy} - test_acc: {test_acc}" ) model.load_state_dict(best_model) return model