boringKey's picture
Upload 236 files
5fee096 verified
'''
Code Reference:
* https://github.com/jadore801120/attention-is-all-you-need-pytorch/
* https://github.com/GT-RIPL/CODA-Prompt
* https://github.com/openai/CLIP
'''
import os
import math
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from collections import Counter
from timm.models.vision_transformer import PatchEmbed
from timm.models.layers import trunc_normal_, DropPath
from scipy.special import softmax
from .petl.adapter import Adapter, MaskedAdapter
from .petl.proj import Proj
from .prompt import L2P
# Helper
class SparseDispatcher(object):
"""Helper for implementing a mixture of experts.
The purpose of this class is to create input minibatches for the
experts and to combine the results of the experts to form a unified
output tensor.
There are two functions:
dispatch - take an input Tensor and create input Tensors for each expert.
combine - take output Tensors from each expert and form a combined output
Tensor. Outputs from different experts for the same batch element are
summed together, weighted by the provided "gates".
The class is initialized with a "gates" Tensor, which specifies which
batch elements go to which experts, and the weights to use when combining
the outputs. Batch element b is sent to expert e iff gates[b, e] != 0.
The inputs and outputs are all two-dimensional [batch, depth].
Caller is responsible for collapsing additional dimensions prior to
calling this class and reshaping the output to the original shape.
See common_layers.reshape_like().
Example use:
gates: a float32 `Tensor` with shape `[batch_size, num_experts]`
inputs: a float32 `Tensor` with shape `[batch_size, input_size]`
experts: a list of length `num_experts` containing sub-networks.
dispatcher = SparseDispatcher(num_experts, gates)
expert_inputs = dispatcher.dispatch(inputs)
expert_outputs = [experts[i](expert_inputs[i]) for i in range(num_experts)]
outputs = dispatcher.combine(expert_outputs)
The preceding code sets the output for a particular example b to:
output[b] = Sum_i(gates[b, i] * experts[i](inputs[b]))
This class takes advantage of sparsity in the gate matrix by including in the
`Tensor`s for expert i only the batch elements for which `gates[b, i] > 0`.
"""
def __init__(self, num_experts, gates):
"""Create a SparseDispatcher."""
self._gates = gates
self._num_experts = num_experts
sorted_experts, index_sorted_experts = torch.nonzero(gates).sort(0)
# drop indices
_, self._expert_index = sorted_experts.split(1, dim=1)
# get according batch index for each expert
self._batch_index = torch.nonzero(gates)[index_sorted_experts[:, 1], 0]
# calculate num samples that each expert gets
self._part_sizes = (gates > 0).sum(0).tolist()
# expand gates to match with self._batch_index
gates_exp = gates[self._batch_index.flatten()]
self._nonzero_gates = torch.gather(gates_exp, 1, self._expert_index)
def dispatch(self, inp):
"""Create one input Tensor for each expert.
The `Tensor` for a expert `i` contains the slices of `inp` corresponding
to the batch elements `b` where `gates[b, i] > 0`.
Args:
inp: a `Tensor` of shape "[batch_size, <extra_input_dims>]`
Returns:
a list of `num_experts` `Tensor`s with shapes
`[expert_batch_size_i, <extra_input_dims>]`.
"""
# assigns samples to experts whose gate is nonzero
inp_exp = inp[self._batch_index].squeeze(1)
return torch.split(inp_exp, self._part_sizes, dim=0)
def combine(self, expert_out, multiply_by_gates=True):
"""Sum together the expert output, weighted by the gates.
The slice corresponding to a particular batch element `b` is computed
as the sum over all experts `i` of the expert output, weighted by the
corresponding gate values. If `multiply_by_gates` is set to False, the
gate values are ignored.
Args:
expert_out: a list of `num_experts` `Tensor`s, each with shape
`[expert_batch_size_i, <extra_output_dims>]`.
multiply_by_gates: a boolean
Returns:
a `Tensor` with shape `[batch_size, <extra_output_dims>]`.
"""
# apply exp to expert outputs, so we are not longer in log space
stitched = torch.cat(expert_out, 0)
if multiply_by_gates:
stitched = stitched.mul(self._nonzero_gates) # 加权
zeros = torch.zeros(self._gates.size(0), expert_out[-1].size(1), device=stitched.device)
# combine samples that have been processed by the same k experts
combined = zeros.index_add(0, self._batch_index, stitched.float())
# add eps to all zero values in order to avoid nans when going back to log space
# back to log space
return combined
def expert_to_gates(self):
"""Gate values corresponding to the examples in the per-expert `Tensor`s.
Returns:
a list of `num_experts` one-dimensional `Tensor`s with type `tf.float32`
and shapes `[expert_batch_size_i]`
"""
# split nonzero gates for each expert
return torch.split(self._nonzero_gates, self._part_sizes, dim=0)
# Sub-module of Attention
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
# Attention
class MultiHeadAttention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop) if attn_drop > 0. else nn.Identity()
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0. else nn.Identity()
self.attn_gradients = None
self.attention_map = None
def save_attn_gradients(self, attn_gradients):
self.attn_gradients = attn_gradients
def get_attn_gradients(self):
return self.attn_gradients
def save_attention_map(self, attention_map):
self.attention_map = attention_map
def get_attention_map(self):
return self.attention_map
def forward(self, x, attn_mask=None, register_hook=False, prompt=None):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # [3, B, NH, N, HD]
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
if prompt is not None:
pk, pv = prompt
pk = pk.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
pv = pv.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
k = torch.cat((pk,k), dim=2)
v = torch.cat((pv,v), dim=2)
attn = (q @ k.transpose(-2, -1)) * self.scale
if attn_mask is not None:
attn += attn_mask.unsqueeze(0) # For head axis broadcasting
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
if register_hook:
self.save_attention_map(attn)
attn.register_hook(self.save_attn_gradients)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class MultiHeadAttention_LoRA(MultiHeadAttention):
'''
Attention module with lora, apply to k, v
'''
def __init__(self, dim, num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., lora_rank=10, lora_bias=False):
super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop)
self.lora_rank = lora_rank
self.lora_A_k = nn.Linear(self.dim, self.lora_rank, bias=lora_bias)
self.lora_B_k = nn.Linear(self.lora_rank, self.dim, bias=lora_bias)
self.lora_A_v = nn.Linear(self.dim, self.lora_rank, bias=lora_bias)
self.lora_B_v = nn.Linear(self.lora_rank, self.dim, bias=lora_bias)
self.apply_lora = False
self.cur_matrix = torch.zeros(self.dim ,self.dim)
self.n_cur_matrix = 0
def init_param(self):
nn.init.kaiming_uniform_(self.lora_A_k.weight, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.lora_A_v.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_B_k.weight)
nn.init.zeros_(self.lora_B_v.weight)
self.apply_lora = True
def merge_weight(self):
q_weight, k_weight, v_weight = self.qkv.weight.chunk(3, dim=0)
k_weight = k_weight + self.lora_B_k.weight @ self.lora_A_k.weight
v_weight = v_weight + self.lora_B_v.weight @ self.lora_A_v.weight
self.qkv.weight.data = torch.cat([q_weight, k_weight, v_weight], dim=0)
self.apply_lora = False
def reset_input_matrix(self):
self.cur_matrix.zero_()
self.n_cur_matrix = 0
def forward(self, x, attn_mask=None, register_hook=False, prompt=None, get_input_matrix = False):
if get_input_matrix:
self.cur_matrix = (self.cur_matrix * self.n_cur_matrix + torch.bmm(x.detach().permute(0, 2, 1), x.detach()).sum(dim=0).cpu())/(self.n_cur_matrix + x.shape[0] * x.shape[1])
self.n_cur_matrix += x.shape[0]*x.shape[1]
B, N, C = x.shape
q_weight, k_weight, v_weight = self.qkv.weight.chunk(3, dim=0)
if self.apply_lora:
k_weight = k_weight + self.lora_B_k.weight @ self.lora_A_k.weight
v_weight = v_weight + self.lora_B_v.weight @ self.lora_A_v.weight
qkv = F.linear(x, torch.cat([q_weight, k_weight, v_weight], dim=0), self.qkv.bias.data).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
if attn_mask is not None:
attn += attn_mask.unsqueeze(0) # For head axis broadcasting
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
if register_hook:
self.save_attention_map(attn)
attn.register_hook(self.save_attn_gradients)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class MultiHeadAttention_SDLoRA(MultiHeadAttention):
def __init__(self, dim, num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., lora_rank=10, lora_bias=False):
super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop)
self.lora_rank = lora_rank
self.lora_bias = lora_bias
self.lora_A_q_list = nn.ModuleList([])
self.lora_B_q_list = nn.ModuleList([])
self.lora_A_v_list = nn.ModuleList([])
self.lora_B_v_list = nn.ModuleList([])
self.assimilated_mag_lora_q = []
self.assimilated_mag_lora_v = []
def init_param(self):
self.lora_A_q_list.append(nn.Linear(self.dim, self.lora_rank, bias=self.lora_bias))
self.lora_B_q_list.append(nn.Linear(self.lora_rank, self.dim, bias=self.lora_bias))
self.lora_A_v_list.append(nn.Linear(self.dim, self.lora_rank, bias=self.lora_bias))
self.lora_B_v_list.append(nn.Linear(self.lora_rank, self.dim, bias=self.lora_bias))
nn.init.kaiming_uniform_(self.lora_A_q_list[-1].weight, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.lora_A_v_list[-1].weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_B_q_list[-1].weight)
nn.init.zeros_(self.lora_B_v_list[-1].weight)
self.assimilated_mag_lora_q.append(
torch.Tensor([0.0]).to(self.qkv.weight.device)
)
self.assimilated_mag_lora_v.append(
torch.Tensor([0.0]).to(self.qkv.weight.device)
)
assert len(self.lora_A_q_list) == len(self.mag_lora)
assert len(self.mag_lora) == len(self.assimilated_mag_lora_q)
def forward(self, x, attn_mask=None, register_hook=False, prompt=None):
B, N, C = x.shape
qq = self.mag_lora[-1] * self.lora_B_q_list[-1](self.lora_A_q_list[-1](x))
vv = self.mag_lora[-1] * self.lora_B_v_list[-1](self.lora_A_v_list[-1](x))
for i in range(len(self.lora_A_q_list) - 1):
norm_B = torch.norm(self.lora_B_q_list[i].weight)
norm_A = torch.norm(self.lora_A_q_list[i].weight)
if norm_B != 0 and norm_A != 0: # Only in SD-LoRA-KD, where direction of lora being decomposed
qq += (self.mag_lora[i] + self.assimilated_mag_lora_q[i]) * self.lora_B_q_list[i](self.lora_A_q_list[i](x)) / (norm_B * norm_A)
norm_B = torch.norm(self.lora_B_v_list[i].weight)
norm_A = torch.norm(self.lora_A_v_list[i].weight)
if norm_B != 0 and norm_A != 0:
vv += (self.mag_lora[i] + self.assimilated_mag_lora_v[i]) * self.lora_B_v_list[i](self.lora_A_v_list[i](x)) / (norm_B * norm_A)
qkv = self.qkv(x)
qkv[:, :, : self.dim] += qq
qkv[:, :, -self.dim :] += vv
qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
if attn_mask is not None:
attn += attn_mask.unsqueeze(0) # For head axis broadcasting
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
if register_hook:
self.save_attention_map(attn)
attn.register_hook(self.save_attn_gradients)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class MultiHeadAttention_LoRA_Sub(MultiHeadAttention):
'''
Attention module with lora, apply to k, v
'''
def __init__(self, dim, num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., lora_rank=10, lora_bias=False):
super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop)
self.lora_rank = lora_rank
self.lora_A_k = nn.Linear(self.dim, self.lora_rank, bias=lora_bias)
self.lora_B_k = nn.Linear(self.lora_rank, self.dim, bias=lora_bias)
self.lora_A_v = nn.Linear(self.dim, self.lora_rank, bias=lora_bias)
self.lora_B_v = nn.Linear(self.lora_rank, self.dim, bias=lora_bias)
self.apply_lora = False
self.cur_matrix = torch.zeros(self.dim ,self.dim)
self.n_cur_matrix = 0
self.register_buffer("prev_k_weight", torch.zeros(self.dim, self.dim))
self.register_buffer("prev_v_weight", torch.zeros(self.dim, self.dim))
def init_param(self):
nn.init.kaiming_uniform_(self.lora_A_k.weight, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.lora_A_v.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_B_k.weight)
nn.init.zeros_(self.lora_B_v.weight)
self.apply_lora = True
def save_weight(self):
self.prev_k_weight += self.lora_B_k.weight @ self.lora_A_k.weight
self.prev_v_weight += self.lora_B_v.weight @ self.lora_A_v.weight
self.apply_lora = False
def reset_input_matrix(self):
self.cur_matrix.zero_()
self.n_cur_matrix = 0
def forward(self, x, attn_mask=None, register_hook=False, prompt=None, get_input_matrix = False):
B, N, C = x.shape
q_weight, k_weight, v_weight = self.qkv.weight.chunk(3, dim=0)
if get_input_matrix:
# Only in before_task getting input matrix
self.cur_matrix = (self.cur_matrix * self.n_cur_matrix + torch.bmm(x.detach().permute(0, 2, 1), x.detach()).sum(dim=0).cpu())/(self.n_cur_matrix + x.shape[0] * x.shape[1])
self.n_cur_matrix += x.shape[0]*x.shape[1]
k_weight = k_weight - self.prev_k_weight
v_weight = v_weight - self.prev_v_weight
elif self.apply_lora:
# Only in training
k_weight = k_weight + self.prev_k_weight + self.lora_B_k.weight @ self.lora_A_k.weight
v_weight = v_weight + self.prev_v_weight + self.lora_B_v.weight @ self.lora_A_v.weight
else:
# Only in testing
k_weight = k_weight + self.prev_k_weight
v_weight = v_weight + self.prev_v_weight
qkv = F.linear(x, torch.cat([q_weight, k_weight, v_weight], dim=0), self.qkv.bias.data).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
if attn_mask is not None:
attn += attn_mask.unsqueeze(0) # For head axis broadcasting
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
if register_hook:
self.save_attention_map(attn)
attn.register_hook(self.save_attn_gradients)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class MultiHeadAttention_CL_LoRA(MultiHeadAttention_LoRA):
'''
Attention module with lora, apply to q, v
'''
def __init__(self, dim, num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., lora_rank=10, lora_bias=False):
super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias)
del self.lora_A_k
del self.lora_B_k
self.lora_A_q = nn.Linear(self.dim, self.lora_rank, bias=lora_bias)
self.lora_B_q = nn.Linear(self.lora_rank, self.dim, bias=lora_bias)
def init_param(self):
q1, _ = torch.linalg.qr(torch.rand(self.dim, self.lora_rank))
q2, _ = torch.linalg.qr(torch.rand(self.dim, self.lora_rank))
with torch.no_grad():
self.lora_A_q.weight.copy_(q1.T)
self.lora_A_v.weight.copy_(q2.T)
scaling_factor = 1. # You can adjust this value if needed
self.lora_A_q.weight.data *= scaling_factor
self.lora_A_v.weight.data *= scaling_factor
nn.init.zeros_(self.lora_B_q.weight)
nn.init.zeros_(self.lora_B_v.weight)
def forward(
self,
x,
adapt=None,
prompt=None,
rank_prompt=None,
block_weight=None,
attn_mask=None,
register_hook=False):
# custom_adapt including the lora and lora scale weight
# since this method has many set of weights during training/inference, keep changing the module weight is quite exhausting
# lets just pass in as argument
B, N, C = x.shape
q_weight, k_weight, v_weight = self.qkv.weight.chunk(3, dim=0)
qkv = F.linear(x, torch.cat([q_weight, k_weight, v_weight], dim=0), self.qkv.bias.data)
if adapt is not None:
if block_weight is not None:
block_weight = block_weight
else:
block_weight = torch.ones(3).to(x.device)
qq = block_weight[0] * adapt[0](x)
vv = block_weight[2] * adapt[2](x)
qkv[:, :, : self.dim] += qq
qkv[:, :, -self.dim :] += vv
qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
if attn_mask is not None:
attn += attn_mask.unsqueeze(0) # For head axis broadcasting
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
if register_hook:
self.save_attention_map(attn)
attn.register_hook(self.save_attn_gradients)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
# MInfLoRA
class MultiHeadAttention_MaskedLoRA(MultiHeadAttention_LoRA):
# Attention module with masked (projection) lora, apply to k, v
def __init__(self, dim, num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., lora_rank=10, lora_bias=False):
super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias)
# Trgp implementation
self.identity_matrix = torch.eye(self.qkv.weight.shape[1])
self.space = [[0, 0] for _ in range(10)]
self.scale_param = nn.ModuleList([nn.ParameterList([nn.Parameter(self.identity_matrix) for _ in range(2)]) for _ in range(10)])
self.scaling_mask = [[False, False] for _ in range(10)]
def enable_scale(self, task_id, space):
if len(space) == 2:
self.space[task_id][0] = space[0]
self.space[task_id][1] = space[1]
self.scaling_mask[task_id][0] = True
self.scaling_mask[task_id][1] = True
elif len(space) == 1:
self.space[task_id][0] = space[0]
self.scaling_mask[task_id][0] = True
for scale_param_list in self.scale_param:
for scale_param in scale_param_list:
scale_param = scale_param.to(self.qkv.weight.device)
def forward(self, x, attn_mask=None, expert_id=0, register_hook=False, prompt=None, get_input_matrix = False):
if get_input_matrix:
self.cur_matrix = (self.cur_matrix*self.n_cur_matrix + torch.bmm(x.detach().permute(0, 2, 1), x.detach()).sum(dim=0).cpu())/(self.n_cur_matrix + x.shape[0]*x.shape[1])
self.n_cur_matrix += x.shape[0]*x.shape[1]
B, N, C = x.shape
q_weight, k_weight, v_weight = self.qkv.weight.chunk(3, dim=0)
if self.apply_lora:
k_weight = k_weight + self.lora_B_k.weight @ self.lora_A_k.weight
v_weight = v_weight + self.lora_B_v.weight @ self.lora_A_v.weight
for mask, scale, space in zip(self.scaling_mask[expert_id], self.scale_param[expert_id], self.space[expert_id]):
if not mask:
break
scale_size = space.shape[1]
cropped_scale = scale[:scale_size, :scale_size]
cropped_scale = cropped_scale @ cropped_scale.T # better, idk why
cropped_identity_matrix = self.identity_matrix[:scale_size, :scale_size].to(self.qkv.weight.device)
k_weight = k_weight + k_weight @ space @ (cropped_scale - cropped_identity_matrix) @ space.T
v_weight = v_weight + v_weight @ space @ (cropped_scale - cropped_identity_matrix) @ space.T
qkv = F.linear(x, torch.cat([q_weight, k_weight, v_weight], dim=0), self.qkv.bias.data).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
if attn_mask is not None:
attn += attn_mask.unsqueeze(0) # For head axis broadcasting
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
if register_hook:
self.save_attention_map(attn)
attn.register_hook(self.save_attn_gradients)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
# MInfLoRA1
class MultiHeadAttention_MaskedLoRA1(MultiHeadAttention):
def __init__(self, dim, num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., lora_rank=10, lora_bias=False):
super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop)
self.cur_task = -1
self.lora_rank = lora_rank
self.cur_matrix = torch.zeros(self.dim ,self.dim)
self.n_cur_matrix = 0
self.lora_bias = lora_bias
self.lora_A_k_list = nn.ModuleList([])
self.lora_B_k_list = nn.ModuleList([])
self.lora_A_v_list = nn.ModuleList([])
self.lora_B_v_list = nn.ModuleList([])
self.space_k = [0 for _ in range(10)]
self.space_v = [0 for _ in range(10)]
self.identity_matrix = torch.eye(self.qkv.weight.shape[1])
self.scale_param = nn.ParameterList([])
def init_param(self):
self.lora_A_k_list.append(nn.Linear(self.dim, self.lora_rank, bias=self.lora_bias))
self.lora_B_k_list.append(nn.Linear(self.lora_rank, self.dim, bias=self.lora_bias))
self.lora_A_v_list.append(nn.Linear(self.dim, self.lora_rank, bias=self.lora_bias))
self.lora_B_v_list.append(nn.Linear(self.lora_rank, self.dim, bias=self.lora_bias))
self.scale_param.append(nn.Parameter(self.identity_matrix))
nn.init.kaiming_uniform_(self.lora_A_k_list[-1].weight, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.lora_A_v_list[-1].weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_B_k_list[-1].weight)
nn.init.zeros_(self.lora_B_v_list[-1].weight)
self.cur_task += 1
def reset_input_matrix(self):
self.cur_matrixs = []
def forward(self, x, x_proj, probs, attn_mask=None, expert_id=0, register_hook=False, prompt=None, get_input_matrix=False):
if get_input_matrix:
assert x.shape[0] < 512
self.cur_matrixs.append(x.detach())
if x.shape[0] > 128:
# do some drift check here
activation = torch.bmm(x.permute(0, 2, 1), x).sum(dim=0) / x.shape[0]
# get the intersect between previous activation and curr activation
activation = self.lora_A_k_list[-1].weight.data.T @ self.lora_A_k_list[-1].weight.data @ activation
if self.cur_task > 0:
activation = activation - self.feature_mat @ activation
U, _, _ = torch.linalg.svd(activation, full_matrices = False)
A_new = U[:,:self.lora_rank].T / math.sqrt(3)
A_old = self.lora_A_k_list[-1].weight.data
Bk_old = self.lora_B_k_list[-1].weight.data
Bv_old = self.lora_B_v_list[-1].weight.data
tmp = A_old @ torch.pinverse(A_new)
Bk_new = Bk_old @ tmp
Bv_new = Bv_old @ tmp
'''
# Compute matmul results
Bk_old_A_old = Bk_old @ A_old
Bk_new_A_new = Bk_new @ A_new
Bv_old_A_old = Bv_old @ A_old
Bv_new_A_new = Bv_new @ A_new
# Compute the Frobenius norm of the difference between old and new matmul results
frobenius_norm_Bk = torch.norm(Bk_old_A_old - Bk_new_A_new, p='fro')
frobenius_norm_Bv = torch.norm(Bv_old_A_old - Bv_new_A_new, p='fro')
# Printing the results
print(f"Frobenius norm difference between Bk_old @ A_old and Bk_new @ A_new: {frobenius_norm_Bk.item()}")
print(f"Frobenius norm difference between Bv_old @ A_old and Bv_new @ A_new: {frobenius_norm_Bv.item()}")
'''
self.lora_A_k_list[-1].weight.data.copy_(A_new)
self.lora_A_v_list[-1].weight.data.copy_(A_new)
self.lora_B_k_list[-1].weight.data.copy_(Bk_new)
self.lora_B_v_list[-1].weight.data.copy_(Bv_new)
B, N, C = x.shape
q_weight, k_weight, v_weight = self.qkv.weight.chunk(3, dim=0)
for ii in range(self.cur_task):
k_weight = k_weight + self.lora_B_k_list[ii].weight @ self.lora_A_k_list[ii].weight
v_weight = v_weight + self.lora_B_v_list[ii].weight @ self.lora_A_v_list[ii].weight
k_weight = k_weight + self.lora_B_k_list[-1].weight @ self.lora_A_k_list[-1].weight
v_weight = v_weight + self.lora_B_v_list[-1].weight @ self.lora_A_v_list[-1].weight
'''
for ii in range(self.cur_task):
if not isinstance(self.space_k[ii], int):
space_k = self.space_k[ii]
space_v = self.space_v[ii]
scale_k = self.scale_param[ii]
# Q Scaling
scalee = scale_k[:space_k.shape[0], :space_k.shape[0]]
# QQ^T Scaling
scalee = scale_k[:space_k.shape[0], :space_k.shape[0]] @ scale_k[:space_k.shape[0], :space_k.shape[0]].T
# QQ^T Diagonal Scaling12
#scalee = torch.diag(torch.diag(scale_k[:space_k.shape[0], :space_k.shape[0]] @ scale_k[:space_k.shape[0], :space_k.shape[0]].T))
# Q Diagonal Scaling
#scalee = torch.diag(torch.diag(scale_k[:space_k.shape[0], :space_k.shape[0]]))
#scalee = scale_k[0, 0]
scalee = self.mag_lora[ii]
use_scale = False
if use_scale:
norm_B = torch.norm(self.lora_B_k_list[ii].weight)
norm_A = torch.norm(self.lora_A_k_list[ii].weight)
k_weight = k_weight - self.lora_B_k_list[ii].weight @ self.lora_A_k_list[ii].weight @ space_k.T @ space_k
k_weight = k_weight + scalee * (self.lora_B_k_list[ii].weight @ self.lora_A_k_list[ii].weight @ space_k.T @ space_k) / (norm_B * norm_A)
norm_B = torch.norm(self.lora_B_v_list[ii].weight)
norm_A = torch.norm(self.lora_A_v_list[ii].weight)
v_weight = v_weight - self.lora_B_v_list[ii].weight @ self.lora_A_v_list[ii].weight @ space_v.T @ space_v
v_weight = v_weight + scalee * (self.lora_B_v_list[ii].weight @ self.lora_A_v_list[ii].weight @ space_v.T @ space_v) / (norm_B * norm_A)
'''
qkv = F.linear(x, torch.cat([q_weight, k_weight, v_weight], dim=0), self.qkv.bias.data).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
if attn_mask is not None:
attn += attn_mask.unsqueeze(0) # For head axis broadcasting
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
if register_hook:
self.save_attention_map(attn)
attn.register_hook(self.save_attn_gradients)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x, x, probs
# MInfLoRA2
class MultiHeadAttention_MultiMaskedLoRA(MultiHeadAttention_MaskedLoRA):
def __init__(self, dim, num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., lora_rank=10, lora_bias=False):
super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias)
self.activated_expert = 0
self.saved_space = [[torch.tensor((1)), torch.tensor((1))] for _ in range(10)]
self.hit = 0
self.total = 0
self.projected_cur_matrix = torch.zeros(self.dim ,self.dim)
self.n_projected_cur_matrix = 0
def reset_input_matrix(self):
super().reset_input_matrix()
self.projected_cur_matrix.zero_()
self.n_projected_cur_matrix = 0
def enable_scale(self, task_id, space):
if len(space) == 2:
self.space[task_id][0] = space[0]
self.space[task_id][1] = space[1]
self.scaling_mask[task_id][0] = True
self.scaling_mask[task_id][1] = True
elif len(space) == 1:
self.space[task_id][0] = space[0]
self.scaling_mask[task_id][0] = True
for scale_param_list in self.scale_param:
for scale_param in scale_param_list:
scale_param = scale_param.to(self.qkv.weight.device)
def save_space(self, task_id, space):
self.activated_expert = task_id
self.saved_space[task_id][0] = space
def forward(self, x, x_proj, probs, attn_mask=None, expert_id=0, register_hook=False, prompt=None, get_input_matrix=False):
B, N, C = x.shape
if get_input_matrix:
assert expert_id == 0
self.cur_matrix = (self.cur_matrix * self.n_cur_matrix + torch.bmm(x.detach().permute(0, 2, 1), x.detach()).sum(dim=0).cpu())/(self.n_cur_matrix + B * N)
self.n_cur_matrix += B * N
# By Sum and Batch
if not self.training and not get_input_matrix:
with torch.no_grad():
cur_cur_matrix = torch.bmm(x.detach().permute(0, 2, 1), x.detach()).sum(dim=0) / (B * N) # (C, C)
saved = torch.stack([self.saved_space[idd][0] for idd in range(self.activated_expert + 1)]).to(x.device) # (task_num, C, r)
#saved = torch.stack([self.space[idd][0] for idd in range(self.activated_expert + 1)]).to(x.device) # (task_num, C, r)
proj_mat = saved.transpose(1, 2) # (task_num, r, C)
proj_mat = torch.einsum('ijk,kl->ijl', proj_mat, cur_cur_matrix) # (task_num, r, C) @ (C, C)
proj_norm = np.linalg.norm(proj_mat.cpu(), axis=(1, 2)) # (task_num, )
proj_norm = softmax(proj_norm)
probs.append(proj_norm)
selected_expert_id = np.argmax(proj_norm, axis = 0) # (task_num, )
expert_id = selected_expert_id
q_weight, k_weight, v_weight = self.qkv.weight.chunk(3, dim=0)
if self.apply_lora:
k_weight = k_weight + self.lora_B_k.weight @ self.lora_A_k.weight
v_weight = v_weight + self.lora_B_v.weight @ self.lora_A_v.weight
qkv = F.linear(x, torch.cat([q_weight, k_weight, v_weight], dim=0), self.qkv.bias.data).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
if attn_mask is not None:
attn += attn_mask.unsqueeze(0) # For head axis broadcasting
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
if register_hook:
self.save_attention_map(attn)
attn.register_hook(self.save_attn_gradients)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
# ----
for mask, scale, space in zip(self.scaling_mask[expert_id], self.scale_param[expert_id], self.space[expert_id]):
if not mask:
break
scale_size = space.shape[1]
cropped_scale = scale[:scale_size, :scale_size]
cropped_scale = cropped_scale @ cropped_scale.T # better, idk why
cropped_identity_matrix = self.identity_matrix[:scale_size, :scale_size].to(self.qkv.weight.device)
k_weight = k_weight + k_weight @ space @ (cropped_scale - cropped_identity_matrix) @ space.T
v_weight = v_weight + v_weight @ space @ (cropped_scale - cropped_identity_matrix) @ space.T
qkv = F.linear(x_proj, torch.cat([q_weight, k_weight, v_weight], dim=0), self.qkv.bias.data).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
if attn_mask is not None:
attn += attn_mask.unsqueeze(0) # For head axis broadcasting
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
if register_hook:
self.save_attention_map(attn)
attn.register_hook(self.save_attn_gradients)
x_proj = (attn @ v).transpose(1, 2).reshape(B, N, C)
x_proj = self.proj(x_proj)
x_proj = self.proj_drop(x_proj)
return x, x_proj, probs
def forward1(self, x, x_proj, probs, attn_mask=None, expert_id=0, register_hook=False, prompt=None, get_input_matrix=False):
B, N, C = x.shape
if get_input_matrix:
assert expert_id == 0
self.cur_matrix = (self.cur_matrix * self.n_cur_matrix + torch.bmm(x.detach().permute(0, 2, 1), x.detach()).sum(dim=0).cpu())/(self.n_cur_matrix + B * N)
self.n_cur_matrix += B * N
# By each
if not self.training and not get_input_matrix:
with torch.no_grad():
cur_cur_matrix = torch.bmm(x.detach().permute(0, 2, 1), x.detach()) / N # (B, C, C)
cur_cur_matrix = cur_cur_matrix.permute(1, 2, 0) # (C, C, B)
saved = torch.stack([self.saved_space[idd][0] for idd in range(self.activated_expert + 1)]).to(x.device) # (task_num, C, r)
proj_mat = saved.transpose(1, 2) # (task_num, r, C)
proj_mat = torch.einsum('ijk,klm->ijlm', proj_mat, cur_cur_matrix) # (task_num, r, C) @ (C, C, B) -> (task_num, r, C, B)
proj_norm = np.linalg.norm(proj_mat, axis=(1, 2)) # (task_num, B)
proj_norm = softmax(proj_norm, axis=0) # (task_num, B)
probs.append(proj_norm)
selected_expert_id = np.argmax(proj_norm, axis = 0) # (B, )
selected_expert_id = torch.tensor(selected_expert_id).to(x.device)
q_weight, k_weight, v_weight = self.qkv.weight.chunk(3, dim=0)
if self.apply_lora:
k_weight = k_weight + self.lora_B_k.weight @ self.lora_A_k.weight
v_weight = v_weight + self.lora_B_v.weight @ self.lora_A_v.weight
qkv = F.linear(x, torch.cat([q_weight, k_weight, v_weight], dim=0), self.qkv.bias.data).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
if attn_mask is not None:
attn += attn_mask.unsqueeze(0) # For head axis broadcasting
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
if register_hook:
self.save_attention_map(attn)
attn.register_hook(self.save_attn_gradients)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
# ----
if not self.training and not get_input_matrix: # Test
inputs = [x_proj.clone() for _ in range(self.activated_expert + 1)]
k_weights = [k_weight.clone() for _ in range(self.activated_expert + 1)]
v_weights = [v_weight.clone() for _ in range(self.activated_expert + 1)]
qkv_outputs = []
for ex in range(self.activated_expert + 1):
for mask, scale, space in zip(self.scaling_mask[ex], self.scale_param[ex], self.space[ex]):
if not mask:
break
scale_size = space.shape[1]
cropped_scale = scale[:scale_size, :scale_size]
cropped_scale = cropped_scale @ cropped_scale.T # better, idk why
cropped_identity_matrix = self.identity_matrix[:scale_size, :scale_size].to(x.device)
k_weights[ex] = k_weights[ex] + k_weights[ex] @ space @ (cropped_scale - cropped_identity_matrix) @ space.T
v_weights[ex] = v_weights[ex] + v_weights[ex] @ space @ (cropped_scale - cropped_identity_matrix) @ space.T
cur_selected = selected_expert_id.unsqueeze(-1).unsqueeze(-1)
mask = (cur_selected == ex)
inputs[ex] *= mask
inputs[ex] = inputs[ex].to(x.device)
q_weight = q_weight.to(x.device)
k_weights[ex] = k_weights[ex].to(x.device)
v_weights[ex] = v_weights[ex].to(x.device)
qkv = F.linear(inputs[ex], torch.cat([q_weight, k_weights[ex], v_weights[ex]], dim=0))
qkv_outputs.append(qkv)
stacked = torch.stack(qkv_outputs)
qkv = torch.sum(stacked, dim=0)
qkv = qkv + self.qkv.bias
qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
if attn_mask is not None:
attn += attn_mask.unsqueeze(0) # For head axis broadcasting
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
if register_hook:
self.save_attention_map(attn)
attn.register_hook(self.save_attn_gradients)
x_proj = (attn @ v).transpose(1, 2).reshape(B, N, C)
x_proj = self.proj(x_proj)
x_proj = self.proj_drop(x_proj)
else:
for mask, scale, space in zip(self.scaling_mask[expert_id], self.scale_param[expert_id], self.space[expert_id]):
if not mask:
break
scale_size = space.shape[1]
cropped_scale = scale[:scale_size, :scale_size]
cropped_scale = cropped_scale @ cropped_scale.T # better, idk why
cropped_identity_matrix = self.identity_matrix[:scale_size, :scale_size].to(self.qkv.weight.device)
k_weight = k_weight + k_weight @ space @ (cropped_scale - cropped_identity_matrix) @ space.T
v_weight = v_weight + v_weight @ space @ (cropped_scale - cropped_identity_matrix) @ space.T
qkv = F.linear(x_proj, torch.cat([q_weight, k_weight, v_weight], dim=0), self.qkv.bias.data).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
if attn_mask is not None:
attn += attn_mask.unsqueeze(0) # For head axis broadcasting
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
if register_hook:
self.save_attention_map(attn)
attn.register_hook(self.save_attn_gradients)
x_proj = (attn @ v).transpose(1, 2).reshape(B, N, C)
x_proj = self.proj(x_proj)
x_proj = self.proj_drop(x_proj)
return x, x_proj, probs
# MInfLoRA3
class MultiHeadAttention_MultiMaskedLoRA3(MultiHeadAttention_MaskedLoRA):
def __init__(self, dim, num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., lora_rank=10, lora_bias=False):
super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias)
self.cur_task = -1
self.lora_A_k_list = nn.ModuleList([nn.Linear(self.dim, self.lora_rank, bias=lora_bias) for _ in range(10)])
self.lora_B_k_list = nn.ModuleList([nn.Linear(self.lora_rank, self.dim, bias=lora_bias) for _ in range(10)])
self.lora_A_v_list = nn.ModuleList([nn.Linear(self.dim, self.lora_rank, bias=lora_bias) for _ in range(10)])
self.lora_B_v_list = nn.ModuleList([nn.Linear(self.lora_rank, self.dim, bias=lora_bias) for _ in range(10)])
self.space_k = [0 for _ in range(10)]
self.space_v = [0 for _ in range(10)]
self.scale_param = nn.ParameterList([nn.Parameter(self.identity_matrix) for _ in range(10)])
def init_param(self):
self.cur_task += 1
i = self.cur_task
nn.init.kaiming_uniform_(self.lora_A_k_list[i].weight, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.lora_A_v_list[i].weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_B_k_list[i].weight)
nn.init.zeros_(self.lora_B_v_list[i].weight)
def merge_weight(self):
print('Not MERGED')
return 0
q_weight, k_weight, v_weight = self.qkv.weight.chunk(3, dim=0)
k_weight = k_weight + self.lora_B_k.weight @ self.lora_A_k.weight
v_weight = v_weight + self.lora_B_v.weight @ self.lora_A_v.weight
self.apply_lora = False
for exp_id in range(10):
for ii, mask, scale_k, scale_v, space_k, space_v in zip([0, 1], self.scaling_mask[exp_id], self.scale_param_k[exp_id], self.scale_param_v[exp_id], self.space_k[exp_id], self.space_v[exp_id]):
if isinstance(space_k, int):
break
k_weight = k_weight - k_weight @ space_k.T @ space_k + k_weight @ space_k.T @ scale_k[:space_k.shape[0], :space_k.shape[0]] @ space_k
v_weight = v_weight - v_weight @ space_v.T @ space_v + v_weight @ space_v.T @ scale_k[:space_v.shape[0], :space_v.shape[0]] @ space_v
self.space_k[exp_id][ii] = 0
self.qkv.weight.data = torch.cat([q_weight, k_weight, v_weight], dim=0)
def save_dir(self):
return 0
self.cur_task += 1
'''
norm = torch.linalg.matrix_norm(self.lora_B_k.weight @ self.lora_A_k.weight)
self.lora_A_k.weight.data = self.lora_A_k.weight.data / norm
self.lora_B_k.weight.data = self.lora_B_k.weight.data / norm
self.space_k[self.cur_task][0] = self.lora_A_k.weight.data.clone() / norm
norm = torch.linalg.matrix_norm(self.lora_B_v.weight @ self.lora_A_v.weight)
self.lora_A_v.weight.data = self.lora_A_v.weight.data / norm
self.lora_B_v.weight.data = self.lora_B_v.weight.data / norm
self.space_v[self.cur_task][0] = self.lora_A_v.weight.data.clone() / norm]
'''
_, k_weight, v_weight = self.qkv.weight.chunk(3, dim=0)
U, _, _ = np.linalg.svd(k_weight.data, full_matrices = False)
U, _, _ = np.linalg.svd(U[:, :10], full_matrices = False)
orto_proj = U[:, -50:]
self.space_k[self.cur_task][0] = torch.Tensor(orto_proj.T).to(self.qkv.weight.device)
U, _, _ = np.linalg.svd(v_weight.data, full_matrices = False)
U, _, _ = np.linalg.svd(U[:, :10], full_matrices = False)
orto_proj = U[:, -50:]
self.space_v[self.cur_task][0] = torch.Tensor(orto_proj.T).to(self.qkv.weight.device)
self.scaling_mask[self.cur_task][0] = True
def enable_scale(self, task_id, space):
if len(space) == 2:
self.space[task_id][0] = space[0]
self.space[task_id][1] = space[1]
self.scaling_mask[task_id][0] = True
self.scaling_mask[task_id][1] = True
elif len(space) == 1:
self.space[task_id][0] = space[0]
self.scaling_mask[task_id][0] = True
for scale_param_list in self.scale_param:
for scale_param in scale_param_list:
scale_param = scale_param.to(self.qkv.weight.device)
def save_space(self, task_id, space):
self.activated_expert = task_id
self.saved_space[task_id].append(space)
def forward(self, x, x_proj, probs, attn_mask=None, expert_id=0, register_hook=False, prompt=None, get_input_matrix=False):
B, N, C = x.shape
if get_input_matrix:
self.cur_matrix = (self.cur_matrix * self.n_cur_matrix + torch.bmm(x.detach().permute(0, 2, 1), x.detach()).sum(dim=0).cpu())/(self.n_cur_matrix + B * N)
self.n_cur_matrix += B * N
q_weight, k_weight, v_weight = self.qkv.weight.chunk(3, dim=0)
# DEBUG
for exp_id in range(10):
break
for mask, scale, space_k, space_v in zip(self.scaling_mask[exp_id], self.scale_param[exp_id], self.space_k[exp_id], self.space_v[exp_id]):
if isinstance(space_k, int):
break
cropped_scale = scale[:space_k.shape[0], :space_k.shape[0]]
print(
round(torch.linalg.norm(k_weight @ space_k.T @ space_k, ord='fro').item(), 2),
round(torch.linalg.norm(k_weight @ space_k.T @ cropped_scale @ space_k, ord='fro').item(), 2),
round(torch.linalg.norm(self.lora_B_k.weight @ self.lora_A_k.weight @ space_k.T @ space_k, ord='fro').item(), 2),
round(torch.linalg.norm(self.lora_B_k.weight @ self.lora_A_k.weight @ space_k.T @ cropped_scale @ space_k, ord='fro').item(), 2),
)
for ii in range(self.cur_task + 1):
k_weight = k_weight + self.lora_B_k_list[ii].weight @ self.lora_A_k_list[ii].weight
v_weight = v_weight + self.lora_B_v_list[ii].weight @ self.lora_A_v_list[ii].weight
if not isinstance(self.space_k[ii], int):
space_k = self.space_k[ii]
space_v = self.space_v[ii]
scale_k = self.scale_param[ii]
# Q Scaling
scalee = scale_k[:space_k.shape[0], :space_k.shape[0]] #@ scale_k[:space_k.shape[0], :space_k.shape[0]].T
# QQ^T Scaling
scalee = scale_k[:space_k.shape[0], :space_k.shape[0]] @ scale_k[:space_k.shape[0], :space_k.shape[0]].T
# QQ^T Diagonal Scaling
scalee = torch.diag(torch.diag(scale_k[:space_k.shape[0], :space_k.shape[0]] @ scale_k[:space_k.shape[0], :space_k.shape[0]].T))
# Q Diagonal Scaling
scalee = torch.diag(torch.diag(scale_k[:space_k.shape[0], :space_k.shape[0]]))
# TODO: Change the Scale to remove scale, and scale by magnitude and direction
# TODO2: following TODO 1, but now unfreeze the previous scale
use_scale = True
if use_scale:
#print('Enabled scale')
#magnitude = torch.linalg.matrix_norm(space_k, ord='fro')
dir_k = space_k # / magnitude
k_weight = k_weight - k_weight @ space_k.T @ space_k + k_weight @ dir_k.T @ scalee @ dir_k
#magnitude = torch.linalg.matrix_norm(space_v, ord='fro')
dir_v = space_v # / magnitude
v_weight = v_weight - v_weight @ space_v.T @ space_v + v_weight @ dir_v.T @ scalee @ dir_v
else:
pass
#print('Disabled scale')
#if not self.training and not get_input_matrix:
# diagonal_elements = torch.diag(scalee)
# print(ii, diagonal_elements)
qkv = F.linear(x, torch.cat([q_weight, k_weight, v_weight], dim=0), self.qkv.bias.data).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
if attn_mask is not None:
attn += attn_mask.unsqueeze(0) # For head axis broadcasting
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
if register_hook:
self.save_attention_map(attn)
attn.register_hook(self.save_attn_gradients)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x, x, probs
# MLP
class Mlp(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
# Blocks
class ResidualAttentionBlock(nn.Module):
def __init__(self,
d_model: int,
n_head: int,
mlp_ratio: float = 4.,
qkv_bias: bool = True,
qk_scale: float = None,
attn_drop: float = 0.,
proj_drop: float = 0.,
drop_path: float = 0.,
attn_layer = MultiHeadAttention,
act_layer = nn.GELU,
norm_layer = nn.LayerNorm,
norm_layer_eps = 1e-5,
attn_mask: torch.Tensor = None,
text_or_image=None,
# For attn_layer = MultiHeadAttention_LoRA
lora_rank: int = 0,
lora_bias: bool = False
):
super().__init__()
if attn_layer == MultiHeadAttention:
self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop)
elif attn_layer == MultiHeadAttention_LoRA:
self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias)
elif attn_layer == MultiHeadAttention_SDLoRA:
self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias)
elif attn_layer == MultiHeadAttention_LoRA_Sub:
self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias)
elif attn_layer == MultiHeadAttention_MaskedLoRA:
self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias)
elif attn_layer == MultiHeadAttention_MultiMaskedLoRA:
self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias)
elif attn_layer == MultiHeadAttention_CL_LoRA:
self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias)
else:
assert 0, f'{attn_layer} not Implemented'
self.ln_1 = norm_layer(d_model, eps=norm_layer_eps)
self.mlp = Mlp(d_model, int(d_model * mlp_ratio), act_layer=act_layer)
self.ln_2 = norm_layer(d_model, eps=norm_layer_eps)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.attn_mask = attn_mask
self.text_or_image = text_or_image
def attention(self, x: torch.Tensor, **kwargs):
self.attn_mask = self.attn_mask.to(x) if self.attn_mask is not None else None
x = x.permute(1, 0, 2)
attn = self.attn(x, attn_mask=self.attn_mask, **kwargs)
attn = attn.permute(1, 0, 2)
return attn
def forward(self, x: torch.Tensor, **kwargs):
x = x + self.drop_path(self.attention(self.ln_1(x), **kwargs)) # [Seq, Batch, Dim]
x = x + self.drop_path(self.mlp(self.ln_2(x)))
return x
class ResidualAttentionBlock_MLP(ResidualAttentionBlock):
def __init__(self,
d_model: int,
n_head: int,
mlp_ratio: float = 4.,
qkv_bias: bool = True,
qk_scale: float = None,
attn_drop: float = 0.,
proj_drop: float = 0.,
drop_path: float = 0.,
attn_layer = MultiHeadAttention,
act_layer = nn.GELU,
norm_layer = nn.LayerNorm,
attn_mask: torch.Tensor = None,
text_or_image=None,
# For attn_layer = MultiHeadAttention_LoRA
lora_rank: int = 0,
lora_bias: bool = False,
):
super().__init__(
d_model,
n_head,
mlp_ratio,
qkv_bias,
qk_scale,
attn_drop,
proj_drop,
drop_path,
attn_layer,
act_layer,
norm_layer,
attn_mask,
text_or_image)
self.ffn_num = 64
self.adaptmlp = Adapter(d_model=d_model, dropout=0.1, bottleneck=self.ffn_num,
init_option='lora', adapter_scalar=0.1, adapter_layernorm_option='none')
self.lora_feature = None # Temporary save the output of adapter, for method : DMNSP
def attention(self, x: torch.Tensor, **kwargs):
self.attn_mask = self.attn_mask.to(x) if self.attn_mask is not None else None
x = x.permute(1, 0, 2)
attn = self.attn(x, attn_mask=self.attn_mask, **kwargs)
attn = attn.permute(1, 0, 2)
return attn
def forward(self, x: torch.Tensor, compute_lora_feat = False, **kwargs):
x = x + self.drop_path(self.attention(self.ln_1(x), **kwargs)) # [Seq, Batch, Dim]
x_re = x.permute(1, 0, 2)
adapt_x = self.adaptmlp(x_re, add_residual=False)
adapt_x = adapt_x.permute(1, 0, 2)
x = x + self.drop_path(self.mlp(self.ln_2(x)) + adapt_x)
if compute_lora_feat:
self.lora_feature = adapt_x.detach().cpu()
return x
class ResidualAttentionBlock_MaskedMLP(ResidualAttentionBlock):
def __init__(self,
d_model: int,
n_head: int,
mlp_ratio: float = 4.,
qkv_bias: bool = True,
qk_scale: float = None,
attn_drop: float = 0.,
proj_drop: float = 0.,
drop_path: float = 0.,
attn_layer = MultiHeadAttention,
act_layer = nn.GELU,
norm_layer = nn.LayerNorm,
attn_mask: torch.Tensor = None,
text_or_image=None,
# For attn_layer = MultiHeadAttention_LoRA
lora_rank: int = 0,
lora_bias: bool = False,
):
super().__init__(
d_model,
n_head,
mlp_ratio,
qkv_bias,
qk_scale,
attn_drop,
proj_drop,
drop_path,
attn_layer,
act_layer,
norm_layer,
attn_mask,
text_or_image)
self.ffn_num = 64
self.adaptmlp = MaskedAdapter(d_model=d_model, dropout=0.1, bottleneck=self.ffn_num,
init_option='lora', adapter_scalar=0.1, adapter_layernorm_option='none')
def attention(self, x: torch.Tensor, **kwargs):
self.attn_mask = self.attn_mask.to(x) if self.attn_mask is not None else None
x = x.permute(1, 0, 2)
attn = self.attn(x, attn_mask=self.attn_mask, **kwargs)
attn = attn.permute(1, 0, 2)
return attn
def forward(self, x: torch.Tensor, compute_input_matrix = False, **kwargs):
x = x + self.drop_path(self.attention(self.ln_1(x), **kwargs)) # [Seq, Batch, Dim]
x_re = x.permute(1, 0, 2)
adapt_x = self.adaptmlp(x_re, add_residual=False, compute_input_matrix=compute_input_matrix)
adapt_x = adapt_x.permute(1, 0, 2)
x = x + self.drop_path(self.mlp(self.ln_2(x)) + adapt_x)
return x
class ResidualAttentionBlock_MoE_MLP(ResidualAttentionBlock):
def __init__(self,
d_model: int,
n_head: int,
mlp_ratio: float = 4.,
qkv_bias: bool = True,
qk_scale: float = None,
attn_drop: float = 0.,
proj_drop: float = 0.,
drop_path: float = 0.,
attn_layer = MultiHeadAttention,
act_layer = nn.GELU,
norm_layer = nn.LayerNorm,
attn_mask: torch.Tensor = None,
text_or_image=None,
# For attn_layer = MultiHeadAttention_LoRA
lora_rank: int = 0,
lora_bias: bool = False,
# MoE
step: int = 0,
experts_num: int = 0,
top_k: int = 0,
noisy_gating: bool = True
):
super().__init__(
d_model,
n_head,
mlp_ratio,
qkv_bias,
qk_scale,
attn_drop,
proj_drop,
drop_path,
attn_layer,
act_layer,
norm_layer,
attn_mask,
text_or_image)
assert top_k <= experts_num
self.register_buffer("mean", torch.tensor([0.0]))
self.register_buffer("std", torch.tensor([1.0]))
self.step = step
self.top_k = top_k
self.noisy_gating = noisy_gating
self.ffn_num = 64
self.experts_num = experts_num
self.softmax = nn.Softmax(1)
self.softplus = nn.Softplus()
self.router_list = nn.ParameterList([
nn.Parameter(torch.zeros(d_model, self.experts_num), requires_grad=True) for _ in range(self.step)
])
self.w_noise_list = nn.ParameterList([
nn.Parameter(torch.zeros(d_model, self.experts_num), requires_grad=True) for _ in range(self.step)
])
self.adaptmlp_list = nn.ModuleList([
Adapter(d_model=d_model, dropout=0.1, bottleneck=self.ffn_num,
init_option='lora',
adapter_scalar=0.1,
adapter_layernorm_option='none')
for _ in range(self.experts_num)
])
self.lora_feature = None # Temporary save the output of adapter, for method : DMNSP
def attention(self, x: torch.Tensor, **kwargs):
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
x = x.permute(1, 0, 2)
attn = self.attn(x, attn_mask=self.attn_mask, **kwargs)
attn = attn.permute(1, 0, 2)
return attn
def cv_squared(self, x):
"""The squared coefficient of variation of a sample.
Useful as a loss to encourage a positive distribution to be more uniform.
Epsilons added for numerical stability.
Returns 0 for an empty Tensor.
Args:
x: a `Tensor`.
Returns:
a `Scalar`.
"""
eps = 1e-10
# if only num_experts = 1
if x.shape[0] == 1:
return torch.tensor([0], device=x.device, dtype=x.dtype)
return x.float().var() / (x.float().mean()**2 + eps)
def _gates_to_load(self, gates):
"""Compute the true load per expert, given the gates.
The load is the number of examples for which the corresponding gate is >0.
Args:
gates: a `Tensor` of shape [batch_size, n]
Returns:
a float32 `Tensor` of shape [n]
"""
return (gates > 0).sum(0)
def _prob_in_top_k(self, clean_values, noisy_values, noise_stddev, noisy_top_values):
"""Helper function to NoisyTopKGating.
Computes the probability that value is in top k, given different random noise.
This gives us a way of backpropagating from a loss that balances the number
of times each expert is in the top k experts per example.
In the case of no noise, pass in None for noise_stddev, and the result will
not be differentiable.
Args:
clean_values: a `Tensor` of shape [batch, n].
noisy_values: a `Tensor` of shape [batch, n]. Equal to clean values plus
normally distributed noise with standard deviation noise_stddev.
noise_stddev: a `Tensor` of shape [batch, n], or None
noisy_top_values: a `Tensor` of shape [batch, m].
"values" Output of tf.top_k(noisy_top_values, m). m >= k+1
Returns:
a `Tensor` of shape [batch, n].
"""
# print('1231',clean_values) # 全nan
batch = clean_values.size(0)
m = noisy_top_values.size(1)
top_values_flat = noisy_top_values.flatten()
threshold_positions_if_in = torch.arange(batch, device=clean_values.device) * m + self.top_k
threshold_if_in = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_in), 1)
is_in = torch.gt(noisy_values, threshold_if_in)
threshold_positions_if_out = threshold_positions_if_in - 1
threshold_if_out = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_out), 1)
# is each value currently in the top k.
normal = Normal(self.mean, self.std)
#
prob_if_in = normal.cdf((clean_values - threshold_if_in)/noise_stddev)
prob_if_out = normal.cdf((clean_values - threshold_if_out)/noise_stddev)
prob = torch.where(is_in, prob_if_in, prob_if_out)
return prob
def noisy_top_k_gating(self, x, train, w_gate, w_noise, noise_epsilon=1e-2):
"""Noisy top-k gating.
See paper: https://arxiv.org/abs/1701.06538.
Args:
x: input Tensor with shape [batch_size, input_size]
train: a boolean - we only add noise at training time.
noise_epsilon: a float
Returns:
gates: a Tensor with shape [batch_size, num_experts]
load: a Tensor with shape [num_experts]
"""
clean_logits = x @ w_gate.to(x)
if self.noisy_gating and train:
raw_noise_stddev = x @ w_noise.to(x)
noise_stddev = ((self.softplus(raw_noise_stddev) + noise_epsilon))
noisy_logits = clean_logits + (torch.randn_like(clean_logits) * noise_stddev)
logits = noisy_logits
else:
logits = clean_logits
# calculate topk + 1 that will be needed for the noisy gates
top_logits, top_indices = logits.topk(min(self.top_k + 1, self.experts_num), dim=1)
top_k_logits = top_logits[:, :self.top_k]
top_k_indices = top_indices[:, :self.top_k]
top_k_gates = self.softmax(top_k_logits)
zeros = torch.zeros_like(logits)
gates = zeros.scatter(1, top_k_indices, top_k_gates)
#if self.noisy_gating and self.top_k < self.experts_num and train: # 目前未用上
# load = (self._prob_in_top_k(clean_logits, noisy_logits, noise_stddev, top_logits)).sum(0)
#else:
# load = self._gates_to_load(gates)
return gates, None #, load
def forward(self, x: torch.Tensor, compute_lora_feat=False, **kwargs):
x = x + self.drop_path(self.attention(self.ln_1(x), **kwargs)) # [Seq, Batch, Dim]
x_re = x.permute(1, 0, 2)[:, 0, :]
gates, load = self.noisy_top_k_gating(x_re, self.training, self.router_list[0],
self.w_noise_list[0]) # hardcoded, task_id = 0
dispatcher = SparseDispatcher(self.experts_num, gates)
expert_inputs = dispatcher.dispatch(x.permute(1, 0, 2).view(x.shape[1], -1))
expert_outputs = [self.adaptmlp_list[i](expert_inputs[i].view(expert_inputs[i].shape[0],
x.shape[0], x.shape[2]).to(x), add_residual=False)
for i in range(self.experts_num)]
expert_outputs = [out.view(out.shape[0], -1) for out in expert_outputs if out.shape[0] > 0]
y = dispatcher.combine(expert_outputs)
y = y.view(x.shape[1], x.shape[0], x.shape[2])
x = x + self.drop_path(self.mlp(self.ln_2(x)) + y.permute(1, 0, 2))
return x
class ResidualAttentionBlock_MoE_Proj(ResidualAttentionBlock):
def __init__(self,
d_model: int,
n_head: int,
mlp_ratio: float = 4.,
qkv_bias: bool = True,
qk_scale: float = None,
attn_drop: float = 0.,
proj_drop: float = 0.,
drop_path: float = 0.,
attn_layer = MultiHeadAttention,
act_layer = nn.GELU,
norm_layer = nn.LayerNorm,
attn_mask: torch.Tensor = None,
text_or_image=None,
# For attn_layer = MultiHeadAttention_LoRA
lora_rank: int = 0,
lora_bias: bool = False,
# MoE
experts_num=0,
):
super().__init__()
if isinstance(attn_layer, str):
try:
attn_layer = globals()[attn_layer]
except KeyError:
print(f'{attn_layer} not found, using default MultiHeadAttention')
attn_layer = MultiHeadAttention
if isinstance(act_layer, str):
try:
act_layer = globals()[act_layer]
except KeyError:
print(f'{act_layer} not found, using default nn.GELU')
act_layer = nn.GELU
if isinstance(norm_layer, str):
try:
norm_layer = globals()[norm_layer]
except KeyError:
print(f'{norm_layer} not found, using default nn.LayerNorm')
norm_layer = nn.LayerNorm
self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop)
self.ln_1 = norm_layer(d_model)
self.mlp = Mlp(d_model, int(d_model * mlp_ratio), act_layer=act_layer)
self.ln_2 = norm_layer(d_model)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.attn_mask = attn_mask
self.is_train = True
# TODO : make it argument, now harcodrd
if experts_num > 1:
self.register_buffer("mean", torch.tensor([0.0]))
self.register_buffer("std", torch.tensor([1.0]))
self.step = 1
else:
self.step = 0
self.top_k = 2
self.ffn_num = 64
self.experts_num = experts_num
self.softmax = nn.Softmax(1)
self.softplus = nn.Softplus()
self.noisy_gating = True
self.text_or_image = text_or_image
self.router_list = nn.ParameterList()
self.w_noise_list = nn.ParameterList()
for i in range(self.step):
self.router_list.append(nn.Parameter(torch.zeros(d_model, self.experts_num), requires_grad=True))
self.w_noise_list.append(nn.Parameter(torch.zeros(d_model, self.experts_num), requires_grad=True))
self.adaptmlp_list = nn.ModuleList()
for i in range(self.experts_num): #
self.adaptmlp_list.append(Adapter(d_model=d_model, dropout=0.1, bottleneck=self.ffn_num,
init_option='lora',
adapter_scalar=0.1,
adapter_layernorm_option='none',
))
self.lora_feature = None # Temporary save the output of adapter, for method : DMNSP
def attention(self, x: torch.Tensor, **kwargs):
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
x = x.permute(1, 0, 2)
attn = self.attn(x, attn_mask=self.attn_mask, **kwargs)
attn = attn.permute(1, 0, 2)
return attn
def cv_squared(self, x):
"""The squared coefficient of variation of a sample.
Useful as a loss to encourage a positive distribution to be more uniform.
Epsilons added for numerical stability.
Returns 0 for an empty Tensor.
Args:
x: a `Tensor`.
Returns:
a `Scalar`.
"""
eps = 1e-10
# if only num_experts = 1
if x.shape[0] == 1:
return torch.tensor([0], device=x.device, dtype=x.dtype)
return x.float().var() / (x.float().mean()**2 + eps)
def _gates_to_load(self, gates):
"""Compute the true load per expert, given the gates.
The load is the number of examples for which the corresponding gate is >0.
Args:
gates: a `Tensor` of shape [batch_size, n]
Returns:
a float32 `Tensor` of shape [n]
"""
return (gates > 0).sum(0)
def _prob_in_top_k(self, clean_values, noisy_values, noise_stddev, noisy_top_values):
"""Helper function to NoisyTopKGating.
Computes the probability that value is in top k, given different random noise.
This gives us a way of backpropagating from a loss that balances the number
of times each expert is in the top k experts per example.
In the case of no noise, pass in None for noise_stddev, and the result will
not be differentiable.
Args:
clean_values: a `Tensor` of shape [batch, n].
noisy_values: a `Tensor` of shape [batch, n]. Equal to clean values plus
normally distributed noise with standard deviation noise_stddev.
noise_stddev: a `Tensor` of shape [batch, n], or None
noisy_top_values: a `Tensor` of shape [batch, m].
"values" Output of tf.top_k(noisy_top_values, m). m >= k+1
Returns:
a `Tensor` of shape [batch, n].
"""
# print('1231',clean_values) # 全nan
batch = clean_values.size(0)
m = noisy_top_values.size(1)
top_values_flat = noisy_top_values.flatten()
threshold_positions_if_in = torch.arange(batch, device=clean_values.device) * m + self.top_k
threshold_if_in = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_in), 1)
is_in = torch.gt(noisy_values, threshold_if_in)
threshold_positions_if_out = threshold_positions_if_in - 1
threshold_if_out = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_out), 1)
# is each value currently in the top k.
normal = Normal(self.mean, self.std)
#
prob_if_in = normal.cdf((clean_values - threshold_if_in)/noise_stddev)
prob_if_out = normal.cdf((clean_values - threshold_if_out)/noise_stddev)
prob = torch.where(is_in, prob_if_in, prob_if_out)
return prob
def noisy_top_k_gating(self, x, train, w_gate, w_noise, noise_epsilon=1e-2):
"""Noisy top-k gating.
See paper: https://arxiv.org/abs/1701.06538.
Args:
x: input Tensor with shape [batch_size, input_size]
train: a boolean - we only add noise at training time.
noise_epsilon: a float
Returns:
gates: a Tensor with shape [batch_size, num_experts]
load: a Tensor with shape [num_experts]
"""
clean_logits = x @ w_gate.to(x)
if self.noisy_gating and train:
raw_noise_stddev = x @ w_noise.to(x)
noise_stddev = ((self.softplus(raw_noise_stddev) + noise_epsilon))
noisy_logits = clean_logits + (torch.randn_like(clean_logits) * noise_stddev)
logits = noisy_logits
else:
logits = clean_logits
# calculate topk + 1 that will be needed for the noisy gates
top_logits, top_indices = logits.topk(min(self.top_k + 1, self.experts_num), dim=1)
top_k_logits = top_logits[:, :self.top_k]
top_k_indices = top_indices[:, :self.top_k]
top_k_gates = self.softmax(top_k_logits)
zeros = torch.zeros_like(logits)
gates = zeros.scatter(1, top_k_indices, top_k_gates)
#if self.noisy_gating and self.top_k < self.experts_num and train: # 目前未用上
# load = (self._prob_in_top_k(clean_logits, noisy_logits, noise_stddev, top_logits)).sum(0)
#else:
# load = self._gates_to_load(gates)
return gates, None #, load
def forward(self, x: torch.Tensor, **kwargs):
x = x + self.drop_path(self.attention(self.ln_1(x), **kwargs)) # [Seq, Batch, Dim]
if self.experts_num == 0:
x = x + self.drop_path(self.mlp(self.ln_2(x)))
elif self.experts_num == 1:
x_re = x.permute(1, 0, 2)
adapt_x = self.adaptmlp_list[0](x_re, add_residual=False)
adapt_x = adapt_x.permute(1, 0, 2)
x = x + self.drop_path(self.mlp(self.ln_2(x)) + adapt_x)
if compute_lora_feat:
self.lora_feature = adapt_x.detach().cpu()
else:
x_re = x.permute(1, 0, 2)[:, 0, :]
gates, load = self.noisy_top_k_gating(x_re, self.is_train, self.router_list[0],
self.w_noise_list[0]) # hardcoded, task_id = 0
dispatcher = SparseDispatcher(self.experts_num, gates)
expert_inputs = dispatcher.dispatch(x.permute(1, 0, 2).view(x.shape[1], -1))
expert_outputs = [self.adaptmlp_list[i](expert_inputs[i].view(expert_inputs[i].shape[0],
x.shape[0], x.shape[2]).to(x), add_residual=False)
for i in range(self.experts_num)]
expert_outputs = [out.view(out.shape[0], -1) for out in expert_outputs if out.shape[0] > 0]
y = dispatcher.combine(expert_outputs)
y = y.view(x.shape[1], x.shape[0], x.shape[2])
x = x + self.drop_path(self.mlp(self.ln_2(x)) + y.permute(1, 0, 2))
return x
class ResidualAttentionBiBlock(nn.Module):
def __init__(self,
d_model: int,
n_head: int,
mlp_ratio: float = 4.,
qkv_bias: bool = True,
qk_scale: float = None,
attn_drop: float = 0.,
proj_drop: float = 0.,
drop_path: float = 0.,
attn_layer = MultiHeadAttention,
act_layer = nn.GELU,
norm_layer = nn.LayerNorm,
attn_mask: torch.Tensor = None,
text_or_image=None,
# For attn_layer = MultiHeadAttention_LoRA
lora_rank: int = 0,
lora_bias: bool = False
):
super().__init__()
if attn_layer == MultiHeadAttention:
self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop)
elif attn_layer == MultiHeadAttention_LoRA:
self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias)
elif attn_layer == MultiHeadAttention_MaskedLoRA:
self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias)
elif attn_layer == MultiHeadAttention_MultiMaskedLoRA or attn_layer == MultiHeadAttention_MultiMaskedLoRA3 or attn_layer == MultiHeadAttention_MaskedLoRA1:
self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias)
else:
assert 0, f'{attn_layer} not Implemented'
self.ln_1 = norm_layer(d_model)
self.mlp = Mlp(d_model, int(d_model * mlp_ratio), act_layer=act_layer)
self.ln_2 = norm_layer(d_model)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.attn_mask = attn_mask
self.text_or_image = text_or_image
def attention(self, x: torch.Tensor, x_proj, probs, **kwargs):
self.attn_mask = self.attn_mask.to(x) if self.attn_mask is not None else None
x, x_proj = x.permute(1, 0, 2), x_proj.permute(1, 0, 2)
attn, attn_proj, probs = self.attn(x, x_proj, probs, attn_mask=self.attn_mask, **kwargs)
attn, attn_proj = attn.permute(1, 0, 2), attn_proj.permute(1, 0, 2)
return attn, attn_proj, probs
def forward(self, x: torch.Tensor, x_proj, probs, **kwargs):
attn, attn_proj, probs = self.attention(self.ln_1(x), self.ln_1(x_proj), probs, **kwargs)
x = x + self.drop_path(attn) # [Seq, Batch, Dim]
x_proj = x_proj + self.drop_path(attn_proj)
x = x + self.drop_path(self.mlp(self.ln_2(x)))
x_proj = x_proj + self.drop_path(self.mlp(self.ln_2(x_proj)))
return x, x_proj, probs
# Transformers
class Transformer(nn.Module):
def __init__(self,
width: int,
layers: int,
heads: int,
block_layer = ResidualAttentionBlock,
attn_layer = MultiHeadAttention,
act_layer = nn.GELU,
norm_layer = nn.LayerNorm,
attn_mask: torch.Tensor = None,
text_or_image=None,
**kwargs
):
super().__init__()
self.width = width
self.layers = layers
if isinstance(block_layer, str):
try:
block_layer = globals()[block_layer]
except KeyError:
print(f'{block_layer} not found, using default ResidualAttentionBlock')
block_layer = ResidualAttentionBlock
if isinstance(attn_layer, str):
try:
attn_layer = globals()[attn_layer]
except KeyError:
print(f'{attn_layer} not found, using default MultiHeadAttention')
attn_layer = MultiHeadAttention
if isinstance(act_layer, str):
try:
act_layer = globals()[act_layer]
except KeyError:
print(f'{act_layer} not found, using default nn.GELU')
act_layer = nn.GELU
if isinstance(norm_layer, str):
try:
norm_layer = globals()[norm_layer]
except KeyError:
print(f'{norm_layer} not found, using default nn.LayerNorm')
norm_layer = nn.LayerNorm
self.blocks = nn.ModuleList([
block_layer(
d_model=width,
n_head=heads,
attn_layer=attn_layer,
act_layer=act_layer,
norm_layer=norm_layer,
attn_mask=attn_mask,
text_or_image=text_or_image,
**kwargs)
for _ in range(layers)])
def forward(self, x: torch.Tensor, l2p_prompt=None, l2p_e_prompt_layer_idx=[], **kwargs):
prompt_counter = -1
for i, block in enumerate(self.blocks):
if l2p_prompt is not None and (i in l2p_e_prompt_layer_idx):
prompt_counter += 1
batched_prompt = l2p_prompt[prompt_counter]
batched_prompt = batched_prompt.permute(1, 0, 2) # (B, Prompt_len, C) -> (Prompt_len, B, C), since x is also (N, B, C)
x = torch.cat([batched_prompt, x], dim=0) # append to dim N
x = block(x, **kwargs)
return x
class Transformer_Proj(Transformer):
def __init__(self,
width: int,
layers: int,
heads: int,
block_layer = ResidualAttentionBlock,
attn_layer = MultiHeadAttention,
act_layer = nn.GELU,
norm_layer = nn.LayerNorm,
attn_mask: torch.Tensor = None,
text_or_image=None,
**kwargs
):
super().__init__(width, layers, heads, block_layer, attn_layer, act_layer, norm_layer, attn_mask, text_or_image, **kwargs)
self.probs = []
def forward(self, x: torch.Tensor, **kwargs):
x_proj = x.clone()
self.probs = []
for i, block in enumerate(self.blocks):
x, x_proj, self.probs = block(x, x_proj, self.probs, **kwargs)
return x_proj
class Transformer_CL_LoRA(Transformer):
def __init__(self,
width: int,
layers: int,
heads: int,
block_layer = ResidualAttentionBlock,
attn_layer = MultiHeadAttention,
act_layer = nn.GELU,
norm_layer = nn.LayerNorm,
attn_mask: torch.Tensor = None,
text_or_image=None,
**kwargs
):
super().__init__(width, layers, heads, block_layer, attn_layer, act_layer, norm_layer, attn_mask, text_or_image, **kwargs)
def forward(self, x, adapt, prompt, rank_prompt, block_weight, **kwargs):
for idx, blk in enumerate(self.blocks):
if idx >= 6:
x = blk(
x,
adapt = adapt[idx],
prompt = prompt,
rank_prompt = rank_prompt,
block_weight = block_weight[:, idx - 6],
**kwargs
)
else:
x = blk(
x,
adapt = adapt[idx],
prompt = prompt,
rank_prompt = rank_prompt,
block_weight = None,
**kwargs
)
return x
# ViT from CLIP
class VisualTransformer(nn.Module):
def __init__(self,
img_size: int,
patch_size: int,
in_chans: int = 3,
width: int = 768,
depth: int = 12,
heads: int = 8,
output_dim: int = 512,
text_or_image: str = None,
**kwargs
):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.in_chans = in_chans
self.width = width
self.depth = depth
self.heads = heads
self.output_dim = output_dim
self.conv1 = nn.Conv2d(in_channels=in_chans, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
scale = width ** -0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(scale * torch.randn((img_size // patch_size) ** 2 + 1, width))
self.ln_pre = LayerNorm(width)
self.transformer = Transformer(width, depth, heads, text_or_image=text_or_image, **kwargs)
self.ln_post = LayerNorm(width)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
def forward(self, x: torch.Tensor, **kwargs):
x = self.conv1(x)
x = x.reshape(x.shape[0], x.shape[1], -1)
x = x.permute(0, 2, 1)
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND / [Batch_Size, Seq_len, Dim] -> [Seq_len, Batch_Size, Dim]
x = self.transformer(x, **kwargs)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_post(x[:, 0, :])
if self.proj is not None:
x = x @ self.proj
return x
# Standard ViT
class VisionTransformer(nn.Module):
""" Vision Transformer
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
https://arxiv.org/abs/2010.11929
"""
def __init__(self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
embed_dim=768,
depth=12,
num_heads=12,
attn_layer=MultiHeadAttention,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
representation_size=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer=nn.LayerNorm,
ckpt_layer=0,
transformer_layer=Transformer,
**kwargs):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
num_classes (int): number of classes for classification head
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
norm_layer: (nn.Module): normalization layer
"""
super().__init__()
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_heads = num_heads
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
if transformer_layer == 'Transformer_Proj':
self.transformer = Transformer_Proj(embed_dim, depth, num_heads, text_or_image='image', attn_layer=attn_layer, norm_layer=norm_layer, **kwargs)
elif transformer_layer == 'Transformer_CL_LoRA':
self.transformer = Transformer_CL_LoRA(embed_dim, depth, num_heads, text_or_image='image', attn_layer=attn_layer, norm_layer=norm_layer, **kwargs)
else:
self.transformer = Transformer(embed_dim, depth, num_heads, text_or_image='image', attn_layer=attn_layer, norm_layer=norm_layer, **kwargs)
self.norm = partial(nn.LayerNorm, eps=1e-6)(embed_dim)
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def forward(self, x, register_blk=-1, prompt=None, prompt_flag='', q=None, train=False, task_id=-1, cls_features=None, **kwargs):
B = x.shape[0]
x = self.patch_embed(x)
if prompt_flag == 'l2p':
batched_prompt = None
e_prompt_layer_idx = []
if prompt:
num_prompted_layers = 1
e_prompt_layer_idx = [0]
total_prompt_len = prompt.length * prompt.top_k * len(e_prompt_layer_idx)
batched_prompt, reduce_sim = prompt(x, cls_features=cls_features)
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed[:, :x.size(1), :]
x = self.pos_drop(x)
x = x.permute(1, 0, 2) # (B, N ,C) -> (N, B ,C)
x = self.transformer(
x,
l2p_prompt = batched_prompt,
l2p_e_prompt_layer_idx = e_prompt_layer_idx,
**kwargs
)
x = x.permute(1, 0, 2) # (N, B ,C) -> (B, N ,C)
x = self.norm(x)
if prompt:
x = x[:, :total_prompt_len]
x = x.mean(dim=1)
return x, reduce_sim
else:
return x[:, 0]
else:
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed[:,:x.size(1),:]
x = self.pos_drop(x)
# TODO: clean, move everything to trasnformer
prompt_loss = torch.zeros((1,), requires_grad=True).to(x.device)
if prompt is not None:
for i,blk in enumerate(self.transformer.blocks):
if prompt is not None:
if train:
p_list, loss, x = prompt.forward(q, i, x, train=True, task_id=task_id)
prompt_loss += loss
else:
p_list, _, x = prompt.forward(q, i, x, train=False, task_id=task_id)
else:
p_list = None
# the blk only takes x in shape [N, B, C] not [B, N ,C]
x = x.permute(1, 0, 2)
x = blk(x, register_hook=register_blk==i, prompt=p_list, **kwargs)
x = x.permute(1, 0, 2)
else:
x = x.permute(1, 0, 2)
x = self.transformer(x, **kwargs)
x = x.permute(1, 0, 2)
x = self.norm(x)
return x, prompt_loss
@torch.jit.ignore()
def load_pretrained(self, checkpoint_path, prefix=''):
_load_weights(self, checkpoint_path, prefix)
class VisionTransformer_CL_LoRA(VisionTransformer):
""" Vision Transformer
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
https://arxiv.org/abs/2010.11929
"""
class Adapter_lora(nn.Module):
def __init__(self,
config=None,
d_model=None,
bottleneck=None,
dropout=0.0,
init_option="bert",
adapter_scalar="1.0",
adapter_layernorm_option="in"):
super().__init__()
self.n_embd = config.d_model if d_model is None else d_model
self.down_size = config.attn_bn if bottleneck is None else bottleneck
self.lora_A = nn.Linear(self.down_size, self.n_embd, bias=False)
self.lora_B = nn.Linear(self.n_embd, self.down_size, bias=False)
random_matrix = torch.rand(self.n_embd, self.down_size)
q, r = torch.linalg.qr(random_matrix)
with torch.no_grad():
self.lora_B.weight.copy_(q.T)
scaling_factor = 1. # You can adjust this value if needed
self.lora_B.weight.data *= scaling_factor
if init_option == "bert":
raise NotImplementedError
elif init_option == "lora":
with torch.no_grad():
nn.init.zeros_(self.lora_A.weight)
else:
raise NotImplementedError
def forward(self, x):
inter_x = self.lora_B(x)
out = self.lora_A(inter_x)
return out
def __init__(self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
embed_dim=768,
depth=12,
num_heads=12,
attn_layer=MultiHeadAttention,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
representation_size=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer=nn.LayerNorm,
ckpt_layer=0,
transformer_layer=Transformer,
**kwargs):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
num_classes (int): number of classes for classification head
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
norm_layer: (nn.Module): normalization layer
"""
super().__init__(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
num_classes=num_classes,
embed_dim=embed_dim,
depth=depth,
num_heads=num_heads,
attn_layer=attn_layer,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
representation_size=representation_size,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate,
norm_layer=norm_layer,
ckpt_layer=ckpt_layer,
transformer_layer=transformer_layer,
**kwargs
)
cfg_dict = {
'use_distillation': True,
'use_block_weight': True,
'msa_adapt': True,
'msa': [1, 0, 1],
'specfic_pos': [6, 7, 8, 9, 10, 11],
'general_pos': [0, 1, 2, 3, 4, 5],
'ffn_adapt': True,
'ffn_option': 'parallel',
'ffn_adapter_layernorm_option': 'none',
'ffn_adapter_init_option': 'lora',
'ffn_adapter_scalar': '0.1',
'ffn_num': kwargs['lora_rank'],
'd_model': 768,
'vpt_on': False,
'vpt_num': 0,
'_device': 'cuda:0'
}
from types import SimpleNamespace
self.tuning_config = SimpleNamespace(**cfg_dict)
self.config = self.tuning_config
self._device = self.tuning_config._device
self.msa_adapt = self.tuning_config.msa_adapt
self.use_distillation = self.tuning_config.use_distillation
self.use_block_weight = self.tuning_config.use_block_weight
self.general_pos = self.tuning_config.general_pos
self.specfic_pos = self.tuning_config.specfic_pos
self.adapt_pos = self.general_pos + self.specfic_pos
self.adapt_pos = sorted(self.adapt_pos)
if self.msa_adapt:
self.msa = self.tuning_config.msa
if self.use_distillation:
self.old_adapter_list = nn.ModuleList()
if self.use_block_weight:
self.block_weight_list = []
self.block_weight = nn.Parameter(torch.randn(3, len(self.specfic_pos)))
nn.init.uniform_(self.block_weight, .5, 1.5)
self.adapter_list = []
self.adapter_pos_list = []
self.cur_adapter = nn.ModuleList()
self.get_new_adapter_initial_msa()
def forward(self, x, test = False, register_blk=-1, prompt=None, prompt_flag='', q=None, train=False, task_id=-1, cls_features=None, **kwargs):
if not test:
output = self.forward_train(x)
output = output[:, 0]
return output, None # [bs, 768]
else:
features = self.forward_test(x)
output = torch.Tensor().to(features[0].device)
for x in features:
cls = x[:, 0, :]
output = torch.cat((
output,
cls
), dim=1)
return output, None # [bs, task_id * 768]
def forward_train(self, x):
B = x.shape[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed[:,:x.size(1),:]
x = self.pos_drop(x)
x = x.permute(1, 0, 2)
x = self.transformer(
x,
adapt = self.cur_adapter,
prompt = None,
rank_prompt = None,
block_weight=self.block_weight)
x = x.permute(1, 0, 2)
x = self.norm(x)
return x
def forward_test(self, x, use_init_ptm=False):
import copy
B = x.shape[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
x_init = self.pos_drop(x)
features = []
assert self.config.ffn_adapt
assert self.adapt_pos == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
assert self.general_pos == [0, 1, 2, 3, 4, 5]
assert self.use_block_weight
# len(self.adapter_list) == cur_task_id
for i in range(len(self.adapter_list)):
x = copy.deepcopy(x_init)
x = x.permute(1, 0, 2)
for idx, blk in enumerate(self.transformer.blocks):
if idx >= 6:
x = blk(x, adapt = self.adapter_list[i][idx - 6], prompt = None, rank_prompt = None,
block_weight=self.block_weight_list[i][:, idx - 6])
else:
x = blk(x, adapt = self.cur_adapter[idx], prompt = None, rank_prompt = None, block_weight=None)
x = x.permute(1, 0, 2)
x = self.norm(x)
features.append(x)
x = copy.deepcopy(x_init)
x = x.permute(1, 0, 2)
for idx, blk in enumerate(self.transformer.blocks):
if idx >= 6:
x = blk(x, adapt = self.cur_adapter[idx], prompt = None, rank_prompt = None,
block_weight=self.block_weight[:, idx - 6])
else:
x = blk(x, adapt = self.cur_adapter[idx], prompt = None, rank_prompt = None, block_weight=None)
x = x.permute(1, 0, 2)
x = self.norm(x)
features.append(x)
return features
def forward_proto(self, x, adapt_index):
assert adapt_index > -1
assert self.config.ffn_adapt
assert self.use_block_weight
B = x.shape[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
x = self.pos_drop(x)
if adapt_index < len(self.adapter_list):
x = x.permute(1, 0, 2)
for idx, blk in enumerate(self.transformer.blocks):
if idx >= 6:
x = blk(x, adapt = self.adapter_list[adapt_index][idx - 6], prompt = None, rank_prompt = None,
block_weight=self.block_weight_list[adapt_index][:, idx - 6])
else:
x = blk(x, adapt = self.cur_adapter[idx], prompt = None, rank_prompt = None, block_weight=None)
x = x.permute(1, 0, 2)
else:
x = x.permute(1, 0, 2)
for idx, blk in enumerate(self.transformer.blocks):
if idx >= 6:
x = blk(x, adapt = self.cur_adapter[idx], prompt = None, rank_prompt = None,
block_weight=self.block_weight[:, idx - 6])
else:
x = blk(x, adapt = self.cur_adapter[idx], prompt = None, rank_prompt = None, block_weight=None)
x = x.permute(1, 0, 2)
x = self.norm(x)
x = x[:, 0, :]
return x
def forward_general_cls(self, x, t_idx):
import copy
B = x.shape[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
x = self.pos_drop(x)
x_teacher = copy.deepcopy(x)
for j in range(6): # [0, ..., 5]
x = self.transformer.blocks[j](x, adapt = self.cur_adapter[j])
x_teacher = self.transformer.blocks[j](x_teacher, adapt = self.old_adapter_list[t_idx-1][j])
x = self.norm(x)
output_new = x[:, 0, :]
x_teacher = self.norm(x_teacher)
output_teacher= x_teacher[:, 0, :]
return output_new, output_teacher
def get_new_adapter_initial_msa(self):
config = self.config
if config.ffn_adapt:
for i in range(len(self.adapt_pos)):
temp_adapter = nn.ModuleList()
for j in self.msa:
if j ==1:
adapter = VisionTransformer_CL_LoRA.Adapter_lora(self.config, dropout=0.0, bottleneck=config.ffn_num,
init_option=config.ffn_adapter_init_option,
adapter_scalar=config.ffn_adapter_scalar,
adapter_layernorm_option=config.ffn_adapter_layernorm_option,
).to(self._device)
else:
adapter = nn.Identity()
temp_adapter.append(adapter)
self.cur_adapter.append(temp_adapter)
self.cur_adapter.requires_grad_(True)
else:
print("====Not use adapter===")
def add_adapter_to_list(self):
temp_adapter = []
import copy
for i in range(len(self.specfic_pos)):
temp_pos = self.adapt_pos.index(self.specfic_pos[i])
temp_adapter.append(copy.deepcopy(self.cur_adapter[temp_pos].requires_grad_(False)))
self.adapter_list.append(temp_adapter)
if self.use_block_weight:
self.block_weight_old = copy.deepcopy(self.block_weight)
self.block_weight_list.append(self.block_weight_old.requires_grad_(False))
self.block_weight = nn.Parameter(torch.randn(3, len(self.specfic_pos)))
nn.init.uniform_(self.block_weight, .5, 1.5)
self.adapter_pos_list.append(self.adapt_pos)
if self.use_distillation:
self.old_adapter_list.append(copy.deepcopy(self.cur_adapter).requires_grad_(False))
if self.msa_adapt:
self.get_new_adapter_msa()
def get_new_adapter_msa(self):
config = self.config
if config.ffn_adapt:
for i in range(len(self.specfic_pos)):
pos = self.adapt_pos.index(self.specfic_pos[i])
temp_adapter = nn.ModuleList()
for j in self.msa:
if j == 1:
adapter = VisionTransformer_CL_LoRA.Adapter_lora(self.config, dropout=0.0, bottleneck=config.ffn_num,
init_option=config.ffn_adapter_init_option,
adapter_scalar=config.ffn_adapter_scalar,
adapter_layernorm_option=config.ffn_adapter_layernorm_option,
).to(self._device)
adapter.requires_grad_(True)
else:
adapter = nn.Identity()
temp_adapter.append(adapter)
self.cur_adapter[pos] = temp_adapter
if len(self.specfic_pos) < 12:
self.cur_adapter.requires_grad_(True)
for i in self.adapt_pos:
if i in self.general_pos:
pos = self.adapt_pos.index(i)
for j in range(len(self.msa)):
if self.msa[j] == 1:
self.cur_adapter[pos][j].lora_B.requires_grad_(False)
else:
print("====Not use adapter===")
@torch.jit.ignore()
def load_pretrained(self, checkpoint_path, prefix=''):
_load_weights(self, checkpoint_path, prefix)
@torch.no_grad()
def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
""" Load weights from .npz checkpoints for official Google Brain Flax implementation
"""
import numpy as np
def _n2p(w, t=True):
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
w = w.flatten()
if t:
if w.ndim == 4:
w = w.transpose([3, 2, 0, 1])
elif w.ndim == 3:
w = w.transpose([2, 0, 1])
elif w.ndim == 2:
w = w.transpose([1, 0])
return torch.from_numpy(w)
w = np.load(checkpoint_path)
if not prefix and 'opt/target/embedding/kernel' in w:
prefix = 'opt/target/'
if hasattr(model.patch_embed, 'backbone'):
# hybrid
backbone = model.patch_embed.backbone
stem_only = not hasattr(backbone, 'stem')
stem = backbone if stem_only else backbone.stem
stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
if not stem_only:
for i, stage in enumerate(backbone.stages):
for j, block in enumerate(stage.blocks):
bp = f'{prefix}block{i + 1}/unit{j + 1}/'
for r in range(3):
getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
if block.downsample is not None:
block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
else:
embed_conv_w = adapt_input_conv(
model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
model.patch_embed.proj.weight.copy_(embed_conv_w)
model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
if pos_embed_w.shape != model.pos_embed.shape:
pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
model.pos_embed.copy_(pos_embed_w)
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
for i, block in enumerate(model.blocks.children()):
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
block.attn.qkv.weight.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
block.attn.qkv.bias.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
for r in range(2):
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))