import math import logging import numpy as np import torch from torch import optim from torch.nn import functional as F from torch.utils.data import DataLoader from models.base import BaseLearner from utils.inc_net import CosineIncrementalNet from utils.toolkit import tensor2numpy epochs = 100 lrate = 0.1 ft_epochs = 20 ft_lrate = 0.005 batch_size = 32 lambda_c_base = 5 lambda_f_base = 1 nb_proxy = 10 weight_decay = 5e-4 num_workers = 4 """ Distillation losses: POD-flat (lambda_f=1) + POD-spatial (lambda_c=5) NME results are shown. The reproduced results are not in line with the reported results. Maybe I missed something... +--------------------+--------------------+--------------------+--------------------+ | Classifier | Steps | Reported (%) | Reproduced (%) | +--------------------+--------------------+--------------------+--------------------+ | Cosine (k=1) | 50 | 56.69 | 55.49 | +--------------------+--------------------+--------------------+--------------------+ | LSC-CE (k=10) | 50 | 59.86 | 55.69 | +--------------------+--------------------+--------------------+--------------------+ | LSC-NCA (k=10) | 50 | 61.40 | 56.50 | +--------------------+--------------------+--------------------+--------------------+ | LSC-CE (k=10) | 25 | ----- | 59.16 | +--------------------+--------------------+--------------------+--------------------+ | LSC-NCA (k=10) | 25 | 62.71 | 59.79 | +--------------------+--------------------+--------------------+--------------------+ | LSC-CE (k=10) | 10 | ----- | 62.59 | +--------------------+--------------------+--------------------+--------------------+ | LSC-NCA (k=10) | 10 | 64.03 | 62.81 | +--------------------+--------------------+--------------------+--------------------+ | LSC-CE (k=10) | 5 | ----- | 64.16 | +--------------------+--------------------+--------------------+--------------------+ | LSC-NCA (k=10) | 5 | 64.48 | 64.37 | +--------------------+--------------------+--------------------+--------------------+ """ class PODNet(BaseLearner): def __init__(self, args): super().__init__(args) self._network = CosineIncrementalNet( args, pretrained=False, nb_proxy=nb_proxy ) self._class_means = None def after_task(self): self._old_network = self._network.copy().freeze() self._known_classes = self._total_classes logging.info("Exemplar size: {}".format(self.exemplar_size)) def incremental_train(self, data_manager): self._cur_task += 1 self._total_classes = self._known_classes + data_manager.get_task_size( self._cur_task ) self.task_size = self._total_classes - self._known_classes self._network.update_fc(self._total_classes, self._cur_task) logging.info( "Learning on {}-{}".format(self._known_classes, self._total_classes) ) train_dset = data_manager.get_dataset( np.arange(self._known_classes, self._total_classes), source="train", mode="train", appendent=self._get_memory(), ) test_dset = data_manager.get_dataset( np.arange(0, self._total_classes), source="test", mode="test" ) self.train_loader = DataLoader( train_dset, batch_size=batch_size, shuffle=True, num_workers=num_workers ) self.test_loader = DataLoader( test_dset, batch_size=batch_size, shuffle=False, num_workers=num_workers ) self._train(data_manager, self.train_loader, self.test_loader) self.build_rehearsal_memory(data_manager, self.samples_per_class) def _train(self, data_manager, train_loader, test_loader): if self._cur_task == 0: self.factor = 0 else: self.factor = math.sqrt( self._total_classes / (self._total_classes - self._known_classes) ) logging.info("Adaptive factor: {}".format(self.factor)) self._network.to(self._device) if self._old_network is not None: self._old_network.to(self._device) if self._cur_task == 0: network_params = self._network.parameters() else: ignored_params = list(map(id, self._network.fc.fc1.parameters())) base_params = filter( lambda p: id(p) not in ignored_params, self._network.parameters() ) network_params = [ {"params": base_params, "lr": lrate, "weight_decay": weight_decay}, { "params": self._network.fc.fc1.parameters(), "lr": 0, "weight_decay": 0, }, ] optimizer = optim.SGD( network_params, lr=lrate, momentum=0.9, weight_decay=weight_decay ) scheduler = optim.lr_scheduler.CosineAnnealingLR( optimizer=optimizer, T_max=epochs ) self._run(train_loader, test_loader, optimizer, scheduler, epochs) if self._cur_task == 0: return logging.info( "Finetune the network (classifier part) with the undersampled dataset!" ) if self._fixed_memory: finetune_samples_per_class = self._memory_per_class self._construct_exemplar_unified(data_manager, finetune_samples_per_class) else: finetune_samples_per_class = self._memory_size // self._known_classes self._reduce_exemplar(data_manager, finetune_samples_per_class) self._construct_exemplar(data_manager, finetune_samples_per_class) finetune_train_dataset = data_manager.get_dataset( [], source="train", mode="train", appendent=self._get_memory() ) finetune_train_loader = DataLoader( finetune_train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, ) logging.info( "The size of finetune dataset: {}".format(len(finetune_train_dataset)) ) ignored_params = list(map(id, self._network.fc.fc1.parameters())) base_params = filter( lambda p: id(p) not in ignored_params, self._network.parameters() ) network_params = [ {"params": base_params, "lr": ft_lrate, "weight_decay": weight_decay}, {"params": self._network.fc.fc1.parameters(), "lr": 0, "weight_decay": 0}, ] optimizer = optim.SGD( network_params, lr=ft_lrate, momentum=0.9, weight_decay=weight_decay ) scheduler = optim.lr_scheduler.CosineAnnealingLR( optimizer=optimizer, T_max=ft_epochs ) self._run(finetune_train_loader, test_loader, optimizer, scheduler, ft_epochs) if self._fixed_memory: self._data_memory = self._data_memory[ : -self._memory_per_class * self.task_size ] self._targets_memory = self._targets_memory[ : -self._memory_per_class * self.task_size ] assert ( len( np.setdiff1d( self._targets_memory, np.arange(0, self._known_classes) ) ) == 0 ), "Exemplar error!" def _run(self, train_loader, test_loader, optimizer, scheduler, epk): for epoch in range(1, epk + 1): self._network.train() lsc_losses = 0.0 spatial_losses = 0.0 flat_losses = 0.0 correct, total = 0, 0 for i, (_, inputs, targets) in enumerate(train_loader): inputs, targets = inputs.to(self._device), targets.to(self._device) outputs = self._network(inputs) logits = outputs["logits"] features = outputs["features"] fmaps = outputs["fmaps"] lsc_loss = nca(logits, targets) spatial_loss = 0.0 flat_loss = 0.0 if self._old_network is not None: with torch.no_grad(): old_outputs = self._old_network(inputs) old_features = old_outputs["features"] old_fmaps = old_outputs["fmaps"] flat_loss = ( F.cosine_embedding_loss( features, old_features.detach(), torch.ones(inputs.shape[0]).to(self._device), ) * self.factor * lambda_f_base ) spatial_loss = ( pod_spatial_loss(fmaps, old_fmaps) * self.factor * lambda_c_base ) loss = lsc_loss + flat_loss + spatial_loss optimizer.zero_grad() loss.backward() optimizer.step() lsc_losses += lsc_loss.item() spatial_losses += ( spatial_loss.item() if self._cur_task != 0 else spatial_loss ) flat_losses += flat_loss.item() if self._cur_task != 0 else flat_loss _, preds = torch.max(logits, dim=1) correct += preds.eq(targets.expand_as(preds)).cpu().sum() total += len(targets) if scheduler is not None: scheduler.step() train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) test_acc = self._compute_accuracy(self._network, test_loader) info1 = "Task {}, Epoch {}/{} (LR {:.5f}) => ".format( self._cur_task, epoch, epk, optimizer.param_groups[0]["lr"] ) info2 = "LSC_loss {:.2f}, Spatial_loss {:.2f}, Flat_loss {:.2f}, Train_acc {:.2f}, Test_acc {:.2f}".format( lsc_losses / (i + 1), spatial_losses / (i + 1), flat_losses / (i + 1), train_acc, test_acc, ) logging.info(info1 + info2) def pod_spatial_loss(old_fmaps, fmaps, normalize=True): """ a, b: list of [bs, c, w, h] """ loss = torch.tensor(0.0).to(fmaps[0].device) for i, (a, b) in enumerate(zip(old_fmaps, fmaps)): assert a.shape == b.shape, "Shape error" a = torch.pow(a, 2) b = torch.pow(b, 2) a_h = a.sum(dim=3).view(a.shape[0], -1) # [bs, c*w] b_h = b.sum(dim=3).view(b.shape[0], -1) # [bs, c*w] a_w = a.sum(dim=2).view(a.shape[0], -1) # [bs, c*h] b_w = b.sum(dim=2).view(b.shape[0], -1) # [bs, c*h] a = torch.cat([a_h, a_w], dim=-1) b = torch.cat([b_h, b_w], dim=-1) if normalize: a = F.normalize(a, dim=1, p=2) b = F.normalize(b, dim=1, p=2) layer_loss = torch.mean(torch.frobenius_norm(a - b, dim=-1)) loss += layer_loss return loss / len(fmaps) def nca( similarities, targets, class_weights=None, focal_gamma=None, scale=1.0, margin=0.6, exclude_pos_denominator=True, hinge_proxynca=False, memory_flags=None, ): margins = torch.zeros_like(similarities) margins[torch.arange(margins.shape[0]), targets] = margin similarities = scale * (similarities - margin) if exclude_pos_denominator: similarities = similarities - similarities.max(1)[0].view(-1, 1) disable_pos = torch.zeros_like(similarities) disable_pos[torch.arange(len(similarities)), targets] = similarities[ torch.arange(len(similarities)), targets ] numerator = similarities[torch.arange(similarities.shape[0]), targets] denominator = similarities - disable_pos losses = numerator - torch.log(torch.exp(denominator).sum(-1)) if class_weights is not None: losses = class_weights[targets] * losses losses = -losses if hinge_proxynca: losses = torch.clamp(losses, min=0.0) loss = torch.mean(losses) return loss return F.cross_entropy( similarities, targets, weight=class_weights, reduction="mean" )