| """ |
| @inproceedings{rebuffi2017icarl, |
| title={icarl: Incremental classifier and representation learning}, |
| author={Rebuffi, Sylvestre-Alvise and Kolesnikov, Alexander and Sperl, Georg and Lampert, Christoph H}, |
| booktitle={Proceedings of the IEEE conference on Computer Vision and Pattern Recognition}, |
| pages={2001--2010}, |
| year={2017} |
| } |
| https://arxiv.org/abs/1611.07725 |
| """ |
|
|
| from typing import Iterator |
| import torch |
| from torch import nn |
| from torch.nn import functional as F |
| from copy import deepcopy |
| import numpy as np |
| from torch.nn.parameter import Parameter |
| from torch.utils.data import DataLoader, Dataset |
| import PIL |
| import os |
| import copy |
|
|
| class Model(nn.Module): |
| |
| def __init__(self, backbone, feat_dim, num_class): |
| super().__init__() |
| self.backbone = backbone |
| self.feat_dim = feat_dim |
| self.num_class = num_class |
| self.classifier = nn.Linear(feat_dim, num_class) |
| |
| def forward(self, x): |
| return self.get_logits(x) |
| |
| def get_logits(self, x): |
| logits = self.classifier(self.backbone(x)['features']) |
| return logits |
| |
| |
|
|
| class ICarl(nn.Module): |
| def __init__(self, backbone, feat_dim, num_class, **kwargs): |
| super().__init__() |
|
|
| |
| self.device = kwargs['device'] |
| |
| |
| self.cur_task_id = 0 |
|
|
| |
| self.cur_cls_indexes = None |
| |
| |
| self.network = Model(backbone, feat_dim, num_class) |
| |
| |
| self.old_network = None |
| |
| |
| self.prev_cls_num = 0 |
|
|
| |
| self.accu_cls_num = 0 |
|
|
| |
| self.init_cls_num = kwargs['init_cls_num'] |
| self.inc_cls_num = kwargs['inc_cls_num'] |
| self.task_num = kwargs['task_num'] |
|
|
| |
| self.class_means = None |
|
|
|
|
| |
| def get_parameters(self, config): |
| return self.network.parameters() |
| |
| |
| def observe(self, data): |
| |
| x, y = data['image'], data['label'] |
| x = x.to(self.device) |
| y = y.to(self.device) |
| |
| |
| logits, loss = self.criterion(x, y) |
|
|
| pred = torch.argmax(logits, dim=1) |
| acc = torch.sum(pred == y).item() |
|
|
| return pred, acc / x.size(0), loss |
|
|
|
|
| def inference(self, data): |
| |
| |
| |
| |
| if self.class_means is not None and len(self.class_means) == self.accu_cls_num: |
| |
| return self.NCM_classify(data) |
| |
| else: |
| |
| |
| |
| |
| x, y = data['image'], data['label'] |
| x = x.to(self.device) |
| y = y.to(self.device) |
|
|
| logits = self.network(x)[:, :self.accu_cls_num] |
| pred = torch.argmax(logits, dim=1) |
|
|
| acc = torch.sum(pred == y).item() |
| return pred, acc / x.size(0) |
| |
| |
|
|
| def NCM_classify(self, data): |
|
|
| def metric(x, y): |
| """Calculate the pair-wise euclidean distance between input tensor `x` and `y`. |
| Args: |
| x (Tensor): to be calculated for distance, with shape (N, D) |
| y (Tensor): to be calculated for distance, with shape (M, D), where D is embedding size. |
| |
| Returns: |
| pair euclidean distance tensor with shape (N, M) |
| and dist[i][j] represent the distance between x[i] and y[j] |
| """ |
| n = x.size(0) |
| m = y.size(0) |
| x = x.unsqueeze(1).expand(n, m, -1) |
| y = y.unsqueeze(0).expand(n, m, -1) |
| return torch.pow(x - y, 2).sum(2) |
|
|
| |
| x, y = data['image'], data['label'] |
| x = x.to(self.device) |
| y = y.to(self.device) |
|
|
| feats = feats = self.network.backbone(x)['features'] |
| feats = feats.view(feats.size(0), -1) |
| distance = metric(feats, self.class_means) |
|
|
| pred = torch.argmin(distance, dim=1) |
| acc = torch.sum(pred == y).item() |
|
|
| return pred, acc / x.size(0) |
|
|
|
|
| def forward(self, x): |
| return self.network(x)[:, self.accu_cls_num] |
| |
| |
| def before_task(self, task_idx, buffer, train_loader, test_loaders): |
| if self.cur_task_id == 0: |
| self.accu_cls_num = self.init_cls_num |
| else: |
| self.accu_cls_num += self.inc_cls_num |
| |
| self.cur_cls_indexes = np.arange(self.prev_cls_num, self.accu_cls_num) |
|
|
|
|
|
|
| def after_task(self, task_idx, buffer, train_loader, test_loaders): |
| |
| |
| self.old_network = copy.deepcopy(self.network) |
| self.old_network.eval() |
| |
| self.prev_cls_num = self.accu_cls_num |
| |
| |
| buffer.reduce_old_data(self.cur_task_id, self.accu_cls_num) |
| |
|
|
| val_transform = test_loaders[0].dataset.trfms |
| buffer.update(self.network, train_loader, val_transform, |
| self.cur_task_id, self.accu_cls_num, self.cur_cls_indexes, |
| self.device) |
| |
| |
| self.class_means = self.calc_class_mean(buffer, |
| train_loader, |
| val_transform, |
| self.device).to(self.device) |
| self.cur_task_id += 1 |
| |
| |
|
|
| |
|
|
| def criterion(self, x, y): |
| def _KD_loss(pred, soft, T=2): |
| """ |
| Compute the knowledge distillation (KD) loss between the predicted logits and the soft target. |
| Code Reference: |
| KD loss function is borrowed from: https://github.com/G-U-N/PyCIL/blob/master/models/icarl.py |
| """ |
| pred = torch.log_softmax(pred / T, dim=1) |
| soft = torch.softmax(soft / T, dim=1) |
| return -1 * torch.mul(soft, pred).sum() / pred.shape[0] |
|
|
| cur_logits = self.network(x)[:, :self.accu_cls_num] |
| loss_clf = F.cross_entropy(cur_logits, y) |
|
|
| if self.cur_task_id > 0: |
| old_logits = self.old_network(x) |
| loss_kd = _KD_loss( |
| cur_logits[:, : self.prev_cls_num], |
| old_logits[:, : self.prev_cls_num], |
| ) |
| loss = loss_clf + loss_kd |
| else: |
| loss = loss_clf |
|
|
| return cur_logits, loss |
|
|
|
|
|
|
|
|
| def calc_class_mean(self, buffer, train_loader, val_transform, device): |
|
|
| |
| class miniBufferDataset(Dataset): |
| def __init__(self, root, mode, image_list, label_list, transforms): |
| self.data_root = root |
| self.mode = mode |
| self.images = image_list |
| self.labels = label_list |
| self.transforms = transforms |
| |
| def __getitem__(self, idx): |
| img_path = self.images[idx] |
| label = self.labels[idx] |
| image = PIL.Image.open(os.path.join(self.data_root, self.mode, img_path)).convert("RGB") |
| image = self.transforms(image) |
| return {"image": image, "label": label} |
|
|
| def __len__(self): |
| return len(self.labels) |
| |
| root_path = train_loader.dataset.data_root |
| mode = train_loader.dataset.mode |
| image_list = buffer.images |
| label_list = buffer.labels |
| ds = miniBufferDataset(root_path, mode, image_list, label_list, val_transform) |
|
|
| icarl_loader = DataLoader(ds, |
| batch_size=train_loader.batch_size, |
| shuffle=False, |
| num_workers=train_loader.num_workers, |
| pin_memory=train_loader.pin_memory) |
|
|
| |
| |
| extracted_features = [] |
| extracted_targets = [] |
| with torch.no_grad(): |
| self.network.eval() |
| for data in icarl_loader: |
| images = data['image'].to(device) |
| labels = data['label'].to(device) |
| feats = self.network.backbone(images)['features'] |
| |
| extracted_features.append(feats / feats.norm(dim=1).view(-1, 1)) |
| extracted_targets.extend(labels) |
|
|
| extracted_features = torch.cat(extracted_features).cpu() |
| extracted_targets = torch.stack(extracted_targets).cpu() |
|
|
| all_class_means = [] |
| for curr_cls in np.unique(extracted_targets): |
| |
| cls_ind = np.where(extracted_targets == curr_cls)[0] |
| |
| cls_feats = extracted_features[cls_ind] |
| |
| cls_feats_mean = cls_feats.mean(0) / cls_feats.mean(0).norm() |
| |
| all_class_means.append(cls_feats_mean) |
| |
| return torch.stack(all_class_means) |
|
|