import logging import numpy as np from tqdm import tqdm import torch from torch import nn from torch import optim from torch.nn import functional as F from torch.utils.data import DataLoader from models.base import BaseLearner from models.podnet import pod_spatial_loss from utils.inc_net import IncrementalNet from utils.toolkit import target2onehot, tensor2numpy EPSILON = 1e-8 init_epoch = 200 init_lr = 0.1 init_milestones = [60, 120, 170] init_lr_decay = 0.1 init_weight_decay = 0.0005 epochs = 180 lrate = 0.1 milestones = [70, 120, 150] lrate_decay = 0.1 batch_size = 128 weight_decay = 2e-4 num_workers = 4 T = 2 lamda = 1000 fishermax = 0.0001 class EWC(BaseLearner): def __init__(self, args): super().__init__(args) self.fisher = None self._network = IncrementalNet(args, False) def after_task(self): self._known_classes = self._total_classes def incremental_train(self, data_manager): self._cur_task += 1 self._total_classes = self._known_classes + data_manager.get_task_size( self._cur_task ) self._network.update_fc(self._total_classes) logging.info( "Learning on {}-{}".format(self._known_classes, self._total_classes) ) train_dataset = data_manager.get_dataset( np.arange(self._known_classes, self._total_classes), source="train", mode="train", ) self.train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers ) test_dataset = data_manager.get_dataset( np.arange(0, self._total_classes), source="test", mode="test" ) self.test_loader = DataLoader( test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers ) if len(self._multiple_gpus) > 1: self._network = nn.DataParallel(self._network, self._multiple_gpus) self._train(self.train_loader, self.test_loader) if len(self._multiple_gpus) > 1: self._network = self._network.module if self.fisher is None: self.fisher = self.getFisherDiagonal(self.train_loader) else: alpha = self._known_classes / self._total_classes new_finsher = self.getFisherDiagonal(self.train_loader) for n, p in new_finsher.items(): new_finsher[n][: len(self.fisher[n])] = ( alpha * self.fisher[n] + (1 - alpha) * new_finsher[n][: len(self.fisher[n])] ) self.fisher = new_finsher self.mean = { n: p.clone().detach() for n, p in self._network.named_parameters() if p.requires_grad } def _train(self, train_loader, test_loader): self._network.to(self._device) if self._cur_task == 0: optimizer = optim.SGD( self._network.parameters(), momentum=0.9, lr=init_lr, weight_decay=init_weight_decay, ) scheduler = optim.lr_scheduler.MultiStepLR( optimizer=optimizer, milestones=init_milestones, gamma=init_lr_decay ) self._init_train(train_loader, test_loader, optimizer, scheduler) else: optimizer = optim.SGD( self._network.parameters(), lr=lrate, momentum=0.9, weight_decay=weight_decay, ) scheduler = optim.lr_scheduler.MultiStepLR( optimizer=optimizer, milestones=milestones, gamma=lrate_decay ) self._update_representation(train_loader, test_loader, optimizer, scheduler) def _init_train(self, train_loader, test_loader, optimizer, scheduler): prog_bar = tqdm(range(init_epoch)) for _, epoch in enumerate(prog_bar): self._network.train() losses = 0.0 correct, total = 0, 0 for i, (_, inputs, targets) in enumerate(train_loader): inputs, targets = inputs.to(self._device), targets.to(self._device) logits = self._network(inputs)["logits"] loss = F.cross_entropy(logits, targets) optimizer.zero_grad() loss.backward() optimizer.step() losses += loss.item() _, preds = torch.max(logits, dim=1) correct += preds.eq(targets.expand_as(preds)).cpu().sum() total += len(targets) scheduler.step() train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) if epoch % 5 == 0: test_acc = self._compute_accuracy(self._network, test_loader) info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( self._cur_task, epoch + 1, init_epoch, losses / len(train_loader), train_acc, test_acc, ) else: info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format( self._cur_task, epoch + 1, init_epoch, losses / len(train_loader), train_acc, ) prog_bar.set_description(info) logging.info(info) def _update_representation(self, train_loader, test_loader, optimizer, scheduler): prog_bar = tqdm(range(epochs)) for _, epoch in enumerate(prog_bar): self._network.train() losses = 0.0 correct, total = 0, 0 for i, (_, inputs, targets) in enumerate(train_loader): inputs, targets = inputs.to(self._device), targets.to(self._device) logits = self._network(inputs)["logits"] loss_clf = F.cross_entropy( logits[:, self._known_classes :], targets - self._known_classes ) loss_ewc = self.compute_ewc() loss = loss_clf + lamda * loss_ewc optimizer.zero_grad() loss.backward() optimizer.step() losses += loss.item() _, preds = torch.max(logits, dim=1) correct += preds.eq(targets.expand_as(preds)).cpu().sum() total += len(targets) scheduler.step() train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) if epoch % 5 == 0: test_acc = self._compute_accuracy(self._network, test_loader) info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( self._cur_task, epoch + 1, epochs, losses / len(train_loader), train_acc, test_acc, ) else: info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format( self._cur_task, epoch + 1, epochs, losses / len(train_loader), train_acc, ) prog_bar.set_description(info) logging.info(info) def compute_ewc(self): loss = 0 if len(self._multiple_gpus) > 1: for n, p in self._network.module.named_parameters(): if n in self.fisher.keys(): loss += ( torch.sum( (self.fisher[n]) * (p[: len(self.mean[n])] - self.mean[n]).pow(2) ) / 2 ) else: for n, p in self._network.named_parameters(): if n in self.fisher.keys(): loss += ( torch.sum( (self.fisher[n]) * (p[: len(self.mean[n])] - self.mean[n]).pow(2) ) / 2 ) return loss def getFisherDiagonal(self, train_loader): fisher = { n: torch.zeros(p.shape).to(self._device) for n, p in self._network.named_parameters() if p.requires_grad } self._network.train() optimizer = optim.SGD(self._network.parameters(), lr=lrate) for i, (_, inputs, targets) in enumerate(train_loader): inputs, targets = inputs.to(self._device), targets.to(self._device) logits = self._network(inputs)["logits"] loss = torch.nn.functional.cross_entropy(logits, targets) optimizer.zero_grad() loss.backward() for n, p in self._network.named_parameters(): if p.grad is not None: fisher[n] += p.grad.pow(2).clone() for n, p in fisher.items(): fisher[n] = p / len(train_loader) fisher[n] = torch.min(fisher[n], torch.tensor(fishermax)) return fisher