import logging import numpy as np from tqdm import tqdm import torch from torch import optim from torch.nn import functional as F from torch.utils.data import DataLoader from models.base import BaseLearner from utils.inc_net import ( IncrementalNet, CosineIncrementalNet, SimpleCosineIncrementalNet, ) from utils.toolkit import target2onehot, tensor2numpy import ot from torch import nn import copy EPSILON = 1e-8 epochs = 100 lrate = 0.1 milestones = [40, 80] lrate_decay = 0.1 batch_size = 32 memory_size = 2000 T = 2 class COIL(BaseLearner): def __init__(self, args): super().__init__(args) self._network = SimpleCosineIncrementalNet(args, False) self.data_manager = None self.nextperiod_initialization = None self.sinkhorn_reg = args["sinkhorn"] self.calibration_term = args["calibration_term"] self.args = args def after_task(self): self.nextperiod_initialization = self.solving_ot() self._old_network = self._network.copy().freeze() self._known_classes = self._total_classes def solving_ot(self): with torch.no_grad(): if self._total_classes == self.data_manager.get_total_classnum(): print("training over, no more ot solving") return None each_time_class_num = self.data_manager.get_task_size(1) self._extract_class_means( self.data_manager, 0, self._total_classes + each_time_class_num ) former_class_means = torch.tensor( self._ot_prototype_means[: self._total_classes] ) next_period_class_means = torch.tensor( self._ot_prototype_means[ self._total_classes : self._total_classes + each_time_class_num ] ) Q_cost_matrix = torch.cdist( former_class_means, next_period_class_means, p=self.args["norm_term"] ) # solving ot _mu1_vec = ( torch.ones(len(former_class_means)) / len(former_class_means) * 1.0 ) _mu2_vec = ( torch.ones(len(next_period_class_means)) / len(former_class_means) * 1.0 ) T = ot.sinkhorn(_mu1_vec, _mu2_vec, Q_cost_matrix, self.sinkhorn_reg) T = torch.tensor(T).float().cuda() transformed_hat_W = torch.mm( T.T, F.normalize(self._network.fc.weight, p=2, dim=1) ) oldnorm = torch.norm(self._network.fc.weight, p=2, dim=1) newnorm = torch.norm( transformed_hat_W * len(former_class_means), p=2, dim=1 ) meannew = torch.mean(newnorm) meanold = torch.mean(oldnorm) gamma = meanold / meannew self.calibration_term = gamma self._ot_new_branch = ( transformed_hat_W * len(former_class_means) * self.calibration_term ) return transformed_hat_W * len(former_class_means) * self.calibration_term def solving_ot_to_old(self): current_class_num = self.data_manager.get_task_size(self._cur_task) self._extract_class_means_with_memory( self.data_manager, self._known_classes, self._total_classes ) former_class_means = torch.tensor( self._ot_prototype_means[: self._known_classes] ) next_period_class_means = torch.tensor( self._ot_prototype_means[self._known_classes : self._total_classes] ) Q_cost_matrix = ( torch.cdist( next_period_class_means, former_class_means, p=self.args["norm_term"] ) + EPSILON ) # in case of numerical err _mu1_vec = torch.ones(len(former_class_means)) / len(former_class_means) * 1.0 _mu2_vec = ( torch.ones(len(next_period_class_means)) / len(former_class_means) * 1.0 ) T = ot.sinkhorn(_mu2_vec, _mu1_vec, Q_cost_matrix, self.sinkhorn_reg) T = torch.tensor(T).float().cuda() transformed_hat_W = torch.mm( T.T, F.normalize(self._network.fc.weight[-current_class_num:, :], p=2, dim=1), ) return transformed_hat_W * len(former_class_means) * self.calibration_term def incremental_train(self, data_manager): self._cur_task += 1 self._total_classes = self._known_classes + data_manager.get_task_size( self._cur_task ) self._network.update_fc(self._total_classes, self.nextperiod_initialization) self.data_manager = data_manager logging.info( "Learning on {}-{}".format(self._known_classes, self._total_classes) ) self.lamda = self._known_classes / self._total_classes # Loader train_dataset = data_manager.get_dataset( np.arange(self._known_classes, self._total_classes), source="train", mode="train", appendent=self._get_memory(), ) self.train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=4 ) test_dataset = data_manager.get_dataset( np.arange(0, self._total_classes), source="test", mode="test" ) self.test_loader = DataLoader( test_dataset, batch_size=batch_size, shuffle=False, num_workers=4 ) self._train(self.train_loader, self.test_loader) if self.args['fixed_memory']: examplar_size = self.args["memory_per_class"] else: examplar_size = memory_size // self._total_classes self._reduce_exemplar(data_manager, examplar_size) self._construct_exemplar(data_manager, examplar_size) def _train(self, train_loader, test_loader): self._network.to(self._device) if self._old_network is not None: self._old_network.to(self._device) optimizer = optim.SGD( self._network.parameters(), lr=lrate, momentum=0.9, weight_decay=5e-4 ) # 1e-5 scheduler = optim.lr_scheduler.MultiStepLR( optimizer=optimizer, milestones=milestones, gamma=lrate_decay ) self._update_representation(train_loader, test_loader, optimizer, scheduler) def _update_representation(self, train_loader, test_loader, optimizer, scheduler): prog_bar = tqdm(range(epochs)) for _, epoch in enumerate(prog_bar): weight_ot_init = max(1.0 - (epoch / 2) ** 2, 0) weight_ot_co_tuning = (epoch / epochs) ** 2.0 self._network.train() losses = 0.0 correct, total = 0, 0 for i, (_, inputs, targets) in enumerate(train_loader): inputs, targets = inputs.to(self._device), targets.to(self._device) output = self._network(inputs) logits = output["logits"] onehots = target2onehot(targets, self._total_classes) clf_loss = F.cross_entropy(logits, targets) if self._old_network is not None: old_logits = self._old_network(inputs)["logits"].detach() hat_pai_k = F.softmax(old_logits / T, dim=1) log_pai_k = F.log_softmax( logits[:, : self._known_classes] / T, dim=1 ) distill_loss = -torch.mean(torch.sum(hat_pai_k * log_pai_k, dim=1)) if epoch < 1: features = F.normalize(output["features"], p=2, dim=1) current_logit_new = F.log_softmax( logits[:, self._known_classes :] / T, dim=1 ) new_logit_by_wnew_init_by_ot = F.linear( features, F.normalize(self._ot_new_branch, p=2, dim=1) ) new_logit_by_wnew_init_by_ot = F.softmax( new_logit_by_wnew_init_by_ot / T, dim=1 ) new_branch_distill_loss = -torch.mean( torch.sum( current_logit_new * new_logit_by_wnew_init_by_ot, dim=1 ) ) loss = ( distill_loss * self.lamda + clf_loss * (1 - self.lamda) + 0.001 * (weight_ot_init * new_branch_distill_loss) ) else: features = F.normalize(output["features"], p=2, dim=1) if i % 30 == 0: with torch.no_grad(): self._ot_old_branch = self.solving_ot_to_old() old_logit_by_wold_init_by_ot = F.linear( features, F.normalize(self._ot_old_branch, p=2, dim=1) ) old_logit_by_wold_init_by_ot = F.log_softmax( old_logit_by_wold_init_by_ot / T, dim=1 ) old_branch_distill_loss = -torch.mean( torch.sum(hat_pai_k * old_logit_by_wold_init_by_ot, dim=1) ) loss = ( distill_loss * self.lamda + clf_loss * (1 - self.lamda) + self.args["reg_term"] * (weight_ot_co_tuning * old_branch_distill_loss) ) else: loss = clf_loss optimizer.zero_grad() loss.backward() optimizer.step() losses += loss.item() _, preds = torch.max(logits, dim=1) correct += preds.eq(targets.expand_as(preds)).cpu().sum() total += len(targets) scheduler.step() train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) test_acc = self._compute_accuracy(self._network, test_loader) info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( self._cur_task, epoch + 1, epochs, losses / len(train_loader), train_acc, test_acc, ) prog_bar.set_description(info) logging.info(info) def _extract_class_means(self, data_manager, low, high): self._ot_prototype_means = np.zeros( (data_manager.get_total_classnum(), self._network.feature_dim) ) with torch.no_grad(): for class_idx in range(low, high): data, targets, idx_dataset = data_manager.get_dataset( np.arange(class_idx, class_idx + 1), source="train", mode="test", ret_data=True, ) idx_loader = DataLoader( idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4 ) vectors, _ = self._extract_vectors(idx_loader) vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T class_mean = np.mean(vectors, axis=0) class_mean = class_mean / (np.linalg.norm(class_mean)) self._ot_prototype_means[class_idx, :] = class_mean self._network.train() def _extract_class_means_with_memory(self, data_manager, low, high): self._ot_prototype_means = np.zeros( (data_manager.get_total_classnum(), self._network.feature_dim) ) memoryx, memoryy = self._data_memory, self._targets_memory with torch.no_grad(): for class_idx in range(0, low): idxes = np.where( np.logical_and(memoryy >= class_idx, memoryy < class_idx + 1) )[0] data, targets = memoryx[idxes], memoryy[idxes] # idx_dataset=TensorDataset(data,targets) # idx_loader = DataLoader(idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4) _, _, idx_dataset = data_manager.get_dataset( [], source="train", appendent=(data, targets), mode="test", ret_data=True, ) idx_loader = DataLoader( idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4 ) vectors, _ = self._extract_vectors(idx_loader) vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T class_mean = np.mean(vectors, axis=0) class_mean = class_mean / np.linalg.norm(class_mean) self._ot_prototype_means[class_idx, :] = class_mean for class_idx in range(low, high): data, targets, idx_dataset = data_manager.get_dataset( np.arange(class_idx, class_idx + 1), source="train", mode="test", ret_data=True, ) idx_loader = DataLoader( idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4 ) vectors, _ = self._extract_vectors(idx_loader) vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T class_mean = np.mean(vectors, axis=0) class_mean = class_mean / np.linalg.norm(class_mean) self._ot_prototype_means[class_idx, :] = class_mean self._network.train()