|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
from huggingface_hub import PyTorchModelHubMixin |
|
|
|
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
class ClipAdapter(nn.Module): |
|
|
def __init__(self, c_in, bottleneck=768): |
|
|
super(ClipAdapter, self).__init__() |
|
|
self.fc1 = nn.Sequential( |
|
|
nn.Linear(c_in, bottleneck, bias=False), |
|
|
nn.LeakyReLU(inplace=False) |
|
|
) |
|
|
self.fc2 = nn.Sequential( |
|
|
nn.Linear(bottleneck, c_in, bias=False), |
|
|
nn.LeakyReLU(inplace=False) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.fc1(x) |
|
|
y = self.fc2(x) |
|
|
return x, y |
|
|
|
|
|
|
|
|
class CLIPAD(nn.Module, |
|
|
PyTorchModelHubMixin, |
|
|
repo_url="https://github.com/Continual-Mega/Continual-Mega", |
|
|
paper_url="https://arxiv.org/abs/2506.00956"): |
|
|
def __init__(self, clip_model, features): |
|
|
super().__init__() |
|
|
self.clipmodel = clip_model |
|
|
self.image_encoder = clip_model.visual |
|
|
self.features = features |
|
|
self.adapters = nn.ModuleList( [ClipAdapter(1024, bottleneck=768) for i in range(len(features))] ) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.image_encoder.conv1(x) |
|
|
x = x.reshape(x.shape[0], x.shape[1], -1) |
|
|
x = x.permute(0, 2, 1) |
|
|
|
|
|
x = torch.cat( |
|
|
[self.image_encoder.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), |
|
|
x], dim=1) |
|
|
x = x + self.image_encoder.positional_embedding.to(x.dtype) |
|
|
|
|
|
x = self.image_encoder.patch_dropout(x) |
|
|
x = self.image_encoder.ln_pre(x) |
|
|
|
|
|
x = x.permute(1, 0, 2) |
|
|
|
|
|
ada_patch_tokens = [] |
|
|
|
|
|
for i, res in enumerate(self.image_encoder.transformer.resblocks): |
|
|
x, _ = res(x, attn_mask=None) |
|
|
if (i + 1) in self.features: |
|
|
adapt_med, adapt_out = self.adapters[self.features.index(i+1)](x) |
|
|
|
|
|
x = 0.9 * x + 0.1 * adapt_out |
|
|
ada_patch_tokens.append(adapt_med) |
|
|
|
|
|
x = x.permute(1, 0, 2) |
|
|
|
|
|
ada_patch_tokens = [ada_patch_tokens[t].permute(1, 0, 2) for t in range(len(ada_patch_tokens))] |
|
|
|
|
|
pooled, tokens = self.image_encoder._global_pool(x) |
|
|
pooled = self.image_encoder.ln_post(pooled) |
|
|
|
|
|
if self.image_encoder.proj is not None: |
|
|
pooled = pooled @ self.image_encoder.proj |
|
|
|
|
|
return pooled, ada_patch_tokens |
|
|
|
|
|
|
|
|
|
|
|
|