| | """ |
| | @inproceedings{10.24963/ijcai.2024/456, |
| | author = {Hong, Chenxing and Jin, Yan and Kang, Zhiqi and Chen, Yizhou and Li, Mengke and Lu, Yang and Wang, Hanzi}, |
| | title = {Dynamically anchored prompting for task-imbalanced continual learning}, |
| | booktitle = {Proceedings of the Thirty-Third International Joint Conference on Artificial Intelligence}, |
| | year = {2025}, |
| | } |
| | https://dl.acm.org/doi/10.24963/ijcai.2024/456 |
| | Adapted from https://github.com/chenxing6666/dap |
| | """ |
| |
|
| | import math |
| | import copy |
| | import torch |
| | import torch.nn.functional as F |
| | from .finetune import Finetune |
| | import numpy as np |
| | from torch.utils.data import DataLoader |
| |
|
| | global_max_dist = torch.tensor(0) |
| | global_max_dist2 = torch.tensor(0) |
| | global_lam = 0.25 |
| |
|
| |
|
| | class DAP(Finetune): |
| | def __init__(self, backbone, feat_dim, num_class, **kwargs): |
| | super().__init__(backbone, feat_dim, num_class, **kwargs) |
| | self.kwargs = kwargs |
| | self.network = backbone |
| | self.train_mask = kwargs['train_mask'] |
| | self.task_inc = kwargs['task_inc'] |
| | self.pull_constraint = kwargs['pull_constraint'] |
| | self.pull_constraint_coeff = kwargs['pull_constraint_coeff'] |
| |
|
| | self.task_idx = 0 |
| | self.task_data_count = [] |
| | self.prompt_center = None |
| |
|
| | |
| | if self.num_class % kwargs['task_num'] != 0: |
| | raise ValueError('Number of classes must be divisible by number of tasks') |
| | classes_per_task = self.num_class // kwargs['task_num'] |
| | self.class_mask = [list(range(i * classes_per_task, (i + 1) * classes_per_task)) for i in range(kwargs['task_num'])] |
| |
|
| | self.original_model = copy.deepcopy(self.backbone) |
| | self.original_model.to(self.device) |
| | self.original_model.eval() |
| |
|
| | if kwargs['freeze']: |
| | |
| | for p in self.original_model.parameters(): |
| | p.requires_grad = False |
| |
|
| | |
| | for n, p in self.network.named_parameters(): |
| | if n.startswith(tuple(kwargs['freeze'])): |
| | p.requires_grad = False |
| |
|
| | self.loss_fn.to(self.device) |
| |
|
| | def observe(self, data, train_gprompt=False, gen=False): |
| | x, y = data['image'], data['label'] |
| | x = x.to(self.device) |
| | y = y.to(self.device) |
| |
|
| | with torch.no_grad(): |
| | if self.original_model is not None: |
| | output = self.original_model(x) |
| | cls_features = output['pre_logits'] |
| | else: |
| | cls_features = None |
| | if gen: |
| | output = self.network(x, task_id=self.task_idx, cls_features=cls_features, train=True, gen=gen) |
| | else: |
| | output = self.network(x, task_id=self.task_idx, cls_features=cls_features, train=True) |
| | logits = output['logits'] |
| |
|
| | |
| | if self.train_mask and self.class_mask is not None: |
| | mask = self.class_mask[self.task_idx] |
| | not_mask = np.setdiff1d(np.arange(self.num_class), mask) |
| | not_mask = torch.tensor(not_mask, dtype=torch.int64).to(self.device) |
| | logits = logits.index_fill( |
| | dim=1, index=not_mask, value=float('-inf')) |
| |
|
| | if (train_gprompt): |
| |
|
| | pla_similarity_loss_res = self.cal_latestsimilarity_loss( |
| | model=self.network, task_id=self.task_idx) |
| | sta_similarity_loss_res = self.cal_similarity_loss(model=self.network, task_id=self.task_idx, prompt_center=self.prompt_center) |
| |
|
| | pla_similarity_loss = pla_similarity_loss_res['similarity'] |
| | sta_similarity_loss = sta_similarity_loss_res['avg_similarity'] |
| |
|
| | min_data_count = min(self.task_data_count) |
| | max_data_count = max(self.task_data_count) |
| | last_data_count = self.task_data_count[-1] |
| | epsilon = 1e-10 |
| | alpha = (last_data_count - min_data_count) / (max_data_count - min_data_count + epsilon) |
| |
|
| | loss2 = alpha*sta_similarity_loss |
| | loss3 = (1-alpha)*pla_similarity_loss |
| |
|
| | loss = self.loss_fn(logits, y) + loss2 + loss3 |
| |
|
| | else: |
| | |
| | loss = self.loss_fn(logits, y) |
| | if self.pull_constraint and 'reduce_sim' in output: |
| | loss = loss - self.pull_constraint_coeff * output['reduce_sim'] |
| |
|
| | if not math.isfinite(loss.item()): |
| | raise RuntimeError(f'Loss is {loss.item()}, stopping training') |
| |
|
| | pred = torch.argmax(logits, dim=1) |
| | acc = torch.sum(pred == y).item() |
| |
|
| | return pred, acc / x.size(0), loss |
| |
|
| | def inference(self, data): |
| | x, y = data['image'], data['label'] |
| | x = x.to(self.device) |
| | y = y.to(self.device) |
| |
|
| | with torch.no_grad(): |
| | if self.original_model is not None: |
| | output = self.original_model(x) |
| | cls_features = output['pre_logits'] |
| | else: |
| | cls_features = None |
| | output = self.network(x, task_id=self.task_idx, cls_features=cls_features, gen=True) |
| | logits = output['logits'] |
| |
|
| | |
| | if self.task_inc and self.class_mask is not None: |
| | mask = self.class_mask[self.task_idx] |
| | mask = torch.tensor(mask, dtype=torch.int64).to(self.device) |
| | logits_mask = torch.ones_like(logits, device=self.device) * float('-inf') |
| | logits_mask = logits_mask.index_fill(1, mask, 0.0) |
| | logits = logits + logits_mask |
| |
|
| | pred = torch.argmax(logits, dim=1) |
| | acc = torch.sum(pred == y).item() |
| |
|
| | return pred, acc / x.size(0) |
| |
|
| | def before_task(self, task_idx, buffer, train_loader, test_loaders): |
| | self.task_idx = task_idx |
| | self.network.task_id = task_idx |
| | self.task_data_count.append(len(train_loader.dataset)) |
| |
|
| | @staticmethod |
| | def cal_latestsimilarity_loss(model: torch.nn.Module, task_id=-1): |
| | res = dict() |
| | global global_max_dist2 |
| |
|
| | gprompt = model.prompt.generalprompt |
| | tprompt = model.prompt.taskprompt[task_id].detach() |
| |
|
| | gprompt_flat = gprompt.view(-1) |
| | tprompt_tensors = tprompt.view(-1) |
| | similarity = 1-F.cosine_similarity(gprompt_flat, tprompt_tensors, dim=0) |
| | res['similarity'] = similarity |
| | return res |
| |
|
| | @staticmethod |
| | def cal_center(model: torch.nn.Module, task_id=-1, task_data_count=None, prompt_center=None): |
| | tprompt = model.prompt.taskprompt |
| | if task_id > 0: |
| | if prompt_center is None: |
| | prompt_center = tprompt[0].detach().view(-1) |
| | current_tprompt = tprompt[task_id - 1].detach().view(-1) |
| | if task_data_count: |
| | weights = [1 / count for count in task_data_count[:task_id]] |
| | normalized_weight = weights[-1] / sum(weights) |
| | weights2 = sum(weights[:-1]) / sum(weights) |
| | else: |
| | normalized_weight = 1.0 / task_id |
| | prompt_center = (prompt_center * weights2) + \ |
| | (current_tprompt * normalized_weight) |
| | else: |
| | prompt_center = torch.zeros_like(tprompt[0].detach().view(-1)) |
| | return prompt_center |
| |
|
| | @staticmethod |
| | def cal_similarity_loss(model: torch.nn.Module, task_id=-1, prompt_center=None): |
| | res = dict() |
| | global global_max_dist |
| |
|
| | gprompt = model.prompt.generalprompt |
| |
|
| | if task_id > 0: |
| | gprompt_flat = gprompt.view(-1) |
| | similarity = 1-F.cosine_similarity(gprompt_flat, prompt_center, dim=0) |
| | res['similarity'] = similarity |
| | res['avg_similarity'] = similarity |
| | else: |
| | res['similarity'] = torch.tensor(0) |
| | res['avg_similarity'] = 0 |
| | return res |