import logging import numpy as np import torch from torch import nn from torch import optim from torch.nn import functional as F from torch.utils.data import DataLoader from models.base import BaseLearner from utils.inc_net import IncrementalNetWithBias epochs = 170 lrate = 0.1 milestones = [60, 100, 140] lrate_decay = 0.1 batch_size = 128 split_ratio = 0.1 T = 2 weight_decay = 2e-4 num_workers = 8 class BiC(BaseLearner): def __init__(self, args): super().__init__(args) self._network = IncrementalNetWithBias( args, False, bias_correction=True ) self._class_means = None def after_task(self): self._old_network = self._network.copy().freeze() self._known_classes = self._total_classes logging.info("Exemplar size: {}".format(self.exemplar_size)) def incremental_train(self, data_manager): self._cur_task += 1 self._total_classes = self._known_classes + data_manager.get_task_size( self._cur_task ) self._network.update_fc(self._total_classes) logging.info( "Learning on {}-{}".format(self._known_classes, self._total_classes) ) if self._cur_task >= 1: train_dset, val_dset = data_manager.get_dataset_with_split( np.arange(self._known_classes, self._total_classes), source="train", mode="train", appendent=self._get_memory(), val_samples_per_class=int( split_ratio * self._memory_size / self._known_classes ), ) self.val_loader = DataLoader( val_dset, batch_size=batch_size, shuffle=True, num_workers=num_workers ) logging.info( "Stage1 dset: {}, Stage2 dset: {}".format( len(train_dset), len(val_dset) ) ) self.lamda = self._known_classes / self._total_classes logging.info("Lambda: {:.3f}".format(self.lamda)) else: train_dset = data_manager.get_dataset( np.arange(self._known_classes, self._total_classes), source="train", mode="train", appendent=self._get_memory(), ) test_dset = data_manager.get_dataset( np.arange(0, self._total_classes), source="test", mode="test" ) self.train_loader = DataLoader( train_dset, batch_size=batch_size, shuffle=True, num_workers=num_workers ) self.test_loader = DataLoader( test_dset, batch_size=batch_size, shuffle=False, num_workers=num_workers ) self._log_bias_params() self._stage1_training(self.train_loader, self.test_loader) if self._cur_task >= 1: self._stage2_bias_correction(self.val_loader, self.test_loader) self.build_rehearsal_memory(data_manager, self.samples_per_class) if len(self._multiple_gpus) > 1: self._network = self._network.module self._log_bias_params() def _run(self, train_loader, test_loader, optimizer, scheduler, stage): for epoch in range(1, epochs + 1): self._network.train() losses = 0.0 for i, (_, inputs, targets) in enumerate(train_loader): inputs, targets = inputs.to(self._device), targets.to(self._device) logits = self._network(inputs)["logits"] if stage == "training": clf_loss = F.cross_entropy(logits, targets) if self._old_network is not None: old_logits = self._old_network(inputs)["logits"].detach() hat_pai_k = F.softmax(old_logits / T, dim=1) log_pai_k = F.log_softmax( logits[:, : self._known_classes] / T, dim=1 ) distill_loss = -torch.mean( torch.sum(hat_pai_k * log_pai_k, dim=1) ) loss = distill_loss * self.lamda + clf_loss * (1 - self.lamda) else: loss = clf_loss elif stage == "bias_correction": loss = F.cross_entropy(torch.softmax(logits, dim=1), targets) else: raise NotImplementedError() optimizer.zero_grad() loss.backward() optimizer.step() losses += loss.item() scheduler.step() train_acc = self._compute_accuracy(self._network, train_loader) test_acc = self._compute_accuracy(self._network, test_loader) info = "{} => Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.3f}, Test_accy {:.3f}".format( stage, self._cur_task, epoch, epochs, losses / len(train_loader), train_acc, test_acc, ) logging.info(info) def _stage1_training(self, train_loader, test_loader): """ if self._cur_task == 0: loaded_dict = torch.load('./dict_0.pkl') self._network.load_state_dict(loaded_dict['model_state_dict']) self._network.to(self._device) return """ ignored_params = list(map(id, self._network.bias_layers.parameters())) base_params = filter( lambda p: id(p) not in ignored_params, self._network.parameters() ) network_params = [ {"params": base_params, "lr": lrate, "weight_decay": weight_decay}, { "params": self._network.bias_layers.parameters(), "lr": 0, "weight_decay": 0, }, ] optimizer = optim.SGD( network_params, lr=lrate, momentum=0.9, weight_decay=weight_decay ) scheduler = optim.lr_scheduler.MultiStepLR( optimizer=optimizer, milestones=milestones, gamma=lrate_decay ) if len(self._multiple_gpus) > 1: self._network = nn.DataParallel(self._network, self._multiple_gpus) self._network.to(self._device) if self._old_network is not None: self._old_network.to(self._device) self._run(train_loader, test_loader, optimizer, scheduler, stage="training") def _stage2_bias_correction(self, val_loader, test_loader): if isinstance(self._network, nn.DataParallel): self._network = self._network.module network_params = [ { "params": self._network.bias_layers[-1].parameters(), "lr": lrate, "weight_decay": weight_decay, } ] optimizer = optim.SGD( network_params, lr=lrate, momentum=0.9, weight_decay=weight_decay ) scheduler = optim.lr_scheduler.MultiStepLR( optimizer=optimizer, milestones=milestones, gamma=lrate_decay ) if len(self._multiple_gpus) > 1: self._network = nn.DataParallel(self._network, self._multiple_gpus) self._network.to(self._device) self._run( val_loader, test_loader, optimizer, scheduler, stage="bias_correction" ) def _log_bias_params(self): logging.info("Parameters of bias layer:") params = self._network.get_bias_params() for i, param in enumerate(params): logging.info("{} => {:.3f}, {:.3f}".format(i, param[0], param[1]))