import logging import numpy as np from tqdm import tqdm import torch from torch import nn from torch import optim from torch.nn import functional as F from torch.utils.data import DataLoader from models.base import BaseLearner from utils.inc_net import FOSTERNet from utils.toolkit import count_parameters, target2onehot, tensor2numpy # Please refer to https://github.com/G-U-N/ECCV22-FOSTER for the full source code to reproduce foster. EPSILON = 1e-8 class FOSTER(BaseLearner): def __init__(self, args): super().__init__(args) self.args = args self._network = FOSTERNet(args, False) self._snet = None self.beta1 = args["beta1"] self.beta2 = args["beta2"] self.per_cls_weights = None self.is_teacher_wa = args["is_teacher_wa"] self.is_student_wa = args["is_student_wa"] self.lambda_okd = args["lambda_okd"] self.wa_value = args["wa_value"] self.oofc = args["oofc"].lower() def after_task(self): self._known_classes = self._total_classes logging.info("Exemplar size: {}".format(self.exemplar_size)) def load_checkpoint(self, filename): checkpoint = torch.load(filename) self._known_classes = len(checkpoint["classes"]) self.class_list = np.array(checkpoint["classes"]) self.label_list = checkpoint["label_list"] self._network.update_fc(self._known_classes) self._network.load_checkpoint(checkpoint["network"]) self._network.to(self._device) self._cur_task = 0 def save_checkpoint(self, filename): self._network.cpu() save_dict = { "classes": self.data_manager.get_class_list(self._cur_task), "network": { "convnet": self._network.convnets[0].state_dict(), "fc": self._network.fc.state_dict() }, "label_list": self.data_manager.get_label_list(self._cur_task), } torch.save(save_dict, "./{}/{}_{}.pkl".format(filename, self.args['model_name'], self._cur_task)) def incremental_train(self, data_manager): self.data_manager = data_manager if hasattr(self.data_manager,'label_list') and hasattr(self,'label_list'): self.data_manager.label_list = list(self.label_list.values()) + self.data_manager.label_list self._cur_task += 1 if self._cur_task > 1: self._network = self._snet self._total_classes = self._known_classes + data_manager.get_task_size( self._cur_task ) self._network.update_fc(self._total_classes) self._network_module_ptr = self._network logging.info( "Learning on {}-{}".format(self._known_classes, self._total_classes) ) if self._cur_task > 0: for p in self._network.convnets[0].parameters(): p.requires_grad = False for p in self._network.oldfc.parameters(): p.requires_grad = False logging.info("All params: {}".format(count_parameters(self._network))) logging.info( "Trainable params: {}".format(count_parameters(self._network, True)) ) train_dataset = data_manager.get_dataset( np.arange(self._known_classes, self._total_classes), source="train", mode="train", appendent=self._get_memory(), ) self.train_loader = DataLoader( train_dataset, batch_size=self.args["batch_size"], shuffle=True, num_workers=self.args["num_workers"], pin_memory=True, ) test_dataset = data_manager.get_dataset( np.arange(0, self._total_classes), source="test", mode="test" ) self.test_loader = DataLoader( test_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=self.args["num_workers"], ) if len(self._multiple_gpus) > 1: self._network = nn.DataParallel(self._network, self._multiple_gpus) self._train(self.train_loader, self.test_loader) #self.build_rehearsal_memory(data_manager, self.samples_per_class) if len(self._multiple_gpus) > 1: self._network = self._network.module def train(self): self._network_module_ptr.train() self._network_module_ptr.convnets[-1].train() if self._cur_task >= 1: self._network_module_ptr.convnets[0].eval() def _train(self, train_loader, test_loader): self._network.to(self._device) if hasattr(self._network, "module"): self._network_module_ptr = self._network.module if self._cur_task == 0: optimizer = optim.SGD( filter(lambda p: p.requires_grad, self._network.parameters()), momentum=0.9, lr=self.args["init_lr"], weight_decay=self.args["init_weight_decay"], ) scheduler = optim.lr_scheduler.CosineAnnealingLR( optimizer=optimizer, T_max=self.args["init_epochs"] ) self._init_train(train_loader, test_loader, optimizer, scheduler) else: cls_num_list = [self.samples_old_class] * self._known_classes + [ self.samples_new_class(i) for i in range(self._known_classes, self._total_classes) ] effective_num = 1.0 - np.power(self.beta1, cls_num_list) per_cls_weights = (1.0 - self.beta1) / np.array(effective_num) per_cls_weights = ( per_cls_weights / np.sum(per_cls_weights) * len(cls_num_list) ) logging.info("per cls weights : {}".format(per_cls_weights)) self.per_cls_weights = torch.FloatTensor(per_cls_weights).to(self._device) optimizer = optim.SGD( filter(lambda p: p.requires_grad, self._network.parameters()), lr=self.args["lr"], momentum=0.9, weight_decay=self.args["weight_decay"], ) scheduler = optim.lr_scheduler.CosineAnnealingLR( optimizer=optimizer, T_max=self.args["boosting_epochs"] ) if self.oofc == "az": for i, p in enumerate(self._network_module_ptr.fc.parameters()): if i == 0: p.data[ self._known_classes :, : self._network_module_ptr.out_dim ] = torch.tensor(0.0) elif self.oofc != "ft": assert 0, "not implemented" self._feature_boosting(train_loader, test_loader, optimizer, scheduler) if self.is_teacher_wa: self._network_module_ptr.weight_align( self._known_classes, self._total_classes - self._known_classes, self.wa_value, ) else: logging.info("do not weight align teacher!") cls_num_list = [self.samples_old_class] * self._known_classes + [ self.samples_new_class(i) for i in range(self._known_classes, self._total_classes) ] effective_num = 1.0 - np.power(self.beta2, cls_num_list) per_cls_weights = (1.0 - self.beta2) / np.array(effective_num) per_cls_weights = ( per_cls_weights / np.sum(per_cls_weights) * len(cls_num_list) ) logging.info("per cls weights : {}".format(per_cls_weights)) self.per_cls_weights = torch.FloatTensor(per_cls_weights).to(self._device) self._feature_compression(train_loader, test_loader) def _init_train(self, train_loader, test_loader, optimizer, scheduler): prog_bar = tqdm(range(self.args["init_epochs"])) for _, epoch in enumerate(prog_bar): self.train() losses = 0.0 correct, total = 0, 0 for i, (_, inputs, targets) in enumerate(train_loader): inputs, targets = inputs.to( self._device, non_blocking=True ), targets.to(self._device, non_blocking=True) logits = self._network(inputs)["logits"] loss = F.cross_entropy(logits, targets) optimizer.zero_grad() loss.backward() optimizer.step() losses += loss.item() _, preds = torch.max(logits, dim=1) correct += preds.eq(targets.expand_as(preds)).cpu().sum() total += len(targets) scheduler.step() train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) if epoch % 5 == 0: test_acc = self._compute_accuracy(self._network, test_loader) info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( self._cur_task, epoch + 1, self.args["init_epochs"], losses / len(train_loader), train_acc, test_acc, ) else: info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format( self._cur_task, epoch + 1, self.args["init_epochs"], losses / len(train_loader), train_acc, ) prog_bar.set_description(info) logging.info(info) def _feature_boosting(self, train_loader, test_loader, optimizer, scheduler): prog_bar = tqdm(range(self.args["boosting_epochs"])) for _, epoch in enumerate(prog_bar): self.train() losses = 0.0 losses_clf = 0.0 losses_fe = 0.0 losses_kd = 0.0 correct, total = 0, 0 for i, (_, inputs, targets) in enumerate(train_loader): inputs, targets = inputs.to( self._device, non_blocking=True ), targets.to(self._device, non_blocking=True) outputs = self._network(inputs) logits, fe_logits, old_logits = ( outputs["logits"], outputs["fe_logits"], outputs["old_logits"].detach(), ) loss_clf = F.cross_entropy(logits / self.per_cls_weights, targets) loss_fe = F.cross_entropy(fe_logits, targets) loss_kd = self.lambda_okd * _KD_loss( logits[:, : self._known_classes], old_logits, self.args["T"] ) loss = loss_clf + loss_fe + loss_kd optimizer.zero_grad() loss.backward() if self.oofc == "az": for i, p in enumerate(self._network_module_ptr.fc.parameters()): if i == 0: p.grad.data[ self._known_classes :, : self._network_module_ptr.out_dim, ] = torch.tensor(0.0) elif self.oofc != "ft": assert 0, "not implemented" optimizer.step() losses += loss.item() losses_fe += loss_fe.item() losses_clf += loss_clf.item() losses_kd += ( self._known_classes / self._total_classes ) * loss_kd.item() _, preds = torch.max(logits, dim=1) correct += preds.eq(targets.expand_as(preds)).cpu().sum() total += len(targets) scheduler.step() train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) if epoch % 5 == 0: test_acc = self._compute_accuracy(self._network, test_loader) info = "Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_fe {:.3f}, Loss_kd {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( self._cur_task, epoch + 1, self.args["boosting_epochs"], losses / len(train_loader), losses_clf / len(train_loader), losses_fe / len(train_loader), losses_kd / len(train_loader), train_acc, test_acc, ) else: info = "Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_fe {:.3f}, Loss_kd {:.3f}, Train_accy {:.2f}".format( self._cur_task, epoch + 1, self.args["boosting_epochs"], losses / len(train_loader), losses_clf / len(train_loader), losses_fe / len(train_loader), losses_kd / len(train_loader), train_acc, ) prog_bar.set_description(info) logging.info(info) def _feature_compression(self, train_loader, test_loader): self._snet = FOSTERNet(self.args, False) self._snet.update_fc(self._total_classes) if len(self._multiple_gpus) > 1: self._snet = nn.DataParallel(self._snet, self._multiple_gpus) if hasattr(self._snet, "module"): self._snet_module_ptr = self._snet.module else: self._snet_module_ptr = self._snet self._snet.to(self._device) self._snet_module_ptr.convnets[0].load_state_dict( self._network_module_ptr.convnets[0].state_dict() ) self._snet_module_ptr.copy_fc(self._network_module_ptr.oldfc) optimizer = optim.SGD( filter(lambda p: p.requires_grad, self._snet.parameters()), lr=self.args["lr"], momentum=0.9, ) scheduler = optim.lr_scheduler.CosineAnnealingLR( optimizer=optimizer, T_max=self.args["compression_epochs"] ) self._network.eval() prog_bar = tqdm(range(self.args["compression_epochs"])) for _, epoch in enumerate(prog_bar): self._snet.train() losses = 0.0 correct, total = 0, 0 for i, (_, inputs, targets) in enumerate(train_loader): inputs, targets = inputs.to( self._device, non_blocking=True ), targets.to(self._device, non_blocking=True) dark_logits = self._snet(inputs)["logits"] with torch.no_grad(): outputs = self._network(inputs) logits, old_logits, fe_logits = ( outputs["logits"], outputs["old_logits"], outputs["fe_logits"], ) loss_dark = self.BKD(dark_logits, logits, self.args["T"]) loss = loss_dark optimizer.zero_grad() loss.backward() optimizer.step() losses += loss.item() _, preds = torch.max(dark_logits[: targets.shape[0]], dim=1) correct += preds.eq(targets.expand_as(preds)).cpu().sum() total += len(targets) scheduler.step() train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) if epoch % 5 == 0: test_acc = self._compute_accuracy(self._snet, test_loader) info = "SNet: Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( self._cur_task, epoch + 1, self.args["compression_epochs"], losses / len(train_loader), train_acc, test_acc, ) else: info = "SNet: Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format( self._cur_task, epoch + 1, self.args["compression_epochs"], losses / len(train_loader), train_acc, ) prog_bar.set_description(info) logging.info(info) if len(self._multiple_gpus) > 1: self._snet = self._snet.module if self.is_student_wa: self._snet.weight_align( self._known_classes, self._total_classes - self._known_classes, self.wa_value, ) else: logging.info("do not weight align student!") if self._cur_task > 1: self._network = self._snet self._snet.eval() y_pred, y_true = [], [] for _, (_, inputs, targets) in enumerate(test_loader): inputs = inputs.to(self._device, non_blocking=True) with torch.no_grad(): outputs = self._snet(inputs)["logits"] predicts = torch.topk( outputs, k=self.topk, dim=1, largest=True, sorted=True )[1] y_pred.append(predicts.cpu().numpy()) y_true.append(targets.cpu().numpy()) y_pred = np.concatenate(y_pred) y_true = np.concatenate(y_true) cnn_accy = self._evaluate(y_pred, y_true) logging.info("darknet eval: ") logging.info("CNN top1 curve: {}".format(cnn_accy["top1"])) logging.info("CNN top5 curve: {}".format(cnn_accy["top5"])) @property def samples_old_class(self): if self._fixed_memory: return self._memory_per_class else: assert self._total_classes != 0, "Total classes is 0" return self._memory_size // self._known_classes def samples_new_class(self, index): if self.args["dataset"] == "cifar100": return 500 else: return self.data_manager.getlen(index) def BKD(self, pred, soft, T): pred = torch.log_softmax(pred / T, dim=1) soft = torch.softmax(soft / T, dim=1) soft = soft * self.per_cls_weights soft = soft / soft.sum(1)[:, None] return -1 * torch.mul(soft, pred).sum() / pred.shape[0] def _KD_loss(pred, soft, T): pred = torch.log_softmax(pred / T, dim=1) soft = torch.softmax(soft / T, dim=1) return -1 * torch.mul(soft, pred).sum() / pred.shape[0]