| | """ |
| | @misc{wu2025sdlorascalabledecoupledlowrank, |
| | title={SD-LoRA: Scalable Decoupled Low-Rank Adaptation for Class Incremental Learning}, |
| | author={Yichen Wu and Hongming Piao and Long-Kai Huang and Renzhen Wang and Wanhua Li and Hanspeter Pfister and Deyu Meng and Kede Ma and Ying Wei}, |
| | year={2025}, |
| | eprint={2501.13198}, |
| | archivePrefix={arXiv}, |
| | primaryClass={cs.LG}, |
| | url={https://arxiv.org/abs/2501.13198}, |
| | } |
| | |
| | Adapted from https://github.com/WuYichen-97/SD-Lora-CL |
| | |
| | """ |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import copy |
| | import numpy as np |
| |
|
| | from torch.nn import functional as F |
| | from .backbone.transformer import MultiHeadAttention_SDLoRA |
| |
|
| | class Model(nn.Module): |
| | def __init__(self, backbone, device, **kwargs): |
| | super().__init__() |
| |
|
| | self._cur_task_id = -1 |
| | self.backbone = backbone |
| | self.device = device |
| | self.embed_dim = kwargs["embd_dim"] |
| | self.init_cls_num = kwargs['init_cls_num'] |
| | self.inc_cls_num = kwargs['inc_cls_num'] |
| |
|
| | def update_fc(self): |
| | |
| | self._cur_task_id += 1 |
| | if self._cur_task_id == 0: |
| | classifier = nn.Linear(self.embed_dim, self.init_cls_num, bias=True) |
| |
|
| | nn.init.kaiming_uniform_(classifier.weight, nonlinearity='linear') |
| | nn.init.constant_(classifier.bias, 0) |
| | else: |
| | classifier = nn.Linear(self.embed_dim, self.init_cls_num + self.inc_cls_num * (self._cur_task_id), bias=True) |
| |
|
| | nn.init.kaiming_uniform_(classifier.weight, nonlinearity='linear') |
| | nn.init.constant_(classifier.bias, 0) |
| |
|
| | nb_output = self.classifier.out_features |
| | classifier.weight.data[:nb_output] = copy.deepcopy(self.classifier.weight.data) |
| | classifier.bias.data[:nb_output] = copy.deepcopy(self.classifier.bias.data) |
| | del self.classifier |
| |
|
| | self.classifier = classifier |
| |
|
| | def forward(self, x, inference = False): |
| |
|
| | features = self.backbone(x) |
| | logits = self.classifier(features) |
| | return logits |
| |
|
| | class SD_LoRA(nn.Module): |
| |
|
| | def __init__(self, backbone, device, **kwargs): |
| |
|
| | super().__init__() |
| |
|
| | self.device = device |
| | self.init_cls_num = kwargs["init_cls_num"] |
| | self.inc_cls_num = kwargs["inc_cls_num"] |
| | self.task_num = kwargs["task_num"] |
| | self.init_mag = kwargs['init_mag'] |
| | self.rank_reduction = kwargs['rank_reduction'] |
| | self.knowledge_dist = kwargs['knowledge_dist'] |
| | self._known_classes = 0 |
| |
|
| | self._network = Model(backbone, device, **kwargs) |
| | self.attention_modules = [module for module in self._network.modules() if isinstance(module, MultiHeadAttention_SDLoRA)] |
| |
|
| | def observe(self, data): |
| | |
| | x, y = data['image'].to(self.device), data['label'].to(self.device) |
| |
|
| | logits = self._network(x) |
| |
|
| | |
| | fake_y = y - self._known_classes |
| | loss = F.cross_entropy(logits[:, self._known_classes:], fake_y) |
| |
|
| | preds = logits.max(1)[1] |
| | correct_count = preds.eq(y).sum().item() |
| | acc = correct_count / y.size(0) |
| |
|
| | return preds, acc, loss |
| | |
| | def inference(self, data): |
| |
|
| | x, y = data['image'].to(self.device), data['label'].to(self.device) |
| | logits = self._network(x, inference = True) |
| | preds = logits.max(1)[1] |
| |
|
| | correct_count = preds.eq(y).sum().item() |
| | acc = correct_count / y.size(0) |
| |
|
| | return preds, acc |
| | |
| | @torch.no_grad() |
| | def before_task(self, task_idx, buffer, train_loader, test_loaders): |
| |
|
| | self._network.update_fc() |
| |
|
| | if self.rank_reduction[0]: |
| | if task_idx == self.rank_reduction[1]: |
| | for module in self.attention_modules: |
| | module.lora_rank = self.rank_reduction[3] |
| |
|
| | elif task_idx == self.rank_reduction[2]: |
| | for module in self.attention_modules: |
| | module.lora_rank = self.rank_reduction[4] |
| | |
| | |
| | mag = nn.ParameterList([nn.Parameter(torch.Tensor([self.init_mag])) for _ in range(task_idx + 1)]) |
| | for module in self.attention_modules: |
| | module.mag_lora = mag |
| | module.init_param() |
| |
|
| | self._network = self._network.to(self.device) |
| |
|
| | unfrezeed_params = [] |
| | for name, param in self._network.named_parameters(): |
| | param.requires_grad_(False) |
| | if f'classifier' in name or \ |
| | f'lora' and f'list.{task_idx}' in name or \ |
| | ('mag' in name and 'assimilated' not in name): |
| | param.requires_grad_(True) |
| | unfrezeed_params.append(name) |
| |
|
| | print(f"Current task : {task_idx}, Parameters to be updated: {len(unfrezeed_params)}") |
| |
|
| | @torch.no_grad() |
| | def after_task(self, task_idx, buffer, train_loader, test_loaders): |
| |
|
| | self._known_classes += self.init_cls_num if task_idx == 0 else self.inc_cls_num |
| |
|
| | if self.knowledge_dist[0] and task_idx > 0: |
| | for layer, module in enumerate(self.attention_modules): |
| |
|
| | dirs_q, dirs_v = [], [] |
| | for i in range(len(module.lora_A_q_list)): |
| |
|
| | norm_B = torch.norm(module.lora_B_q_list[i].weight) |
| | norm_A = torch.norm(module.lora_A_q_list[i].weight) |
| |
|
| | if norm_A != 0 and norm_B != 0: |
| | dirs_q.append( |
| | (module.lora_B_q_list[i].weight @ module.lora_A_q_list[i].weight) / (norm_B * norm_A) |
| | ) |
| | else: |
| | dirs_q.append( |
| | (module.lora_B_q_list[i].weight @ module.lora_A_q_list[i].weight) |
| | ) |
| |
|
| | norm_B = torch.norm(module.lora_B_v_list[i].weight) |
| | norm_A = torch.norm(module.lora_A_v_list[i].weight) |
| |
|
| | if norm_A != 0 and norm_B != 0: |
| | dirs_v.append( |
| | (module.lora_B_v_list[i].weight @ module.lora_A_v_list[i].weight) / (norm_B * norm_A) |
| | ) |
| | else: |
| | dirs_v.append( |
| | (module.lora_B_q_list[i].weight @ module.lora_A_q_list[i].weight) |
| | ) |
| |
|
| | flatten_dirs = [dir_q.flatten() for dir_q in dirs_q] |
| |
|
| | last_dir = flatten_dirs[-1].unsqueeze(1) |
| | prev_dirs = torch.stack(flatten_dirs[:-1], dim=-1) |
| |
|
| | alphas = torch.linalg.lstsq(prev_dirs, last_dir) |
| |
|
| | if alphas.residuals < self.knowledge_dist[1]: |
| | print(f'Layer {layer}: {alphas.residuals.item()} < {self.knowledge_dist[1]}, Q Merged with {alphas.solution}') |
| |
|
| | assert prev_dirs.shape[1] == len(module.assimilated_mag_lora_q) - 1 |
| | for ii in range(prev_dirs.shape[1]): |
| | module.assimilated_mag_lora_q[ii] += alphas.solution[i] |
| |
|
| | nn.init.zeros_(module.lora_B_q_list[task_idx]) |
| | nn.init.zeros_(module.lora_A_q_list[task_idx]) |
| |
|
| | flatten_dirs = [dir_v.flatten() for dir_v in dirs_v] |
| |
|
| | last_dir = flatten_dirs[-1].unsqueeze(1) |
| | prev_dirs = torch.stack(flatten_dirs[:-1], dim=-1) |
| |
|
| | alphas = torch.linalg.lstsq(prev_dirs, last_dir) |
| |
|
| | if alphas.residuals < self.knowledge_dist[1]: |
| | print(f'Layer {layer}: {alphas.residuals.item()} < {self.knowledge_dist[1]}, V Merged with {alphas.solution}') |
| |
|
| | assert prev_dirs.shape[1] == len(module.assimilated_mag_lora_v) - 1 |
| | for ii in range(prev_dirs.shape[1]): |
| | module.assimilated_mag_lora_v[ii] += alphas.solution[i] |
| |
|
| | nn.init.zeros_(module.lora_B_v_list[task_idx]) |
| | nn.init.zeros_(module.lora_A_v_list[task_idx]) |
| |
|
| | def get_parameters(self, config): |
| | return self._network.parameters() |