SAE / convs /ACL_buffer.py
Ttius's picture
Upload 192 files
998bb30 verified
# -*- coding: utf-8 -*-
"""
Buffer layers for the analytic continual learning (ACL) [1-3].
This implementation refers to the official implementation https://github.com/ZHUANGHP/Analytic-continual-learning.
References:
[1] Zhuang, Huiping, et al.
"ACIL: Analytic class-incremental learning with absolute memorization and privacy protection."
Advances in Neural Information Processing Systems 35 (2022): 11602-11614.
[2] Zhuang, Huiping, et al.
"GKEAL: Gaussian Kernel Embedded Analytic Learning for Few-Shot Class Incremental Task."
Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2023.
[3] Zhuang, Huiping, et al.
"DS-AL: A Dual-Stream Analytic Learning for Exemplar-Free Class-Incremental Learning."
Proceedings of the AAAI Conference on Artificial Intelligence. Vol. 38. No. 15. 2024.
"""
import torch
from typing import Optional, Union, Callable
from abc import ABCMeta, abstractmethod
__all__ = [
"Buffer",
"RandomBuffer",
"activation_t",
]
activation_t = Union[Callable[[torch.Tensor], torch.Tensor], torch.nn.Module]
class Buffer(torch.nn.Module, metaclass=ABCMeta):
def __init__(self) -> None:
super().__init__()
@abstractmethod
def forward(self, X: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()
class RandomBuffer(torch.nn.Linear, Buffer):
"""
Random buffer layer for the ACIL [1] and DS-AL [2].
This implementation refers to the official implementation https://github.com/ZHUANGHP/Analytic-continual-learning.
References:
[1] Zhuang, Huiping, et al.
"ACIL: Analytic class-incremental learning with absolute memorization and privacy protection."
Advances in Neural Information Processing Systems 35 (2022): 11602-11614.
[2] Zhuang, Huiping, et al.
"DS-AL: A Dual-Stream Analytic Learning for Exemplar-Free Class-Incremental Learning."
Proceedings of the AAAI Conference on Artificial Intelligence. Vol. 38. No. 15. 2024.
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = False,
device=None,
dtype=torch.float,
activation: Optional[activation_t] = torch.relu_,
) -> None:
super(torch.nn.Linear, self).__init__()
factory_kwargs = {"device": device, "dtype": dtype}
self.in_features = in_features
self.out_features = out_features
self.activation: activation_t = (
torch.nn.Identity() if activation is None else activation
)
W = torch.empty((out_features, in_features), **factory_kwargs)
b = torch.empty(out_features, **factory_kwargs) if bias else None
# Using buffer instead of parameter
self.register_buffer("weight", W)
self.register_buffer("bias", b)
# Random Initialization
self.reset_parameters()
@torch.no_grad()
def forward(self, X: torch.Tensor) -> torch.Tensor:
X = X.to(self.weight)
return self.activation(super().forward(X))