import copy import logging import numpy as np import torch from torch import nn import torch.nn.functional as F from torch.utils.data import DataLoader from models.foster import FOSTER from utils.toolkit import count_parameters, tensor2numpy, accuracy from utils.inc_net import IncrementalNet from scipy.spatial.distance import cdist from models.base import BaseLearner from models.icarl import iCaRL from tqdm import tqdm import torch.optim as optim EPSILON = 1e-8 batch_size = 32 weight_decay = 2e-4 num_workers = 8 class RMMBase(BaseLearner): def __init__(self, args): self._args = args self._m_rate_list = args.get("m_rate_list", []) self._c_rate_list = args.get("c_rate_list", []) @property def samples_per_class(self): return int(self.memory_size // self._total_classes) @property def memory_size(self): if self._args["dataset"] == "cifar100": img_per_cls = 500 else: img_per_cls = 1300 if self._m_rate_list[self._cur_task] != 0: print(self._total_classes) self._memory_size = min(int(self._total_classes*img_per_cls-1),self._args["memory_size"] + int( self._m_rate_list[self._cur_task] * self._args["increment"] * img_per_cls )) return self._memory_size @property def new_memory_size(self): if self._args["dataset"] == "cifar100": img_per_cls = 500 else: img_per_cls = 1300 return int( (1 - self._m_rate_list[self._cur_task]) * self._args["increment"] * img_per_cls ) def build_rehearsal_memory(self, data_manager, per_class): self._reduce_exemplar(data_manager, per_class) self._construct_exemplar(data_manager, per_class) def _construct_exemplar(self, data_manager, m): if self._args["dataset"] == "cifar100": img_per_cls = 500 else: img_per_cls = 1300 ns = [ min(img_per_cls,int(m * (1 - self._c_rate_list[self._cur_task]))), min(img_per_cls,int(m * (1 + self._c_rate_list[self._cur_task]))), ] logging.info( "Constructing exemplars...({} or {} per classes)".format(ns[0], ns[1]) ) all_cls_entropies = [] ms = [] for class_idx in range(self._known_classes, self._total_classes): data, targets, idx_dataset = data_manager.get_dataset( np.arange(class_idx, class_idx + 1), source="train", mode="test", ret_data=True, ) idx_loader = DataLoader( idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4 ) with torch.no_grad(): cidx_cls_entropies = [] for idx, (_, inputs, targets) in enumerate(idx_loader): inputs, targets = inputs.to(self._device), targets.to(self._device) logits = self._network(inputs)["logits"] cross_entropy = ( F.cross_entropy(logits, targets, reduction="none") .detach() .cpu() .numpy() ) cidx_cls_entropies.append(cross_entropy) # print(cidx_cls_entropies) cidx_cls_entropies = np.mean(np.concatenate(cidx_cls_entropies)) all_cls_entropies.append(cidx_cls_entropies) entropy_median = np.median(all_cls_entropies) for the_entropy in all_cls_entropies: if the_entropy > entropy_median: ms.append(ns[0]) else: ms.append(ns[1]) logging.info(f"ms: {ms}") for class_idx in range(self._known_classes, self._total_classes): data, targets, idx_dataset = data_manager.get_dataset( np.arange(class_idx, class_idx + 1), source="train", mode="test", ret_data=True, ) idx_loader = DataLoader( idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4 ) vectors, _ = self._extract_vectors(idx_loader) vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T class_mean = np.mean(vectors, axis=0) # Select selected_exemplars = [] exemplar_vectors = [] # [n, feature_dim] for k in range(1, ms[class_idx - self._known_classes] + 1): S = np.sum( exemplar_vectors, axis=0 ) # [feature_dim] sum of selected exemplars vectors mu_p = (vectors + S) / k # [n, feature_dim] sum to all vectors i = np.argmin(np.sqrt(np.sum((class_mean - mu_p) ** 2, axis=1))) selected_exemplars.append( np.array(data[i]) ) # New object to avoid passing by inference exemplar_vectors.append( np.array(vectors[i]) ) # New object to avoid passing by inference vectors = np.delete( vectors, i, axis=0 ) # Remove it to avoid duplicative selection data = np.delete( data, i, axis=0 ) # Remove it to avoid duplicative selection selected_exemplars = np.array(selected_exemplars) exemplar_targets = np.full(ms[class_idx - self._known_classes], class_idx) self._data_memory = ( np.concatenate((self._data_memory, selected_exemplars)) if len(self._data_memory) != 0 else selected_exemplars ) self._targets_memory = ( np.concatenate((self._targets_memory, exemplar_targets)) if len(self._targets_memory) != 0 else exemplar_targets ) # Exemplar mean idx_dataset = data_manager.get_dataset( [], source="train", mode="test", appendent=(selected_exemplars, exemplar_targets), ) idx_loader = DataLoader( idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4 ) vectors, _ = self._extract_vectors(idx_loader) vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T mean = np.mean(vectors, axis=0) mean = mean / np.linalg.norm(mean) self._class_means[class_idx, :] = mean class RMM_iCaRL( RMMBase, iCaRL ): # RMM Base is supposed to be prior to the orginal method. def __init__(self, args): RMMBase.__init__(self, args) iCaRL.__init__(self, args) def incremental_train(self, data_manager): self._cur_task += 1 self._total_classes = self._known_classes + data_manager.get_task_size( self._cur_task ) self._network.update_fc(self._total_classes) logging.info( "Learning on {}-{}".format(self._known_classes, self._total_classes) ) train_dataset = data_manager.get_dataset( np.arange(self._known_classes, self._total_classes), source="train", mode="train", appendent=self._get_memory(), m_rate=self._m_rate_list[self._cur_task] if self._cur_task > 0 else None, ) self.train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, ) test_dataset = data_manager.get_dataset( np.arange(0, self._total_classes), source="test", mode="test" ) self.test_loader = DataLoader( test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers ) if len(self._multiple_gpus) > 1: self._network = nn.DataParallel(self._network, self._multiple_gpus) self._train(self.train_loader, self.test_loader) self.build_rehearsal_memory(data_manager, self.samples_per_class) if len(self._multiple_gpus) > 1: self._network = self._network.module class RMM_FOSTER(RMMBase, FOSTER): def __init__(self, args): RMMBase.__init__(self, args) FOSTER.__init__(self, args) def incremental_train(self, data_manager): self.data_manager = data_manager self._cur_task += 1 if self._cur_task > 1: self._network = self._snet self._total_classes = self._known_classes + data_manager.get_task_size( self._cur_task ) self._network.update_fc(self._total_classes) self._network_module_ptr = self._network logging.info( "Learning on {}-{}".format(self._known_classes, self._total_classes) ) if self._cur_task > 0: for p in self._network.convnets[0].parameters(): p.requires_grad = False for p in self._network.oldfc.parameters(): p.requires_grad = False logging.info("All params: {}".format(count_parameters(self._network))) logging.info( "Trainable params: {}".format(count_parameters(self._network, True)) ) train_dataset = data_manager.get_dataset( np.arange(self._known_classes, self._total_classes), source="train", mode="train", appendent=self._get_memory(), m_rate=self._m_rate_list[self._cur_task] if self._cur_task > 0 else None, ) self.train_loader = DataLoader( train_dataset, batch_size=self.args["batch_size"], shuffle=True, num_workers=self.args["num_workers"], pin_memory=True, ) test_dataset = data_manager.get_dataset( np.arange(0, self._total_classes), source="test", mode="test" ) self.test_loader = DataLoader( test_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=self.args["num_workers"], ) if len(self._multiple_gpus) > 1: self._network = nn.DataParallel(self._network, self._multiple_gpus) self._train(self.train_loader, self.test_loader) self.build_rehearsal_memory(data_manager, self.samples_per_class) if len(self._multiple_gpus) > 1: self._network = self._network.module