import logging import numpy as np from tqdm import tqdm import os import torch from torch import nn from torch import optim from torch.nn import functional as F from torch.utils.data import DataLoader,Dataset from models.base import BaseLearner from utils.inc_net import CosineIncrementalNet, FOSTERNet, IncrementalNet from utils.toolkit import count_parameters, target2onehot, tensor2numpy EPSILON = 1e-8 class PASS(BaseLearner): def __init__(self, args): super().__init__(args) self.args = args self._network = IncrementalNet(args, False) self._protos = [] self._radius = 0 self._radiuses = [] def after_task(self): self._known_classes = self._total_classes self._old_network = self._network.copy().freeze() if hasattr(self._old_network,"module"): self.old_network_module_ptr = self._old_network.module else: self.old_network_module_ptr = self._old_network #self.save_checkpoint("{}_{}_{}".format(self.args["model_name"],self.args["init_cls"],self.args["increment"])) def incremental_train(self, data_manager): self.data_manager = data_manager self._cur_task += 1 self._total_classes = self._known_classes + \ data_manager.get_task_size(self._cur_task) self._network.update_fc(self._total_classes*4) self._network_module_ptr = self._network logging.info( 'Learning on {}-{}'.format(self._known_classes, self._total_classes)) logging.info('All params: {}'.format(count_parameters(self._network))) logging.info('Trainable params: {}'.format( count_parameters(self._network, True))) train_dataset = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes), source='train', mode='train', appendent=self._get_memory()) self.train_loader = DataLoader( train_dataset, batch_size=self.args["batch_size"], shuffle=True, num_workers=self.args["num_workers"], pin_memory=True) test_dataset = data_manager.get_dataset( np.arange(0, self._total_classes), source='test', mode='test') self.test_loader = DataLoader( test_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=self.args["num_workers"]) if len(self._multiple_gpus) > 1: self._network = nn.DataParallel(self._network, self._multiple_gpus) self._train(self.train_loader, self.test_loader) if len(self._multiple_gpus) > 1: self._network = self._network.module def _train(self, train_loader, test_loader): resume = False if self._cur_task in []: self._network.load_state_dict(torch.load("{}_{}_{}_{}.pkl".format(self.args["model_name"],self.args["init_cls"],self.args["increment"],self._cur_task))["model_state_dict"]) resume = True self._network.to(self._device) if hasattr(self._network, "module"): self._network_module_ptr = self._network.module if not resume: self._epoch_num = self.args["epochs"] optimizer = torch.optim.Adam(self._network.parameters(), lr=self.args["lr"], weight_decay=self.args["weight_decay"]) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.args["step_size"], gamma=self.args["gamma"]) self._train_function(train_loader, test_loader, optimizer, scheduler) self._build_protos() def _build_protos(self): with torch.no_grad(): for class_idx in range(self._known_classes, self._total_classes): data, targets, idx_dataset = self.data_manager.get_dataset(np.arange(class_idx, class_idx+1), source='train', mode='test', ret_data=True) idx_loader = DataLoader(idx_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=4) vectors, _ = self._extract_vectors(idx_loader) class_mean = np.mean(vectors, axis=0) self._protos.append(class_mean) cov = np.cov(vectors.T) self._radiuses.append(np.trace(cov)/vectors.shape[1]) self._radius = np.sqrt(np.mean(self._radiuses)) def _train_function(self, train_loader, test_loader, optimizer, scheduler): prog_bar = tqdm(range(self._epoch_num)) for _, epoch in enumerate(prog_bar): self._network.train() losses = 0. losses_clf, losses_fkd, losses_proto = 0., 0., 0. correct, total = 0, 0 for i, (_, inputs, targets) in enumerate(train_loader): inputs, targets = inputs.to( self._device, non_blocking=True), targets.to(self._device, non_blocking=True) inputs = torch.stack([torch.rot90(inputs, k, (2, 3)) for k in range(4)], 1) inputs = inputs.view(-1, 3, 320, 320) targets = torch.stack([targets * 4 + k for k in range(4)], 1).view(-1) logits, loss_clf, loss_fkd, loss_proto = self._compute_pass_loss(inputs,targets) loss = loss_clf + loss_fkd + loss_proto optimizer.zero_grad() loss.backward() optimizer.step() losses += loss.item() losses_clf += loss_clf.item() losses_fkd += loss_fkd.item() losses_proto += loss_proto.item() _, preds = torch.max(logits, dim=1) correct += preds.eq(targets.expand_as(preds)).cpu().sum() total += len(targets) scheduler.step() train_acc = np.around(tensor2numpy( correct)*100 / total, decimals=2) if epoch % 5 != 0: info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_fkd {:.3f}, Loss_proto {:.3f}, Train_accy {:.2f}'.format( self._cur_task, epoch+1, self._epoch_num, losses/len(train_loader), losses_clf/len(train_loader), losses_fkd/len(train_loader), losses_proto/len(train_loader), train_acc) else: test_acc = self._compute_accuracy(self._network, test_loader) info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_fkd {:.3f}, Loss_proto {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}'.format( self._cur_task, epoch+1, self._epoch_num, losses/len(train_loader), losses_clf/len(train_loader), losses_fkd/len(train_loader), losses_proto/len(train_loader), train_acc, test_acc) prog_bar.set_description(info) logging.info(info) def _compute_pass_loss(self,inputs, targets): logits = self._network(inputs)["logits"] loss_clf = F.cross_entropy(logits/self.args["temp"], targets) if self._cur_task == 0: return logits, loss_clf, torch.tensor(0.), torch.tensor(0.) features = self._network_module_ptr.extract_vector(inputs) features_old = self.old_network_module_ptr.extract_vector(inputs) loss_fkd = self.args["lambda_fkd"] * torch.dist(features, features_old, 2) # index = np.random.choice(range(self._known_classes),size=self.args["batch_size"],replace=True) index = np.random.choice(range(self._known_classes),size=self.args["batch_size"]*int(self._known_classes/(self._total_classes-self._known_classes)),replace=True) # print(index) # print(np.concatenate(self._protos)) proto_features = np.array(self._protos)[index] # print(proto_features) proto_targets = 4*index proto_features = proto_features + np.random.normal(0,1,proto_features.shape)*self._radius proto_features = torch.from_numpy(proto_features).float().to(self._device,non_blocking=True) proto_targets = torch.from_numpy(proto_targets).to(self._device,non_blocking=True) proto_logits = self._network_module_ptr.fc(proto_features)["logits"] loss_proto = self.args["lambda_proto"] * F.cross_entropy(proto_logits/self.args["temp"], proto_targets) return logits, loss_clf, loss_fkd, loss_proto def _compute_accuracy(self, model, loader): model.eval() correct, total = 0, 0 for i, (_, inputs, targets) in enumerate(loader): inputs = inputs.to(self._device) with torch.no_grad(): outputs = model(inputs)["logits"][:,::4] predicts = torch.max(outputs, dim=1)[1] correct += (predicts.cpu() == targets).sum() total += len(targets) return np.around(tensor2numpy(correct)*100 / total, decimals=2) def _eval_cnn(self, loader): self._network.eval() y_pred, y_true = [], [] for _, (_, inputs, targets) in enumerate(loader): inputs = inputs.to(self._device) with torch.no_grad(): outputs = self._network(inputs)["logits"][:,::4] predicts = torch.topk(outputs, k=self.topk, dim=1, largest=True, sorted=True)[1] y_pred.append(predicts.cpu().numpy()) y_true.append(targets.cpu().numpy()) return np.concatenate(y_pred), np.concatenate(y_true) def eval_task(self, save_conf=True): y_pred, y_true = self._eval_cnn(self.test_loader) cnn_accy = self._evaluate(y_pred, y_true) if hasattr(self, '_class_means'): y_pred, y_true = self._eval_nme(self.test_loader, self._class_means) nme_accy = self._evaluate(y_pred, y_true) elif hasattr(self, '_protos'): y_pred, y_true = self._eval_nme(self.test_loader, self._protos/np.linalg.norm(self._protos,axis=1)[:,None]) nme_accy = self._evaluate(y_pred, y_true) else: nme_accy = None if save_conf: _pred = y_pred.T[0] _pred_path = os.path.join(self.args['logfilename'], "pred.npy") _target_path = os.path.join(self.args['logfilename'], "target.npy") np.save(_pred_path, _pred) np.save(_target_path, y_true) _save_dir = os.path.join(f"./results/{self.args['model_name']}/conf_matrix/{self.args['prefix']}") os.makedirs(_save_dir, exist_ok=True) _save_path = os.path.join(_save_dir, f"{self.args['csv_name']}.csv") with open(_save_path, "a+") as f: f.write(f"{self.args['model_name']},{_pred_path},{_target_path} \n") return cnn_accy, nme_accy