boringKey's picture
Upload 236 files
5fee096 verified
'''
@misc{mcdonnell2024ranpacrandomprojectionspretrained,
title={RanPAC: Random Projections and Pre-trained Models for Continual Learning},
author={Mark D. McDonnell and Dong Gong and Amin Parveneh and Ehsan Abbasnejad and Anton van den Hengel},
year={2024},
eprint={2307.02251},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2307.02251},
}
Code Reference:
https://github.com/RanPAC/RanPAC
'''
import copy
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from .backbone.transformer import MultiHeadAttention_LoRA, VisionTransformer
from .backbone.clip import CLIP, tokenize
from .backbone.vit import ViTZoo, ViT_in21k_adapter
VIT = ViT_in21k_adapter
CLIP = CLIP
class CosineLinear(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.Tensor(self.out_features, in_features))
self.sigma = nn.Parameter(torch.Tensor(1))
self.reset_parameters()
self.use_RP = False
self.W_rand = None
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
self.sigma.data.fill_(1)
def forward(self, input):
if not self.use_RP:
out = F.linear(F.normalize(input, p=2, dim=1), F.normalize(self.weight, p=2, dim=1))
else:
if self.W_rand is not None:
inn = F.relu(input @ self.W_rand)
else:
assert 0, 'should not reach here, for now'
inn = input
out = F.linear(inn, self.weight)
out = self.sigma * out
return out
class Network(nn.Module):
def __init__(self, backbone, device, **kwargs):
super().__init__()
self._cur_task_id = -1
self.backbone = backbone
self.device = device
self.classifier = None
if isinstance(self.backbone, VIT):
self.feature_dim = self.backbone.feat_dim
elif isinstance(self.backbone, CLIP):
# Assuming the final features_dim is concat of image and text
self.feature_dim = self.backbone.visual.output_dim + self.backbone.transformer.width
self.accm_class_names = []
self.curr_class_names = []
self.accm_text_tokens = None
self.curr_text_tokens = None
self.prompt_template = kwargs['prompt_template']
def update_classifer(self, num_classes, train_loader):
if isinstance(self.backbone, VIT):
pass
elif isinstance(self.backbone, CLIP):
self.curr_class_names = train_loader.dataset.get_class_names()
self.accm_class_names += self.curr_class_names
self.curr_text_tokens = tokenize(
[self.prompt_template.format(c) for c in self.curr_class_names]
).to(self.device)
self.accm_text_tokens = tokenize(
[self.prompt_template.format(c) for c in self.accm_class_names]
).to(self.device)
else:
assert 0
self._cur_task_id += 1
del self.classifier
self.classifier = CosineLinear(self.feature_dim, num_classes).to(self.device)
def get_feature(self, x):
if isinstance(self.backbone, VIT):
return self.backbone(x)
elif isinstance(self.backbone, CLIP):
features_image, features_text, logits_per_image, logits_per_text = self.backbone(x, self.curr_text_tokens)
max_indices = logits_per_image.softmax(dim=-1).argmax(dim=1) # Shape will be [48]
max_features = features_text[max_indices] # Shape will be [48, 768]
return torch.cat([features_image, max_features], dim=1) # Shape will be [48, 1536]
else:
assert 0
def forward(self, x, inference=False):
if isinstance(self.backbone, VIT):
features = self.backbone(x)
elif isinstance(self.backbone, CLIP):
if inference:
features_image, features_text, logits_per_image, logits_per_text = self.backbone(x, self.accm_text_tokens)
else:
features_image, features_text, logits_per_image, logits_per_text = self.backbone(x, self.curr_text_tokens)
max_indices = logits_per_image.softmax(dim=-1).argmax(dim=1) # Shape will be [48]
max_features = features_text[max_indices] # Shape will be [48, 768]
features = torch.cat([features_image, max_features], dim=1) # Shape will be [48, 1536]
else:
assert 0
return self.classifier(features)
class RanPAC(nn.Module):
def __init__(self, backbone, device, **kwargs):
super().__init__()
self._network = Network(backbone, device, **kwargs)
self.device = device
self.first_session_training = kwargs["first_session_training"]
self.init_cls_num = kwargs["init_cls_num"]
self.inc_cls_num = kwargs["inc_cls_num"]
self.total_cls_num = kwargs['total_cls_num']
self.task_num = kwargs["task_num"]
#self.use_RP = kwargs["use_RP"]
self.M = kwargs['M']
self._known_classes = 0
self._classes_seen_so_far = 0
self._skip_train = False # this flag is used to skip training
self._network.to(self.device)
if isinstance(backbone, CLIP):
for name, param in self._network.named_parameters():
if 'adapt' not in name:
param.requires_grad = False
def before_task(self, task_idx, buffer, train_loader, test_loaders):
if task_idx == 0:
self._classes_seen_so_far = self.init_cls_num
elif task_idx > 0:
self._classes_seen_so_far += self.inc_cls_num
self._network.update_classifer(self._classes_seen_so_far, train_loader)
if task_idx == 0 and self.first_session_training:
self._skip_train = False
else:
self._skip_train = True
print(f"Not training on task {task_idx}")
def observe(self, data):
if self._skip_train:
# set required_grad be True so that it can call backward() but don't do anything
return None, 0., torch.tensor(0., device = self.device, requires_grad = True)
inputs, targets = data['image'].to(self.device), data['label'].to(self.device) - self._known_classes
logits = self._network(inputs)
loss = F.cross_entropy(logits, targets)
_, preds = torch.max(logits, dim=1)
correct = preds.eq(targets.expand_as(preds)).sum().item()
total = len(targets)
acc = round(correct / total, 4)
return preds, acc, loss
def inference(self, data):
inputs, targets = data['image'].to(self.device), data['label']
logits = self._network(inputs, True)
_, preds = torch.max(logits, dim=1)
correct = preds.cpu().eq(targets.expand_as(preds)).sum().item()
total = len(targets)
acc = round(correct / total, 4)
return logits, acc
def after_task(self, task_idx, buffer, train_loader, test_loaders):
self._known_classes = self._classes_seen_so_far
if task_idx == 0:
# Initialize attribute for random projection classifier
self.W_rand = torch.randn(self._network.classifier.in_features, self.M)
self.Q = torch.zeros(self.M, self.init_cls_num) # C
self.G = torch.zeros(self.M, self.M)
else:
self.Q = torch.cat((self.Q, torch.zeros(self.M, self.inc_cls_num)), dim=1)
self.update_rp_classifier(train_loader, test_loaders[0].dataset.trfms)
@torch.no_grad()
def update_rp_classifier(self, train_loader, test_trfms):
self._network.eval()
train_loader.dataset.trfms = test_trfms
self._network.classifier.use_RP = True
self._network.classifier.W_rand = self.W_rand.to(self.device) # feature_dim x M
feature_list, label_list = [], []
for batch in train_loader:
x, y = batch['image'].to(self.device), batch['label']
feature_list.append(self._network.get_feature(x).cpu())
label_list.append(y)
feature_list, label_list = torch.cat(feature_list, dim = 0), torch.cat(label_list, dim = 0)
label_list = F.one_hot(label_list, self._classes_seen_so_far).to(torch.float32)
proj_feature_list = F.relu(feature_list @ self.W_rand)
self.Q += proj_feature_list.T @ label_list
self.G += proj_feature_list.T @ proj_feature_list
ridges = 10.0**np.arange(-8,9)
num_val_samples = int(proj_feature_list.shape[0] * 0.8)
losses = []
Q_val = proj_feature_list[:num_val_samples, :].T @ label_list[:num_val_samples, :]
G_val = proj_feature_list[:num_val_samples, :].T @ proj_feature_list[:num_val_samples, :]
for ridge in ridges:
Wo = torch.linalg.solve(G_val + ridge * torch.eye(self.M), Q_val).T #better nmerical stability than .inv
Y_train_pred = proj_feature_list[num_val_samples:, :] @ Wo.T
losses.append(F.mse_loss(Y_train_pred, label_list[num_val_samples:, :]))
ridge = ridges[np.argmin(np.array(losses))]
print(f"Optimal lambda: {ridge}")
Wo = torch.linalg.solve(self.G + ridge * torch.eye(self.M), self.Q).T #better nmerical stability than .inv
self._network.classifier.weight.data = Wo[:self._network.classifier.weight.shape[0], :].to(self.device) # num_classes x M
def get_parameters(self, config):
return self._network.parameters()