import logging import numpy as np from tqdm import tqdm import torch from torch import nn import copy from torch import optim from torch.nn import functional as F from torch.utils.data import DataLoader from models.base import BaseLearner from utils.inc_net import AdaptiveNet from utils.toolkit import count_parameters, target2onehot, tensor2numpy num_workers=8 EPSILON = 1e-8 batch_size = 32 class MEMO(BaseLearner): def __init__(self, args): super().__init__(args) self.args = args self._old_base = None self._network = AdaptiveNet(args, True) logging.info(f'>>> train generalized blocks:{self.args["train_base"]} train_adaptive:{self.args["train_adaptive"]}') def after_task(self): self._known_classes = self._total_classes if self._cur_task == 0: if self.args['train_base']: logging.info("Train Generalized Blocks...") self._network.TaskAgnosticExtractor.train() for param in self._network.TaskAgnosticExtractor.parameters(): param.requires_grad = True else: logging.info("Fix Generalized Blocks...") self._network.TaskAgnosticExtractor.eval() for param in self._network.TaskAgnosticExtractor.parameters(): param.requires_grad = False logging.info('Exemplar size: {}'.format(self.exemplar_size)) def incremental_train(self, data_manager): self._cur_task += 1 self._total_classes = self._known_classes + data_manager.get_task_size(self._cur_task) self._network.update_fc(self._total_classes) logging.info('Learning on {}-{}'.format(self._known_classes, self._total_classes)) if self._cur_task>0: for i in range(self._cur_task): for p in self._network.AdaptiveExtractors[i].parameters(): if self.args['train_adaptive'] and i == self._cur_task: p.requires_grad = True else: p.requires_grad = False logging.info('All params: {}'.format(count_parameters(self._network))) logging.info('Trainable params: {}'.format(count_parameters(self._network, True))) train_dataset = data_manager.get_dataset( np.arange(self._known_classes, self._total_classes), source='train', mode='train', appendent=self._get_memory() ) self.train_loader = DataLoader( train_dataset, batch_size=self.args["batch_size"], shuffle=True, num_workers=num_workers ) test_dataset = data_manager.get_dataset( np.arange(0, self._total_classes), source='test', mode='test' ) self.test_loader = DataLoader( test_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=num_workers ) if len(self._multiple_gpus) > 1: self._network = nn.DataParallel(self._network, self._multiple_gpus) self._train(self.train_loader, self.test_loader) self.build_rehearsal_memory(data_manager, self.samples_per_class) if len(self._multiple_gpus) > 1: self._network = self._network.module def set_network(self): if len(self._multiple_gpus) > 1: self._network = self._network.module self._network.train() #All status from eval to train if self.args['train_base']: self._network.TaskAgnosticExtractor.train() else: self._network.TaskAgnosticExtractor.eval() # set adaptive extractor's status self._network.AdaptiveExtractors[-1].train() if self._cur_task >= 1: for i in range(self._cur_task): if self.args['train_adaptive']: self._network.AdaptiveExtractors[i].train() else: self._network.AdaptiveExtractors[i].eval() if len(self._multiple_gpus) > 1: self._network = nn.DataParallel(self._network, self._multiple_gpus) def _train(self, train_loader, test_loader): self._network.to(self._device) if self._cur_task==0: optimizer = optim.SGD( filter(lambda p: p.requires_grad, self._network.parameters()), momentum=0.9, lr=self.args["init_lr"], weight_decay=self.args["init_weight_decay"] ) if self.args['scheduler'] == 'steplr': scheduler = optim.lr_scheduler.MultiStepLR( optimizer=optimizer, milestones=self.args['init_milestones'], gamma=self.args['init_lr_decay'] ) elif self.args['scheduler'] == 'cosine': scheduler = optim.lr_scheduler.CosineAnnealingLR( optimizer=optimizer, T_max=self.args['init_epoch'] ) else: raise NotImplementedError if not self.args['skip']: self._init_train(train_loader, test_loader, optimizer, scheduler) else: if isinstance(self._network, nn.DataParallel): self._network = self._network.module load_acc = self._network.load_checkpoint(self.args) self._network.to(self._device) if len(self._multiple_gpus) > 1: self._network = nn.DataParallel(self._network, self._multiple_gpus) cur_test_acc = self._compute_accuracy(self._network, self.test_loader) logging.info(f"Loaded_Test_Acc:{load_acc} Cur_Test_Acc:{cur_test_acc}") else: optimizer = optim.SGD( filter(lambda p: p.requires_grad, self._network.parameters()), lr=self.args['lrate'], momentum=0.9, weight_decay=self.args['weight_decay'] ) if self.args['scheduler'] == 'steplr': scheduler = optim.lr_scheduler.MultiStepLR( optimizer=optimizer, milestones=self.args['milestones'], gamma=self.args['lrate_decay'] ) elif self.args['scheduler'] == 'cosine': assert self.args['t_max'] is not None scheduler = optim.lr_scheduler.CosineAnnealingLR( optimizer=optimizer, T_max=self.args['t_max'] ) else: raise NotImplementedError self._update_representation(train_loader, test_loader, optimizer, scheduler) if len(self._multiple_gpus) > 1: self._network.module.weight_align(self._total_classes-self._known_classes) else: self._network.weight_align(self._total_classes-self._known_classes) def _init_train(self,train_loader,test_loader,optimizer,scheduler): prog_bar = tqdm(range(self.args["init_epoch"])) for _, epoch in enumerate(prog_bar): self._network.train() losses = 0. correct, total = 0, 0 for i, (_, inputs, targets) in enumerate(train_loader): inputs, targets = inputs.to(self._device), targets.to(self._device) logits = self._network(inputs)['logits'] loss=F.cross_entropy(logits,targets) optimizer.zero_grad() loss.backward() optimizer.step() losses += loss.item() _, preds = torch.max(logits, dim=1) correct += preds.eq(targets.expand_as(preds)).cpu().sum() total += len(targets) scheduler.step() train_acc = np.around(tensor2numpy(correct)*100 / total, decimals=2) if epoch%5==0: test_acc = self._compute_accuracy(self._network, test_loader) info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}'.format( self._cur_task, epoch+1, self.args['init_epoch'], losses/len(train_loader), train_acc, test_acc) else: info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}'.format( self._cur_task, epoch+1, self.args['init_epoch'], losses/len(train_loader), train_acc) # prog_bar.set_description(info) logging.info(info) def _update_representation(self, train_loader, test_loader, optimizer, scheduler): prog_bar = tqdm(range(self.args["epochs"])) for _, epoch in enumerate(prog_bar): self.set_network() losses = 0. losses_clf=0. losses_aux=0. correct, total = 0, 0 for i, (_, inputs, targets) in enumerate(train_loader): inputs, targets = inputs.to(self._device), targets.to(self._device) outputs= self._network(inputs) logits,aux_logits=outputs["logits"],outputs["aux_logits"] loss_clf=F.cross_entropy(logits,targets) aux_targets = targets.clone() aux_targets=torch.where(aux_targets-self._known_classes+1.0>0, aux_targets-self._known_classes+1.0,torch.Tensor([.0]).to(self.args["device"][0])) loss_aux=F.cross_entropy(aux_logits,aux_targets.long()) loss=loss_clf+self.args['alpha_aux']*loss_aux optimizer.zero_grad() loss.backward() optimizer.step() losses += loss.item() losses_aux+=loss_aux.item() losses_clf+=loss_clf.item() _, preds = torch.max(logits, dim=1) correct += preds.eq(targets.expand_as(preds)).cpu().sum() total += len(targets) scheduler.step() train_acc = np.around(tensor2numpy(correct)*100 / total, decimals=2) if epoch%5==0: test_acc = self._compute_accuracy(self._network, test_loader) info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_aux {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}'.format( self._cur_task, epoch+1, self.args["epochs"], losses/len(train_loader),losses_clf/len(train_loader),losses_aux/len(train_loader),train_acc, test_acc) else: info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_aux {:.3f}, Train_accy {:.2f}'.format( self._cur_task, epoch+1, self.args["epochs"], losses/len(train_loader), losses_clf/len(train_loader),losses_aux/len(train_loader),train_acc) prog_bar.set_description(info) logging.info(info) def save_checkpoint(self, test_acc): assert self.args['model_name'] == 'finetune' checkpoint_name = f"checkpoints/finetune_{self.args['csv_name']}" _checkpoint_cpu = copy.deepcopy(self._network) if isinstance(_checkpoint_cpu, nn.DataParallel): _checkpoint_cpu = _checkpoint_cpu.module _checkpoint_cpu.cpu() save_dict = { "tasks": self._cur_task, "convnet": _checkpoint_cpu.convnet.state_dict(), "fc":_checkpoint_cpu.fc.state_dict(), "test_acc": test_acc } torch.save(save_dict, "{}_{}.pkl".format(checkpoint_name, self._cur_task)) def _construct_exemplar(self, data_manager, m): logging.info("Constructing exemplars...({} per classes)".format(m)) for class_idx in range(self._known_classes, self._total_classes): data, targets, idx_dataset = data_manager.get_dataset( np.arange(class_idx, class_idx + 1), source="train", mode="test", ret_data=True, ) idx_loader = DataLoader( idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4 ) vectors, _ = self._extract_vectors(idx_loader) vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T class_mean = np.mean(vectors, axis=0) # Select selected_exemplars = [] exemplar_vectors = [] # [n, feature_dim] for k in range(1, m + 1): S = np.sum( exemplar_vectors, axis=0 ) # [feature_dim] sum of selected exemplars vectors mu_p = (vectors + S) / k # [n, feature_dim] sum to all vectors i = np.argmin(np.sqrt(np.sum((class_mean - mu_p) ** 2, axis=1))) selected_exemplars.append( np.array(data[i]) ) # New object to avoid passing by inference exemplar_vectors.append( np.array(vectors[i]) ) # New object to avoid passing by inference vectors = np.delete( vectors, i, axis=0 ) # Remove it to avoid duplicative selection data = np.delete( data, i, axis=0 ) # Remove it to avoid duplicative selection if len(vectors) == 0: break # uniques = np.unique(selected_exemplars, axis=0) # print('Unique elements: {}'.format(len(uniques))) selected_exemplars = np.array(selected_exemplars) # exemplar_targets = np.full(m, class_idx) exemplar_targets = np.full(selected_exemplars.shape[0], class_idx) self._data_memory = ( np.concatenate((self._data_memory, selected_exemplars)) if len(self._data_memory) != 0 else selected_exemplars ) self._targets_memory = ( np.concatenate((self._targets_memory, exemplar_targets)) if len(self._targets_memory) != 0 else exemplar_targets ) # Exemplar mean idx_dataset = data_manager.get_dataset( [], source="train", mode="test", appendent=(selected_exemplars, exemplar_targets), ) idx_loader = DataLoader( idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4 ) vectors, _ = self._extract_vectors(idx_loader) vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T mean = np.mean(vectors, axis=0) mean = mean / np.linalg.norm(mean) self._class_means[class_idx, :] = mean