| | ''' |
| | @misc{mcdonnell2024ranpacrandomprojectionspretrained, |
| | title={RanPAC: Random Projections and Pre-trained Models for Continual Learning}, |
| | author={Mark D. McDonnell and Dong Gong and Amin Parveneh and Ehsan Abbasnejad and Anton van den Hengel}, |
| | year={2024}, |
| | eprint={2307.02251}, |
| | archivePrefix={arXiv}, |
| | primaryClass={cs.LG}, |
| | url={https://arxiv.org/abs/2307.02251}, |
| | } |
| | |
| | Code Reference: |
| | https://github.com/RanPAC/RanPAC |
| | ''' |
| |
|
| | import copy |
| | import math |
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | from .backbone.transformer import MultiHeadAttention_LoRA, VisionTransformer |
| | from .backbone.clip import CLIP, tokenize |
| | from .backbone.vit import ViTZoo, ViT_in21k_adapter |
| |
|
| | VIT = ViT_in21k_adapter |
| | CLIP = CLIP |
| |
|
| | class CosineLinear(nn.Module): |
| | def __init__(self, in_features, out_features): |
| |
|
| | super().__init__() |
| | self.in_features = in_features |
| | self.out_features = out_features |
| | self.weight = nn.Parameter(torch.Tensor(self.out_features, in_features)) |
| | self.sigma = nn.Parameter(torch.Tensor(1)) |
| | self.reset_parameters() |
| |
|
| | self.use_RP = False |
| | self.W_rand = None |
| |
|
| | def reset_parameters(self): |
| |
|
| | stdv = 1. / math.sqrt(self.weight.size(1)) |
| | self.weight.data.uniform_(-stdv, stdv) |
| | self.sigma.data.fill_(1) |
| |
|
| | def forward(self, input): |
| |
|
| | if not self.use_RP: |
| | out = F.linear(F.normalize(input, p=2, dim=1), F.normalize(self.weight, p=2, dim=1)) |
| | else: |
| | if self.W_rand is not None: |
| | inn = F.relu(input @ self.W_rand) |
| | else: |
| | assert 0, 'should not reach here, for now' |
| | inn = input |
| | out = F.linear(inn, self.weight) |
| |
|
| | out = self.sigma * out |
| |
|
| | return out |
| |
|
| | class Network(nn.Module): |
| | def __init__(self, backbone, device, **kwargs): |
| | super().__init__() |
| |
|
| | self._cur_task_id = -1 |
| | self.backbone = backbone |
| | self.device = device |
| | self.classifier = None |
| |
|
| | if isinstance(self.backbone, VIT): |
| | self.feature_dim = self.backbone.feat_dim |
| | elif isinstance(self.backbone, CLIP): |
| | |
| | self.feature_dim = self.backbone.visual.output_dim + self.backbone.transformer.width |
| | self.accm_class_names = [] |
| | self.curr_class_names = [] |
| | self.accm_text_tokens = None |
| | self.curr_text_tokens = None |
| |
|
| | self.prompt_template = kwargs['prompt_template'] |
| |
|
| | def update_classifer(self, num_classes, train_loader): |
| |
|
| | if isinstance(self.backbone, VIT): |
| | pass |
| | elif isinstance(self.backbone, CLIP): |
| | self.curr_class_names = train_loader.dataset.get_class_names() |
| | self.accm_class_names += self.curr_class_names |
| |
|
| | self.curr_text_tokens = tokenize( |
| | [self.prompt_template.format(c) for c in self.curr_class_names] |
| | ).to(self.device) |
| |
|
| | self.accm_text_tokens = tokenize( |
| | [self.prompt_template.format(c) for c in self.accm_class_names] |
| | ).to(self.device) |
| | else: |
| | assert 0 |
| |
|
| | self._cur_task_id += 1 |
| | del self.classifier |
| | self.classifier = CosineLinear(self.feature_dim, num_classes).to(self.device) |
| |
|
| | def get_feature(self, x): |
| |
|
| | if isinstance(self.backbone, VIT): |
| | return self.backbone(x) |
| | elif isinstance(self.backbone, CLIP): |
| | features_image, features_text, logits_per_image, logits_per_text = self.backbone(x, self.curr_text_tokens) |
| |
|
| | max_indices = logits_per_image.softmax(dim=-1).argmax(dim=1) |
| | max_features = features_text[max_indices] |
| |
|
| | return torch.cat([features_image, max_features], dim=1) |
| | else: |
| | assert 0 |
| |
|
| | def forward(self, x, inference=False): |
| |
|
| | if isinstance(self.backbone, VIT): |
| | features = self.backbone(x) |
| | elif isinstance(self.backbone, CLIP): |
| | if inference: |
| | features_image, features_text, logits_per_image, logits_per_text = self.backbone(x, self.accm_text_tokens) |
| | else: |
| | features_image, features_text, logits_per_image, logits_per_text = self.backbone(x, self.curr_text_tokens) |
| |
|
| | max_indices = logits_per_image.softmax(dim=-1).argmax(dim=1) |
| | max_features = features_text[max_indices] |
| | features = torch.cat([features_image, max_features], dim=1) |
| | else: |
| | assert 0 |
| |
|
| | return self.classifier(features) |
| |
|
| | class RanPAC(nn.Module): |
| | def __init__(self, backbone, device, **kwargs): |
| | super().__init__() |
| |
|
| | self._network = Network(backbone, device, **kwargs) |
| |
|
| | self.device = device |
| | self.first_session_training = kwargs["first_session_training"] |
| | self.init_cls_num = kwargs["init_cls_num"] |
| | self.inc_cls_num = kwargs["inc_cls_num"] |
| | self.total_cls_num = kwargs['total_cls_num'] |
| | self.task_num = kwargs["task_num"] |
| | |
| | self.M = kwargs['M'] |
| |
|
| | self._known_classes = 0 |
| | self._classes_seen_so_far = 0 |
| | self._skip_train = False |
| |
|
| | self._network.to(self.device) |
| |
|
| | if isinstance(backbone, CLIP): |
| | for name, param in self._network.named_parameters(): |
| | if 'adapt' not in name: |
| | param.requires_grad = False |
| |
|
| |
|
| | def before_task(self, task_idx, buffer, train_loader, test_loaders): |
| |
|
| | if task_idx == 0: |
| | self._classes_seen_so_far = self.init_cls_num |
| | elif task_idx > 0: |
| | self._classes_seen_so_far += self.inc_cls_num |
| | |
| | self._network.update_classifer(self._classes_seen_so_far, train_loader) |
| |
|
| | if task_idx == 0 and self.first_session_training: |
| | self._skip_train = False |
| | else: |
| | self._skip_train = True |
| | print(f"Not training on task {task_idx}") |
| |
|
| | def observe(self, data): |
| |
|
| | if self._skip_train: |
| | |
| | return None, 0., torch.tensor(0., device = self.device, requires_grad = True) |
| |
|
| | inputs, targets = data['image'].to(self.device), data['label'].to(self.device) - self._known_classes |
| |
|
| | logits = self._network(inputs) |
| | loss = F.cross_entropy(logits, targets) |
| |
|
| | _, preds = torch.max(logits, dim=1) |
| | correct = preds.eq(targets.expand_as(preds)).sum().item() |
| | total = len(targets) |
| |
|
| | acc = round(correct / total, 4) |
| |
|
| | return preds, acc, loss |
| |
|
| | def inference(self, data): |
| |
|
| | inputs, targets = data['image'].to(self.device), data['label'] |
| | logits = self._network(inputs, True) |
| | _, preds = torch.max(logits, dim=1) |
| |
|
| | correct = preds.cpu().eq(targets.expand_as(preds)).sum().item() |
| | total = len(targets) |
| |
|
| | acc = round(correct / total, 4) |
| |
|
| | return logits, acc |
| |
|
| | def after_task(self, task_idx, buffer, train_loader, test_loaders): |
| |
|
| | self._known_classes = self._classes_seen_so_far |
| |
|
| | if task_idx == 0: |
| | |
| | |
| | self.W_rand = torch.randn(self._network.classifier.in_features, self.M) |
| | self.Q = torch.zeros(self.M, self.init_cls_num) |
| | self.G = torch.zeros(self.M, self.M) |
| |
|
| | else: |
| | self.Q = torch.cat((self.Q, torch.zeros(self.M, self.inc_cls_num)), dim=1) |
| |
|
| | self.update_rp_classifier(train_loader, test_loaders[0].dataset.trfms) |
| |
|
| | @torch.no_grad() |
| | def update_rp_classifier(self, train_loader, test_trfms): |
| |
|
| | self._network.eval() |
| | train_loader.dataset.trfms = test_trfms |
| |
|
| | self._network.classifier.use_RP = True |
| | self._network.classifier.W_rand = self.W_rand.to(self.device) |
| |
|
| | feature_list, label_list = [], [] |
| | for batch in train_loader: |
| | x, y = batch['image'].to(self.device), batch['label'] |
| | feature_list.append(self._network.get_feature(x).cpu()) |
| | label_list.append(y) |
| | feature_list, label_list = torch.cat(feature_list, dim = 0), torch.cat(label_list, dim = 0) |
| | |
| | label_list = F.one_hot(label_list, self._classes_seen_so_far).to(torch.float32) |
| | |
| | proj_feature_list = F.relu(feature_list @ self.W_rand) |
| |
|
| | self.Q += proj_feature_list.T @ label_list |
| | self.G += proj_feature_list.T @ proj_feature_list |
| | |
| | ridges = 10.0**np.arange(-8,9) |
| | num_val_samples = int(proj_feature_list.shape[0] * 0.8) |
| | losses = [] |
| | Q_val = proj_feature_list[:num_val_samples, :].T @ label_list[:num_val_samples, :] |
| | G_val = proj_feature_list[:num_val_samples, :].T @ proj_feature_list[:num_val_samples, :] |
| | for ridge in ridges: |
| | Wo = torch.linalg.solve(G_val + ridge * torch.eye(self.M), Q_val).T |
| | Y_train_pred = proj_feature_list[num_val_samples:, :] @ Wo.T |
| | losses.append(F.mse_loss(Y_train_pred, label_list[num_val_samples:, :])) |
| | ridge = ridges[np.argmin(np.array(losses))] |
| | print(f"Optimal lambda: {ridge}") |
| |
|
| | Wo = torch.linalg.solve(self.G + ridge * torch.eye(self.M), self.Q).T |
| | self._network.classifier.weight.data = Wo[:self._network.classifier.weight.shape[0], :].to(self.device) |
| |
|
| | def get_parameters(self, config): |
| | return self._network.parameters() |