import logging import numpy as np from torch._C import device from tqdm import tqdm import torch from torch import nn from torch import optim from torch.nn import functional as F from torch.utils.data import DataLoader from models.base import BaseLearner from utils.inc_net import IncrementalNet from utils.inc_net import CosineIncrementalNet from utils.toolkit import target2onehot, tensor2numpy try: from quadprog import solve_qp except: pass EPSILON = 1e-8 init_epoch = 1 init_lr = 0.1 init_milestones = [40, 60, 80] init_lr_decay = 0.1 init_weight_decay = 0.0005 epochs = 1 lrate = 0.1 milestones = [20, 40, 60] lrate_decay = 0.1 batch_size = 16 weight_decay = 2e-4 num_workers = 4 class GEM(BaseLearner): def __init__(self, args): super().__init__(args) self._network = IncrementalNet(args, False) self.previous_data = None self.previous_label = None def after_task(self): self._old_network = self._network.copy().freeze() self._known_classes = self._total_classes logging.info("Exemplar size: {}".format(self.exemplar_size)) def incremental_train(self, data_manager): self._cur_task += 1 self._total_classes = self._known_classes + data_manager.get_task_size( self._cur_task ) self._network.update_fc(self._total_classes) logging.info( "Learning on {}-{}".format(self._known_classes, self._total_classes) ) train_dataset = data_manager.get_dataset( np.arange(self._known_classes, self._total_classes), source="train", mode="train", ) self.train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers ) test_dataset = data_manager.get_dataset( np.arange(0, self._total_classes), source="test", mode="test" ) self.test_loader = DataLoader( test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers ) if self._cur_task > 0: previous_dataset = data_manager.get_dataset( [], source="train", mode="train", appendent=self._get_memory() ) self.previous_data = [] self.previous_label = [] for i in previous_dataset: _, data_, label_ = i self.previous_data.append(data_) self.previous_label.append(label_) self.previous_data = torch.stack(self.previous_data) self.previous_label = torch.tensor(self.previous_label) # Procedure if len(self._multiple_gpus) > 1: self._network = nn.DataParallel(self._network, self._multiple_gpus) self._train(self.train_loader, self.test_loader) self.build_rehearsal_memory(data_manager, self.samples_per_class) if len(self._multiple_gpus) > 1: self._network = self._network.module def _train(self, train_loader, test_loader): self._network.to(self._device) if self._old_network is not None: self._old_network.to(self._device) if self._cur_task == 0: optimizer = optim.SGD( self._network.parameters(), momentum=0.9, lr=init_lr, weight_decay=init_weight_decay, ) scheduler = optim.lr_scheduler.MultiStepLR( optimizer=optimizer, milestones=init_milestones, gamma=init_lr_decay ) self._init_train(train_loader, test_loader, optimizer, scheduler) else: optimizer = optim.SGD( self._network.parameters(), lr=lrate, momentum=0.9, weight_decay=weight_decay, ) # 1e-5 scheduler = optim.lr_scheduler.MultiStepLR( optimizer=optimizer, milestones=milestones, gamma=lrate_decay ) self._update_representation(train_loader, test_loader, optimizer, scheduler) def _init_train(self, train_loader, test_loader, optimizer, scheduler): prog_bar = tqdm(range(init_epoch)) for _, epoch in enumerate(prog_bar): self._network.train() losses = 0.0 correct, total = 0, 0 for i, (_, inputs, targets) in enumerate(train_loader): inputs, targets = inputs.to(self._device), targets.to(self._device) logits = self._network(inputs)["logits"] loss = F.cross_entropy(logits, targets) optimizer.zero_grad() loss.backward() optimizer.step() losses += loss.item() _, preds = torch.max(logits, dim=1) correct += preds.eq(targets.expand_as(preds)).cpu().sum() total += len(targets) scheduler.step() train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) if epoch % 5 == 0: test_acc = self._compute_accuracy(self._network, test_loader) info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( self._cur_task, epoch + 1, init_epoch, losses / len(train_loader), train_acc, test_acc, ) else: info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format( self._cur_task, epoch + 1, init_epoch, losses / len(train_loader), train_acc, ) prog_bar.set_description(info) logging.info(info) def _update_representation(self, train_loader, test_loader, optimizer, scheduler): prog_bar = tqdm(range(epochs)) grad_numels = [] for params in self._network.parameters(): grad_numels.append(params.data.numel()) G = torch.zeros((sum(grad_numels), self._cur_task + 1)).to(self._device) for _, epoch in enumerate(prog_bar): self._network.train() losses = 0.0 correct, total = 0, 0 for i, (_, inputs, targets) in enumerate(train_loader): incremental_step = self._total_classes - self._known_classes for k in range(0, self._cur_task): optimizer.zero_grad() mask = torch.where( (self.previous_label >= k * incremental_step) & (self.previous_label < (k + 1) * incremental_step) )[0] data_ = self.previous_data[mask].to(self._device) label_ = self.previous_label[mask].to(self._device) pred_ = self._network(data_)["logits"] pred_[:, : k * incremental_step].data.fill_(-10e10) pred_[:, (k + 1) * incremental_step :].data.fill_(-10e10) loss_ = F.cross_entropy(pred_, label_) loss_.backward() j = 0 for params in self._network.parameters(): if params is not None: if j == 0: stpt = 0 else: stpt = sum(grad_numels[:j]) endpt = sum(grad_numels[: j + 1]) G[stpt:endpt, k].data.copy_(params.grad.data.view(-1)) j += 1 optimizer.zero_grad() inputs, targets = inputs.to(self._device), targets.to(self._device) logits = self._network(inputs)["logits"] logits[:, : self._known_classes].data.fill_(-10e10) loss_clf = F.cross_entropy(logits, targets) loss = loss_clf optimizer.zero_grad() loss.backward() j = 0 for params in self._network.parameters(): if params is not None: if j == 0: stpt = 0 else: stpt = sum(grad_numels[:j]) endpt = sum(grad_numels[: j + 1]) G[stpt:endpt, self._cur_task].data.copy_( params.grad.data.view(-1) ) j += 1 dotprod = torch.mm( G[:, self._cur_task].unsqueeze(0), G[:, : self._cur_task] ) if (dotprod < 0).sum() > 0: old_grad = G[:, : self._cur_task].cpu().t().double().numpy() cur_grad = G[:, self._cur_task].cpu().contiguous().double().numpy() C = old_grad @ old_grad.T p = old_grad @ cur_grad A = np.eye(old_grad.shape[0]) b = np.zeros(old_grad.shape[0]) v = solve_qp(C, -p, A, b)[0] new_grad = old_grad.T @ v + cur_grad new_grad = torch.tensor(new_grad).float().to(self._device) new_dotprod = torch.mm( new_grad.unsqueeze(0), G[:, : self._cur_task] ) if (new_dotprod < -0.01).sum() > 0: assert 0 j = 0 for params in self._network.parameters(): if params is not None: if j == 0: stpt = 0 else: stpt = sum(grad_numels[:j]) endpt = sum(grad_numels[: j + 1]) params.grad.data.copy_( new_grad[stpt:endpt] .contiguous() .view(params.grad.data.size()) ) j += 1 optimizer.step() losses += loss.item() _, preds = torch.max(logits, dim=1) correct += preds.eq(targets.expand_as(preds)).cpu().sum() total += len(targets) scheduler.step() train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) if epoch % 5 == 0: test_acc = self._compute_accuracy(self._network, test_loader) info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( self._cur_task, epoch + 1, epochs, losses / len(train_loader), train_acc, test_acc, ) else: info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format( self._cur_task, epoch + 1, epochs, losses / len(train_loader), train_acc, ) prog_bar.set_description(info) logging.info(info)