| | ''' |
| | Code Reference: |
| | |
| | * https://github.com/jadore801120/attention-is-all-you-need-pytorch/ |
| | * https://github.com/GT-RIPL/CODA-Prompt |
| | * https://github.com/openai/CLIP |
| | ''' |
| |
|
| | import os |
| | import math |
| | import torch |
| | import numpy as np |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | from functools import partial |
| | from collections import Counter |
| | from timm.models.vision_transformer import PatchEmbed |
| | from timm.models.layers import trunc_normal_, DropPath |
| | from scipy.special import softmax |
| |
|
| | from .petl.adapter import Adapter, MaskedAdapter |
| | from .petl.proj import Proj |
| | from .prompt import L2P |
| |
|
| | |
| | class SparseDispatcher(object): |
| | """Helper for implementing a mixture of experts. |
| | The purpose of this class is to create input minibatches for the |
| | experts and to combine the results of the experts to form a unified |
| | output tensor. |
| | There are two functions: |
| | dispatch - take an input Tensor and create input Tensors for each expert. |
| | combine - take output Tensors from each expert and form a combined output |
| | Tensor. Outputs from different experts for the same batch element are |
| | summed together, weighted by the provided "gates". |
| | The class is initialized with a "gates" Tensor, which specifies which |
| | batch elements go to which experts, and the weights to use when combining |
| | the outputs. Batch element b is sent to expert e iff gates[b, e] != 0. |
| | The inputs and outputs are all two-dimensional [batch, depth]. |
| | Caller is responsible for collapsing additional dimensions prior to |
| | calling this class and reshaping the output to the original shape. |
| | See common_layers.reshape_like(). |
| | Example use: |
| | gates: a float32 `Tensor` with shape `[batch_size, num_experts]` |
| | inputs: a float32 `Tensor` with shape `[batch_size, input_size]` |
| | experts: a list of length `num_experts` containing sub-networks. |
| | dispatcher = SparseDispatcher(num_experts, gates) |
| | expert_inputs = dispatcher.dispatch(inputs) |
| | expert_outputs = [experts[i](expert_inputs[i]) for i in range(num_experts)] |
| | outputs = dispatcher.combine(expert_outputs) |
| | The preceding code sets the output for a particular example b to: |
| | output[b] = Sum_i(gates[b, i] * experts[i](inputs[b])) |
| | This class takes advantage of sparsity in the gate matrix by including in the |
| | `Tensor`s for expert i only the batch elements for which `gates[b, i] > 0`. |
| | """ |
| |
|
| | def __init__(self, num_experts, gates): |
| | """Create a SparseDispatcher.""" |
| |
|
| | self._gates = gates |
| | self._num_experts = num_experts |
| |
|
| | sorted_experts, index_sorted_experts = torch.nonzero(gates).sort(0) |
| |
|
| | |
| | _, self._expert_index = sorted_experts.split(1, dim=1) |
| | |
| | self._batch_index = torch.nonzero(gates)[index_sorted_experts[:, 1], 0] |
| | |
| | self._part_sizes = (gates > 0).sum(0).tolist() |
| | |
| | gates_exp = gates[self._batch_index.flatten()] |
| | self._nonzero_gates = torch.gather(gates_exp, 1, self._expert_index) |
| |
|
| | def dispatch(self, inp): |
| | """Create one input Tensor for each expert. |
| | The `Tensor` for a expert `i` contains the slices of `inp` corresponding |
| | to the batch elements `b` where `gates[b, i] > 0`. |
| | Args: |
| | inp: a `Tensor` of shape "[batch_size, <extra_input_dims>]` |
| | Returns: |
| | a list of `num_experts` `Tensor`s with shapes |
| | `[expert_batch_size_i, <extra_input_dims>]`. |
| | """ |
| |
|
| | |
| |
|
| | inp_exp = inp[self._batch_index].squeeze(1) |
| | return torch.split(inp_exp, self._part_sizes, dim=0) |
| |
|
| | def combine(self, expert_out, multiply_by_gates=True): |
| | """Sum together the expert output, weighted by the gates. |
| | The slice corresponding to a particular batch element `b` is computed |
| | as the sum over all experts `i` of the expert output, weighted by the |
| | corresponding gate values. If `multiply_by_gates` is set to False, the |
| | gate values are ignored. |
| | Args: |
| | expert_out: a list of `num_experts` `Tensor`s, each with shape |
| | `[expert_batch_size_i, <extra_output_dims>]`. |
| | multiply_by_gates: a boolean |
| | Returns: |
| | a `Tensor` with shape `[batch_size, <extra_output_dims>]`. |
| | """ |
| | |
| |
|
| | stitched = torch.cat(expert_out, 0) |
| | if multiply_by_gates: |
| | stitched = stitched.mul(self._nonzero_gates) |
| |
|
| | zeros = torch.zeros(self._gates.size(0), expert_out[-1].size(1), device=stitched.device) |
| | |
| |
|
| | combined = zeros.index_add(0, self._batch_index, stitched.float()) |
| | |
| | |
| | return combined |
| |
|
| | def expert_to_gates(self): |
| | """Gate values corresponding to the examples in the per-expert `Tensor`s. |
| | Returns: |
| | a list of `num_experts` one-dimensional `Tensor`s with type `tf.float32` |
| | and shapes `[expert_batch_size_i]` |
| | """ |
| | |
| | return torch.split(self._nonzero_gates, self._part_sizes, dim=0) |
| |
|
| | |
| | class LayerNorm(nn.LayerNorm): |
| | """Subclass torch's LayerNorm to handle fp16.""" |
| |
|
| | def forward(self, x: torch.Tensor): |
| | orig_type = x.dtype |
| | ret = super().forward(x.type(torch.float32)) |
| | return ret.type(orig_type) |
| |
|
| | class QuickGELU(nn.Module): |
| | def forward(self, x: torch.Tensor): |
| | return x * torch.sigmoid(1.702 * x) |
| |
|
| | |
| | class MultiHeadAttention(nn.Module): |
| | def __init__(self, dim, num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): |
| | super().__init__() |
| | self.dim = dim |
| | self.num_heads = num_heads |
| | head_dim = dim // num_heads |
| | |
| | self.scale = qk_scale or head_dim ** -0.5 |
| | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
| | self.attn_drop = nn.Dropout(attn_drop) if attn_drop > 0. else nn.Identity() |
| | self.proj = nn.Linear(dim, dim) |
| | self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0. else nn.Identity() |
| | self.attn_gradients = None |
| | self.attention_map = None |
| |
|
| | def save_attn_gradients(self, attn_gradients): |
| | self.attn_gradients = attn_gradients |
| | |
| | def get_attn_gradients(self): |
| | return self.attn_gradients |
| | |
| | def save_attention_map(self, attention_map): |
| | self.attention_map = attention_map |
| | |
| | def get_attention_map(self): |
| | return self.attention_map |
| | |
| | def forward(self, x, attn_mask=None, register_hook=False, prompt=None): |
| |
|
| | B, N, C = x.shape |
| | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| | q, k, v = qkv[0], qkv[1], qkv[2] |
| |
|
| | if prompt is not None: |
| | pk, pv = prompt |
| | pk = pk.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) |
| | pv = pv.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) |
| | k = torch.cat((pk,k), dim=2) |
| | v = torch.cat((pv,v), dim=2) |
| |
|
| | attn = (q @ k.transpose(-2, -1)) * self.scale |
| |
|
| | if attn_mask is not None: |
| | attn += attn_mask.unsqueeze(0) |
| |
|
| | attn = attn.softmax(dim=-1) |
| | attn = self.attn_drop(attn) |
| | |
| | if register_hook: |
| | self.save_attention_map(attn) |
| | attn.register_hook(self.save_attn_gradients) |
| |
|
| | x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| | x = self.proj(x) |
| | x = self.proj_drop(x) |
| | return x |
| |
|
| | class MultiHeadAttention_LoRA(MultiHeadAttention): |
| |
|
| | ''' |
| | Attention module with lora, apply to k, v |
| | ''' |
| |
|
| | def __init__(self, dim, num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., lora_rank=10, lora_bias=False): |
| | super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop) |
| |
|
| | self.lora_rank = lora_rank |
| | |
| | self.lora_A_k = nn.Linear(self.dim, self.lora_rank, bias=lora_bias) |
| | self.lora_B_k = nn.Linear(self.lora_rank, self.dim, bias=lora_bias) |
| | self.lora_A_v = nn.Linear(self.dim, self.lora_rank, bias=lora_bias) |
| | self.lora_B_v = nn.Linear(self.lora_rank, self.dim, bias=lora_bias) |
| | self.apply_lora = False |
| |
|
| | self.cur_matrix = torch.zeros(self.dim ,self.dim) |
| | self.n_cur_matrix = 0 |
| |
|
| | def init_param(self): |
| |
|
| | nn.init.kaiming_uniform_(self.lora_A_k.weight, a=math.sqrt(5)) |
| | nn.init.kaiming_uniform_(self.lora_A_v.weight, a=math.sqrt(5)) |
| | nn.init.zeros_(self.lora_B_k.weight) |
| | nn.init.zeros_(self.lora_B_v.weight) |
| |
|
| | self.apply_lora = True |
| |
|
| | def merge_weight(self): |
| | |
| | q_weight, k_weight, v_weight = self.qkv.weight.chunk(3, dim=0) |
| | k_weight = k_weight + self.lora_B_k.weight @ self.lora_A_k.weight |
| | v_weight = v_weight + self.lora_B_v.weight @ self.lora_A_v.weight |
| | self.qkv.weight.data = torch.cat([q_weight, k_weight, v_weight], dim=0) |
| | self.apply_lora = False |
| |
|
| | def reset_input_matrix(self): |
| | self.cur_matrix.zero_() |
| | self.n_cur_matrix = 0 |
| |
|
| | def forward(self, x, attn_mask=None, register_hook=False, prompt=None, get_input_matrix = False): |
| | |
| | if get_input_matrix: |
| | self.cur_matrix = (self.cur_matrix * self.n_cur_matrix + torch.bmm(x.detach().permute(0, 2, 1), x.detach()).sum(dim=0).cpu())/(self.n_cur_matrix + x.shape[0] * x.shape[1]) |
| | self.n_cur_matrix += x.shape[0]*x.shape[1] |
| |
|
| | B, N, C = x.shape |
| |
|
| | q_weight, k_weight, v_weight = self.qkv.weight.chunk(3, dim=0) |
| |
|
| | if self.apply_lora: |
| | k_weight = k_weight + self.lora_B_k.weight @ self.lora_A_k.weight |
| | v_weight = v_weight + self.lora_B_v.weight @ self.lora_A_v.weight |
| | |
| | qkv = F.linear(x, torch.cat([q_weight, k_weight, v_weight], dim=0), self.qkv.bias.data).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| | |
| | q, k, v = qkv[0], qkv[1], qkv[2] |
| |
|
| | attn = (q @ k.transpose(-2, -1)) * self.scale |
| |
|
| | if attn_mask is not None: |
| | attn += attn_mask.unsqueeze(0) |
| |
|
| | attn = attn.softmax(dim=-1) |
| | attn = self.attn_drop(attn) |
| | |
| | if register_hook: |
| | self.save_attention_map(attn) |
| | attn.register_hook(self.save_attn_gradients) |
| |
|
| | x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| | x = self.proj(x) |
| | x = self.proj_drop(x) |
| |
|
| | return x |
| |
|
| | class MultiHeadAttention_SDLoRA(MultiHeadAttention): |
| | def __init__(self, dim, num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., lora_rank=10, lora_bias=False): |
| | super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop) |
| |
|
| | self.lora_rank = lora_rank |
| | self.lora_bias = lora_bias |
| | |
| | self.lora_A_q_list = nn.ModuleList([]) |
| | self.lora_B_q_list = nn.ModuleList([]) |
| | self.lora_A_v_list = nn.ModuleList([]) |
| | self.lora_B_v_list = nn.ModuleList([]) |
| |
|
| | self.assimilated_mag_lora_q = [] |
| | self.assimilated_mag_lora_v = [] |
| |
|
| | def init_param(self): |
| |
|
| | self.lora_A_q_list.append(nn.Linear(self.dim, self.lora_rank, bias=self.lora_bias)) |
| | self.lora_B_q_list.append(nn.Linear(self.lora_rank, self.dim, bias=self.lora_bias)) |
| | self.lora_A_v_list.append(nn.Linear(self.dim, self.lora_rank, bias=self.lora_bias)) |
| | self.lora_B_v_list.append(nn.Linear(self.lora_rank, self.dim, bias=self.lora_bias)) |
| |
|
| | nn.init.kaiming_uniform_(self.lora_A_q_list[-1].weight, a=math.sqrt(5)) |
| | nn.init.kaiming_uniform_(self.lora_A_v_list[-1].weight, a=math.sqrt(5)) |
| | nn.init.zeros_(self.lora_B_q_list[-1].weight) |
| | nn.init.zeros_(self.lora_B_v_list[-1].weight) |
| |
|
| | self.assimilated_mag_lora_q.append( |
| | torch.Tensor([0.0]).to(self.qkv.weight.device) |
| | ) |
| | self.assimilated_mag_lora_v.append( |
| | torch.Tensor([0.0]).to(self.qkv.weight.device) |
| | ) |
| |
|
| | assert len(self.lora_A_q_list) == len(self.mag_lora) |
| | assert len(self.mag_lora) == len(self.assimilated_mag_lora_q) |
| |
|
| | def forward(self, x, attn_mask=None, register_hook=False, prompt=None): |
| | |
| | B, N, C = x.shape |
| |
|
| | qq = self.mag_lora[-1] * self.lora_B_q_list[-1](self.lora_A_q_list[-1](x)) |
| | vv = self.mag_lora[-1] * self.lora_B_v_list[-1](self.lora_A_v_list[-1](x)) |
| |
|
| | for i in range(len(self.lora_A_q_list) - 1): |
| |
|
| | norm_B = torch.norm(self.lora_B_q_list[i].weight) |
| | norm_A = torch.norm(self.lora_A_q_list[i].weight) |
| | |
| | if norm_B != 0 and norm_A != 0: |
| | qq += (self.mag_lora[i] + self.assimilated_mag_lora_q[i]) * self.lora_B_q_list[i](self.lora_A_q_list[i](x)) / (norm_B * norm_A) |
| |
|
| | norm_B = torch.norm(self.lora_B_v_list[i].weight) |
| | norm_A = torch.norm(self.lora_A_v_list[i].weight) |
| |
|
| | if norm_B != 0 and norm_A != 0: |
| | vv += (self.mag_lora[i] + self.assimilated_mag_lora_v[i]) * self.lora_B_v_list[i](self.lora_A_v_list[i](x)) / (norm_B * norm_A) |
| |
|
| | qkv = self.qkv(x) |
| | qkv[:, :, : self.dim] += qq |
| | qkv[:, :, -self.dim :] += vv |
| |
|
| | qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| | q, k, v = qkv[0], qkv[1], qkv[2] |
| |
|
| | attn = (q @ k.transpose(-2, -1)) * self.scale |
| |
|
| | if attn_mask is not None: |
| | attn += attn_mask.unsqueeze(0) |
| |
|
| | attn = attn.softmax(dim=-1) |
| | attn = self.attn_drop(attn) |
| | |
| | if register_hook: |
| | self.save_attention_map(attn) |
| | attn.register_hook(self.save_attn_gradients) |
| |
|
| | x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| | x = self.proj(x) |
| | x = self.proj_drop(x) |
| |
|
| | return x |
| |
|
| | class MultiHeadAttention_LoRA_Sub(MultiHeadAttention): |
| |
|
| | ''' |
| | Attention module with lora, apply to k, v |
| | ''' |
| |
|
| | def __init__(self, dim, num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., lora_rank=10, lora_bias=False): |
| | super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop) |
| |
|
| | self.lora_rank = lora_rank |
| | |
| | self.lora_A_k = nn.Linear(self.dim, self.lora_rank, bias=lora_bias) |
| | self.lora_B_k = nn.Linear(self.lora_rank, self.dim, bias=lora_bias) |
| | self.lora_A_v = nn.Linear(self.dim, self.lora_rank, bias=lora_bias) |
| | self.lora_B_v = nn.Linear(self.lora_rank, self.dim, bias=lora_bias) |
| | self.apply_lora = False |
| |
|
| | self.cur_matrix = torch.zeros(self.dim ,self.dim) |
| | self.n_cur_matrix = 0 |
| |
|
| | self.register_buffer("prev_k_weight", torch.zeros(self.dim, self.dim)) |
| | self.register_buffer("prev_v_weight", torch.zeros(self.dim, self.dim)) |
| |
|
| | def init_param(self): |
| |
|
| | nn.init.kaiming_uniform_(self.lora_A_k.weight, a=math.sqrt(5)) |
| | nn.init.kaiming_uniform_(self.lora_A_v.weight, a=math.sqrt(5)) |
| | nn.init.zeros_(self.lora_B_k.weight) |
| | nn.init.zeros_(self.lora_B_v.weight) |
| |
|
| | self.apply_lora = True |
| |
|
| | def save_weight(self): |
| |
|
| | self.prev_k_weight += self.lora_B_k.weight @ self.lora_A_k.weight |
| | self.prev_v_weight += self.lora_B_v.weight @ self.lora_A_v.weight |
| | self.apply_lora = False |
| |
|
| | def reset_input_matrix(self): |
| | self.cur_matrix.zero_() |
| | self.n_cur_matrix = 0 |
| |
|
| | def forward(self, x, attn_mask=None, register_hook=False, prompt=None, get_input_matrix = False): |
| | |
| | B, N, C = x.shape |
| |
|
| | q_weight, k_weight, v_weight = self.qkv.weight.chunk(3, dim=0) |
| |
|
| | if get_input_matrix: |
| | |
| | self.cur_matrix = (self.cur_matrix * self.n_cur_matrix + torch.bmm(x.detach().permute(0, 2, 1), x.detach()).sum(dim=0).cpu())/(self.n_cur_matrix + x.shape[0] * x.shape[1]) |
| | self.n_cur_matrix += x.shape[0]*x.shape[1] |
| |
|
| | k_weight = k_weight - self.prev_k_weight |
| | v_weight = v_weight - self.prev_v_weight |
| |
|
| | elif self.apply_lora: |
| | |
| | k_weight = k_weight + self.prev_k_weight + self.lora_B_k.weight @ self.lora_A_k.weight |
| | v_weight = v_weight + self.prev_v_weight + self.lora_B_v.weight @ self.lora_A_v.weight |
| | else: |
| | |
| | k_weight = k_weight + self.prev_k_weight |
| | v_weight = v_weight + self.prev_v_weight |
| |
|
| | qkv = F.linear(x, torch.cat([q_weight, k_weight, v_weight], dim=0), self.qkv.bias.data).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| | |
| | q, k, v = qkv[0], qkv[1], qkv[2] |
| |
|
| | attn = (q @ k.transpose(-2, -1)) * self.scale |
| |
|
| | if attn_mask is not None: |
| | attn += attn_mask.unsqueeze(0) |
| |
|
| | attn = attn.softmax(dim=-1) |
| | attn = self.attn_drop(attn) |
| | |
| | if register_hook: |
| | self.save_attention_map(attn) |
| | attn.register_hook(self.save_attn_gradients) |
| |
|
| | x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| | x = self.proj(x) |
| | x = self.proj_drop(x) |
| |
|
| | return x |
| |
|
| | class MultiHeadAttention_CL_LoRA(MultiHeadAttention_LoRA): |
| |
|
| | ''' |
| | Attention module with lora, apply to q, v |
| | ''' |
| |
|
| | def __init__(self, dim, num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., lora_rank=10, lora_bias=False): |
| | super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias) |
| | |
| | del self.lora_A_k |
| | del self.lora_B_k |
| | self.lora_A_q = nn.Linear(self.dim, self.lora_rank, bias=lora_bias) |
| | self.lora_B_q = nn.Linear(self.lora_rank, self.dim, bias=lora_bias) |
| |
|
| | def init_param(self): |
| | |
| | q1, _ = torch.linalg.qr(torch.rand(self.dim, self.lora_rank)) |
| | q2, _ = torch.linalg.qr(torch.rand(self.dim, self.lora_rank)) |
| | with torch.no_grad(): |
| | self.lora_A_q.weight.copy_(q1.T) |
| | self.lora_A_v.weight.copy_(q2.T) |
| |
|
| | scaling_factor = 1. |
| | self.lora_A_q.weight.data *= scaling_factor |
| | self.lora_A_v.weight.data *= scaling_factor |
| |
|
| | nn.init.zeros_(self.lora_B_q.weight) |
| | nn.init.zeros_(self.lora_B_v.weight) |
| |
|
| | def forward( |
| | self, |
| | x, |
| | adapt=None, |
| | prompt=None, |
| | rank_prompt=None, |
| | block_weight=None, |
| | attn_mask=None, |
| | register_hook=False): |
| | |
| | |
| | |
| | |
| |
|
| | B, N, C = x.shape |
| |
|
| | q_weight, k_weight, v_weight = self.qkv.weight.chunk(3, dim=0) |
| |
|
| | qkv = F.linear(x, torch.cat([q_weight, k_weight, v_weight], dim=0), self.qkv.bias.data) |
| | |
| | if adapt is not None: |
| | if block_weight is not None: |
| | block_weight = block_weight |
| | else: |
| | block_weight = torch.ones(3).to(x.device) |
| | qq = block_weight[0] * adapt[0](x) |
| | vv = block_weight[2] * adapt[2](x) |
| |
|
| | qkv[:, :, : self.dim] += qq |
| | qkv[:, :, -self.dim :] += vv |
| |
|
| | qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| |
|
| | q, k, v = qkv[0], qkv[1], qkv[2] |
| |
|
| | attn = (q @ k.transpose(-2, -1)) * self.scale |
| |
|
| | if attn_mask is not None: |
| | attn += attn_mask.unsqueeze(0) |
| |
|
| | attn = attn.softmax(dim=-1) |
| | attn = self.attn_drop(attn) |
| | |
| | if register_hook: |
| | self.save_attention_map(attn) |
| | attn.register_hook(self.save_attn_gradients) |
| |
|
| | x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| | x = self.proj(x) |
| | x = self.proj_drop(x) |
| |
|
| | return x |
| |
|
| | |
| | class MultiHeadAttention_MaskedLoRA(MultiHeadAttention_LoRA): |
| |
|
| | |
| | |
| | def __init__(self, dim, num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., lora_rank=10, lora_bias=False): |
| | super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias) |
| |
|
| | |
| | self.identity_matrix = torch.eye(self.qkv.weight.shape[1]) |
| | |
| | self.space = [[0, 0] for _ in range(10)] |
| | self.scale_param = nn.ModuleList([nn.ParameterList([nn.Parameter(self.identity_matrix) for _ in range(2)]) for _ in range(10)]) |
| | self.scaling_mask = [[False, False] for _ in range(10)] |
| |
|
| | def enable_scale(self, task_id, space): |
| | if len(space) == 2: |
| | self.space[task_id][0] = space[0] |
| | self.space[task_id][1] = space[1] |
| | self.scaling_mask[task_id][0] = True |
| | self.scaling_mask[task_id][1] = True |
| | elif len(space) == 1: |
| | self.space[task_id][0] = space[0] |
| | self.scaling_mask[task_id][0] = True |
| |
|
| | for scale_param_list in self.scale_param: |
| | for scale_param in scale_param_list: |
| | scale_param = scale_param.to(self.qkv.weight.device) |
| |
|
| | def forward(self, x, attn_mask=None, expert_id=0, register_hook=False, prompt=None, get_input_matrix = False): |
| |
|
| | if get_input_matrix: |
| | self.cur_matrix = (self.cur_matrix*self.n_cur_matrix + torch.bmm(x.detach().permute(0, 2, 1), x.detach()).sum(dim=0).cpu())/(self.n_cur_matrix + x.shape[0]*x.shape[1]) |
| | self.n_cur_matrix += x.shape[0]*x.shape[1] |
| | |
| | B, N, C = x.shape |
| |
|
| | q_weight, k_weight, v_weight = self.qkv.weight.chunk(3, dim=0) |
| |
|
| | if self.apply_lora: |
| | k_weight = k_weight + self.lora_B_k.weight @ self.lora_A_k.weight |
| | v_weight = v_weight + self.lora_B_v.weight @ self.lora_A_v.weight |
| | |
| | for mask, scale, space in zip(self.scaling_mask[expert_id], self.scale_param[expert_id], self.space[expert_id]): |
| |
|
| | if not mask: |
| | break |
| | |
| | scale_size = space.shape[1] |
| | cropped_scale = scale[:scale_size, :scale_size] |
| |
|
| | cropped_scale = cropped_scale @ cropped_scale.T |
| |
|
| | cropped_identity_matrix = self.identity_matrix[:scale_size, :scale_size].to(self.qkv.weight.device) |
| |
|
| | k_weight = k_weight + k_weight @ space @ (cropped_scale - cropped_identity_matrix) @ space.T |
| | v_weight = v_weight + v_weight @ space @ (cropped_scale - cropped_identity_matrix) @ space.T |
| |
|
| | qkv = F.linear(x, torch.cat([q_weight, k_weight, v_weight], dim=0), self.qkv.bias.data).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| | |
| | q, k, v = qkv[0], qkv[1], qkv[2] |
| |
|
| | attn = (q @ k.transpose(-2, -1)) * self.scale |
| |
|
| | if attn_mask is not None: |
| | attn += attn_mask.unsqueeze(0) |
| |
|
| | attn = attn.softmax(dim=-1) |
| | attn = self.attn_drop(attn) |
| | |
| | if register_hook: |
| | self.save_attention_map(attn) |
| | attn.register_hook(self.save_attn_gradients) |
| |
|
| | x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| | x = self.proj(x) |
| | x = self.proj_drop(x) |
| | return x |
| |
|
| | |
| | class MultiHeadAttention_MaskedLoRA1(MultiHeadAttention): |
| | def __init__(self, dim, num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., lora_rank=10, lora_bias=False): |
| | super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop) |
| |
|
| | self.cur_task = -1 |
| | self.lora_rank = lora_rank |
| | |
| | self.cur_matrix = torch.zeros(self.dim ,self.dim) |
| | self.n_cur_matrix = 0 |
| |
|
| | self.lora_bias = lora_bias |
| |
|
| | self.lora_A_k_list = nn.ModuleList([]) |
| | self.lora_B_k_list = nn.ModuleList([]) |
| | self.lora_A_v_list = nn.ModuleList([]) |
| | self.lora_B_v_list = nn.ModuleList([]) |
| |
|
| | self.space_k = [0 for _ in range(10)] |
| | self.space_v = [0 for _ in range(10)] |
| | self.identity_matrix = torch.eye(self.qkv.weight.shape[1]) |
| | self.scale_param = nn.ParameterList([]) |
| |
|
| | def init_param(self): |
| |
|
| | self.lora_A_k_list.append(nn.Linear(self.dim, self.lora_rank, bias=self.lora_bias)) |
| | self.lora_B_k_list.append(nn.Linear(self.lora_rank, self.dim, bias=self.lora_bias)) |
| | self.lora_A_v_list.append(nn.Linear(self.dim, self.lora_rank, bias=self.lora_bias)) |
| | self.lora_B_v_list.append(nn.Linear(self.lora_rank, self.dim, bias=self.lora_bias)) |
| | self.scale_param.append(nn.Parameter(self.identity_matrix)) |
| |
|
| | nn.init.kaiming_uniform_(self.lora_A_k_list[-1].weight, a=math.sqrt(5)) |
| | nn.init.kaiming_uniform_(self.lora_A_v_list[-1].weight, a=math.sqrt(5)) |
| | nn.init.zeros_(self.lora_B_k_list[-1].weight) |
| | nn.init.zeros_(self.lora_B_v_list[-1].weight) |
| |
|
| | self.cur_task += 1 |
| |
|
| | def reset_input_matrix(self): |
| | self.cur_matrixs = [] |
| |
|
| | def forward(self, x, x_proj, probs, attn_mask=None, expert_id=0, register_hook=False, prompt=None, get_input_matrix=False): |
| | |
| | if get_input_matrix: |
| | assert x.shape[0] < 512 |
| | self.cur_matrixs.append(x.detach()) |
| |
|
| | if x.shape[0] > 128: |
| | |
| | activation = torch.bmm(x.permute(0, 2, 1), x).sum(dim=0) / x.shape[0] |
| |
|
| | |
| | activation = self.lora_A_k_list[-1].weight.data.T @ self.lora_A_k_list[-1].weight.data @ activation |
| |
|
| | if self.cur_task > 0: |
| | activation = activation - self.feature_mat @ activation |
| |
|
| | U, _, _ = torch.linalg.svd(activation, full_matrices = False) |
| | A_new = U[:,:self.lora_rank].T / math.sqrt(3) |
| | A_old = self.lora_A_k_list[-1].weight.data |
| | Bk_old = self.lora_B_k_list[-1].weight.data |
| | Bv_old = self.lora_B_v_list[-1].weight.data |
| |
|
| | tmp = A_old @ torch.pinverse(A_new) |
| | Bk_new = Bk_old @ tmp |
| | Bv_new = Bv_old @ tmp |
| |
|
| | ''' |
| | # Compute matmul results |
| | Bk_old_A_old = Bk_old @ A_old |
| | Bk_new_A_new = Bk_new @ A_new |
| | Bv_old_A_old = Bv_old @ A_old |
| | Bv_new_A_new = Bv_new @ A_new |
| | |
| | # Compute the Frobenius norm of the difference between old and new matmul results |
| | frobenius_norm_Bk = torch.norm(Bk_old_A_old - Bk_new_A_new, p='fro') |
| | frobenius_norm_Bv = torch.norm(Bv_old_A_old - Bv_new_A_new, p='fro') |
| | |
| | # Printing the results |
| | print(f"Frobenius norm difference between Bk_old @ A_old and Bk_new @ A_new: {frobenius_norm_Bk.item()}") |
| | print(f"Frobenius norm difference between Bv_old @ A_old and Bv_new @ A_new: {frobenius_norm_Bv.item()}") |
| | ''' |
| |
|
| | self.lora_A_k_list[-1].weight.data.copy_(A_new) |
| | self.lora_A_v_list[-1].weight.data.copy_(A_new) |
| | self.lora_B_k_list[-1].weight.data.copy_(Bk_new) |
| | self.lora_B_v_list[-1].weight.data.copy_(Bv_new) |
| |
|
| | B, N, C = x.shape |
| | q_weight, k_weight, v_weight = self.qkv.weight.chunk(3, dim=0) |
| |
|
| | for ii in range(self.cur_task): |
| | k_weight = k_weight + self.lora_B_k_list[ii].weight @ self.lora_A_k_list[ii].weight |
| | v_weight = v_weight + self.lora_B_v_list[ii].weight @ self.lora_A_v_list[ii].weight |
| |
|
| | k_weight = k_weight + self.lora_B_k_list[-1].weight @ self.lora_A_k_list[-1].weight |
| | v_weight = v_weight + self.lora_B_v_list[-1].weight @ self.lora_A_v_list[-1].weight |
| |
|
| | ''' |
| | for ii in range(self.cur_task): |
| | if not isinstance(self.space_k[ii], int): |
| | |
| | space_k = self.space_k[ii] |
| | space_v = self.space_v[ii] |
| | scale_k = self.scale_param[ii] |
| | |
| | # Q Scaling |
| | scalee = scale_k[:space_k.shape[0], :space_k.shape[0]] |
| | |
| | # QQ^T Scaling |
| | scalee = scale_k[:space_k.shape[0], :space_k.shape[0]] @ scale_k[:space_k.shape[0], :space_k.shape[0]].T |
| | |
| | # QQ^T Diagonal Scaling12 |
| | #scalee = torch.diag(torch.diag(scale_k[:space_k.shape[0], :space_k.shape[0]] @ scale_k[:space_k.shape[0], :space_k.shape[0]].T)) |
| | |
| | # Q Diagonal Scaling |
| | #scalee = torch.diag(torch.diag(scale_k[:space_k.shape[0], :space_k.shape[0]])) |
| | |
| | #scalee = scale_k[0, 0] |
| | scalee = self.mag_lora[ii] |
| | |
| | use_scale = False |
| | if use_scale: |
| | |
| | norm_B = torch.norm(self.lora_B_k_list[ii].weight) |
| | norm_A = torch.norm(self.lora_A_k_list[ii].weight) |
| | |
| | k_weight = k_weight - self.lora_B_k_list[ii].weight @ self.lora_A_k_list[ii].weight @ space_k.T @ space_k |
| | k_weight = k_weight + scalee * (self.lora_B_k_list[ii].weight @ self.lora_A_k_list[ii].weight @ space_k.T @ space_k) / (norm_B * norm_A) |
| | |
| | norm_B = torch.norm(self.lora_B_v_list[ii].weight) |
| | norm_A = torch.norm(self.lora_A_v_list[ii].weight) |
| | |
| | v_weight = v_weight - self.lora_B_v_list[ii].weight @ self.lora_A_v_list[ii].weight @ space_v.T @ space_v |
| | v_weight = v_weight + scalee * (self.lora_B_v_list[ii].weight @ self.lora_A_v_list[ii].weight @ space_v.T @ space_v) / (norm_B * norm_A) |
| | ''' |
| |
|
| | qkv = F.linear(x, torch.cat([q_weight, k_weight, v_weight], dim=0), self.qkv.bias.data).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| | |
| | q, k, v = qkv[0], qkv[1], qkv[2] |
| |
|
| | attn = (q @ k.transpose(-2, -1)) * self.scale |
| |
|
| | if attn_mask is not None: |
| | attn += attn_mask.unsqueeze(0) |
| |
|
| | attn = attn.softmax(dim=-1) |
| | attn = self.attn_drop(attn) |
| | |
| | if register_hook: |
| | self.save_attention_map(attn) |
| | attn.register_hook(self.save_attn_gradients) |
| |
|
| | x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| | x = self.proj(x) |
| | x = self.proj_drop(x) |
| |
|
| | return x, x, probs |
| |
|
| | |
| | class MultiHeadAttention_MultiMaskedLoRA(MultiHeadAttention_MaskedLoRA): |
| | def __init__(self, dim, num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., lora_rank=10, lora_bias=False): |
| | super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias) |
| |
|
| | self.activated_expert = 0 |
| | self.saved_space = [[torch.tensor((1)), torch.tensor((1))] for _ in range(10)] |
| |
|
| | self.hit = 0 |
| | self.total = 0 |
| | self.projected_cur_matrix = torch.zeros(self.dim ,self.dim) |
| | self.n_projected_cur_matrix = 0 |
| |
|
| | def reset_input_matrix(self): |
| | super().reset_input_matrix() |
| | self.projected_cur_matrix.zero_() |
| | self.n_projected_cur_matrix = 0 |
| |
|
| | def enable_scale(self, task_id, space): |
| | |
| | if len(space) == 2: |
| | self.space[task_id][0] = space[0] |
| | self.space[task_id][1] = space[1] |
| | self.scaling_mask[task_id][0] = True |
| | self.scaling_mask[task_id][1] = True |
| | elif len(space) == 1: |
| | self.space[task_id][0] = space[0] |
| | self.scaling_mask[task_id][0] = True |
| |
|
| | for scale_param_list in self.scale_param: |
| | for scale_param in scale_param_list: |
| | scale_param = scale_param.to(self.qkv.weight.device) |
| |
|
| | def save_space(self, task_id, space): |
| | self.activated_expert = task_id |
| | self.saved_space[task_id][0] = space |
| |
|
| | def forward(self, x, x_proj, probs, attn_mask=None, expert_id=0, register_hook=False, prompt=None, get_input_matrix=False): |
| | |
| | B, N, C = x.shape |
| |
|
| | if get_input_matrix: |
| | assert expert_id == 0 |
| | self.cur_matrix = (self.cur_matrix * self.n_cur_matrix + torch.bmm(x.detach().permute(0, 2, 1), x.detach()).sum(dim=0).cpu())/(self.n_cur_matrix + B * N) |
| | self.n_cur_matrix += B * N |
| |
|
| | |
| | if not self.training and not get_input_matrix: |
| | with torch.no_grad(): |
| |
|
| | cur_cur_matrix = torch.bmm(x.detach().permute(0, 2, 1), x.detach()).sum(dim=0) / (B * N) |
| | saved = torch.stack([self.saved_space[idd][0] for idd in range(self.activated_expert + 1)]).to(x.device) |
| | |
| |
|
| | proj_mat = saved.transpose(1, 2) |
| | proj_mat = torch.einsum('ijk,kl->ijl', proj_mat, cur_cur_matrix) |
| | |
| | proj_norm = np.linalg.norm(proj_mat.cpu(), axis=(1, 2)) |
| | |
| | proj_norm = softmax(proj_norm) |
| | probs.append(proj_norm) |
| | selected_expert_id = np.argmax(proj_norm, axis = 0) |
| | |
| | expert_id = selected_expert_id |
| |
|
| | q_weight, k_weight, v_weight = self.qkv.weight.chunk(3, dim=0) |
| |
|
| | if self.apply_lora: |
| | k_weight = k_weight + self.lora_B_k.weight @ self.lora_A_k.weight |
| | v_weight = v_weight + self.lora_B_v.weight @ self.lora_A_v.weight |
| | |
| | qkv = F.linear(x, torch.cat([q_weight, k_weight, v_weight], dim=0), self.qkv.bias.data).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| | |
| | q, k, v = qkv[0], qkv[1], qkv[2] |
| |
|
| | attn = (q @ k.transpose(-2, -1)) * self.scale |
| |
|
| | if attn_mask is not None: |
| | attn += attn_mask.unsqueeze(0) |
| |
|
| | attn = attn.softmax(dim=-1) |
| | attn = self.attn_drop(attn) |
| | |
| | if register_hook: |
| | self.save_attention_map(attn) |
| | attn.register_hook(self.save_attn_gradients) |
| |
|
| | x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| | x = self.proj(x) |
| | x = self.proj_drop(x) |
| |
|
| | |
| |
|
| | for mask, scale, space in zip(self.scaling_mask[expert_id], self.scale_param[expert_id], self.space[expert_id]): |
| |
|
| | if not mask: |
| | break |
| |
|
| | scale_size = space.shape[1] |
| | cropped_scale = scale[:scale_size, :scale_size] |
| |
|
| | cropped_scale = cropped_scale @ cropped_scale.T |
| |
|
| | cropped_identity_matrix = self.identity_matrix[:scale_size, :scale_size].to(self.qkv.weight.device) |
| |
|
| | k_weight = k_weight + k_weight @ space @ (cropped_scale - cropped_identity_matrix) @ space.T |
| | v_weight = v_weight + v_weight @ space @ (cropped_scale - cropped_identity_matrix) @ space.T |
| | |
| | qkv = F.linear(x_proj, torch.cat([q_weight, k_weight, v_weight], dim=0), self.qkv.bias.data).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| | |
| | q, k, v = qkv[0], qkv[1], qkv[2] |
| |
|
| | attn = (q @ k.transpose(-2, -1)) * self.scale |
| |
|
| | if attn_mask is not None: |
| | attn += attn_mask.unsqueeze(0) |
| |
|
| | attn = attn.softmax(dim=-1) |
| | attn = self.attn_drop(attn) |
| | |
| | if register_hook: |
| | self.save_attention_map(attn) |
| | attn.register_hook(self.save_attn_gradients) |
| |
|
| | x_proj = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| | x_proj = self.proj(x_proj) |
| | x_proj = self.proj_drop(x_proj) |
| |
|
| | return x, x_proj, probs |
| |
|
| | def forward1(self, x, x_proj, probs, attn_mask=None, expert_id=0, register_hook=False, prompt=None, get_input_matrix=False): |
| | |
| | B, N, C = x.shape |
| |
|
| | if get_input_matrix: |
| | assert expert_id == 0 |
| | self.cur_matrix = (self.cur_matrix * self.n_cur_matrix + torch.bmm(x.detach().permute(0, 2, 1), x.detach()).sum(dim=0).cpu())/(self.n_cur_matrix + B * N) |
| | self.n_cur_matrix += B * N |
| | |
| | |
| | if not self.training and not get_input_matrix: |
| | with torch.no_grad(): |
| |
|
| | cur_cur_matrix = torch.bmm(x.detach().permute(0, 2, 1), x.detach()) / N |
| | cur_cur_matrix = cur_cur_matrix.permute(1, 2, 0) |
| | saved = torch.stack([self.saved_space[idd][0] for idd in range(self.activated_expert + 1)]).to(x.device) |
| | proj_mat = saved.transpose(1, 2) |
| |
|
| | proj_mat = torch.einsum('ijk,klm->ijlm', proj_mat, cur_cur_matrix) |
| |
|
| | proj_norm = np.linalg.norm(proj_mat, axis=(1, 2)) |
| | proj_norm = softmax(proj_norm, axis=0) |
| | |
| | probs.append(proj_norm) |
| |
|
| | selected_expert_id = np.argmax(proj_norm, axis = 0) |
| | selected_expert_id = torch.tensor(selected_expert_id).to(x.device) |
| |
|
| |
|
| | q_weight, k_weight, v_weight = self.qkv.weight.chunk(3, dim=0) |
| |
|
| | if self.apply_lora: |
| | k_weight = k_weight + self.lora_B_k.weight @ self.lora_A_k.weight |
| | v_weight = v_weight + self.lora_B_v.weight @ self.lora_A_v.weight |
| | |
| | qkv = F.linear(x, torch.cat([q_weight, k_weight, v_weight], dim=0), self.qkv.bias.data).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| | |
| | q, k, v = qkv[0], qkv[1], qkv[2] |
| |
|
| | attn = (q @ k.transpose(-2, -1)) * self.scale |
| |
|
| | if attn_mask is not None: |
| | attn += attn_mask.unsqueeze(0) |
| |
|
| | attn = attn.softmax(dim=-1) |
| | attn = self.attn_drop(attn) |
| | |
| | if register_hook: |
| | self.save_attention_map(attn) |
| | attn.register_hook(self.save_attn_gradients) |
| |
|
| | x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| | x = self.proj(x) |
| | x = self.proj_drop(x) |
| |
|
| | |
| | if not self.training and not get_input_matrix: |
| | inputs = [x_proj.clone() for _ in range(self.activated_expert + 1)] |
| | k_weights = [k_weight.clone() for _ in range(self.activated_expert + 1)] |
| | v_weights = [v_weight.clone() for _ in range(self.activated_expert + 1)] |
| | qkv_outputs = [] |
| |
|
| | for ex in range(self.activated_expert + 1): |
| |
|
| | for mask, scale, space in zip(self.scaling_mask[ex], self.scale_param[ex], self.space[ex]): |
| |
|
| | if not mask: |
| | break |
| |
|
| | scale_size = space.shape[1] |
| | cropped_scale = scale[:scale_size, :scale_size] |
| |
|
| | cropped_scale = cropped_scale @ cropped_scale.T |
| |
|
| | cropped_identity_matrix = self.identity_matrix[:scale_size, :scale_size].to(x.device) |
| |
|
| | k_weights[ex] = k_weights[ex] + k_weights[ex] @ space @ (cropped_scale - cropped_identity_matrix) @ space.T |
| | v_weights[ex] = v_weights[ex] + v_weights[ex] @ space @ (cropped_scale - cropped_identity_matrix) @ space.T |
| |
|
| | cur_selected = selected_expert_id.unsqueeze(-1).unsqueeze(-1) |
| |
|
| | mask = (cur_selected == ex) |
| | inputs[ex] *= mask |
| |
|
| | inputs[ex] = inputs[ex].to(x.device) |
| | q_weight = q_weight.to(x.device) |
| | k_weights[ex] = k_weights[ex].to(x.device) |
| | v_weights[ex] = v_weights[ex].to(x.device) |
| |
|
| | qkv = F.linear(inputs[ex], torch.cat([q_weight, k_weights[ex], v_weights[ex]], dim=0)) |
| | qkv_outputs.append(qkv) |
| |
|
| | stacked = torch.stack(qkv_outputs) |
| | qkv = torch.sum(stacked, dim=0) |
| | qkv = qkv + self.qkv.bias |
| | qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| |
|
| | q, k, v = qkv[0], qkv[1], qkv[2] |
| |
|
| | attn = (q @ k.transpose(-2, -1)) * self.scale |
| |
|
| | if attn_mask is not None: |
| | attn += attn_mask.unsqueeze(0) |
| |
|
| | attn = attn.softmax(dim=-1) |
| | attn = self.attn_drop(attn) |
| | |
| | if register_hook: |
| | self.save_attention_map(attn) |
| | attn.register_hook(self.save_attn_gradients) |
| |
|
| | x_proj = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| | x_proj = self.proj(x_proj) |
| | x_proj = self.proj_drop(x_proj) |
| |
|
| | else: |
| |
|
| | for mask, scale, space in zip(self.scaling_mask[expert_id], self.scale_param[expert_id], self.space[expert_id]): |
| |
|
| | if not mask: |
| | break |
| |
|
| | scale_size = space.shape[1] |
| | cropped_scale = scale[:scale_size, :scale_size] |
| |
|
| | cropped_scale = cropped_scale @ cropped_scale.T |
| |
|
| | cropped_identity_matrix = self.identity_matrix[:scale_size, :scale_size].to(self.qkv.weight.device) |
| |
|
| | k_weight = k_weight + k_weight @ space @ (cropped_scale - cropped_identity_matrix) @ space.T |
| | v_weight = v_weight + v_weight @ space @ (cropped_scale - cropped_identity_matrix) @ space.T |
| | |
| | qkv = F.linear(x_proj, torch.cat([q_weight, k_weight, v_weight], dim=0), self.qkv.bias.data).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| | |
| | q, k, v = qkv[0], qkv[1], qkv[2] |
| |
|
| | attn = (q @ k.transpose(-2, -1)) * self.scale |
| |
|
| | if attn_mask is not None: |
| | attn += attn_mask.unsqueeze(0) |
| |
|
| | attn = attn.softmax(dim=-1) |
| | attn = self.attn_drop(attn) |
| | |
| | if register_hook: |
| | self.save_attention_map(attn) |
| | attn.register_hook(self.save_attn_gradients) |
| |
|
| | x_proj = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| | x_proj = self.proj(x_proj) |
| | x_proj = self.proj_drop(x_proj) |
| |
|
| | return x, x_proj, probs |
| |
|
| | |
| | class MultiHeadAttention_MultiMaskedLoRA3(MultiHeadAttention_MaskedLoRA): |
| | def __init__(self, dim, num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., lora_rank=10, lora_bias=False): |
| | super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias) |
| |
|
| | self.cur_task = -1 |
| |
|
| | self.lora_A_k_list = nn.ModuleList([nn.Linear(self.dim, self.lora_rank, bias=lora_bias) for _ in range(10)]) |
| | self.lora_B_k_list = nn.ModuleList([nn.Linear(self.lora_rank, self.dim, bias=lora_bias) for _ in range(10)]) |
| | self.lora_A_v_list = nn.ModuleList([nn.Linear(self.dim, self.lora_rank, bias=lora_bias) for _ in range(10)]) |
| | self.lora_B_v_list = nn.ModuleList([nn.Linear(self.lora_rank, self.dim, bias=lora_bias) for _ in range(10)]) |
| |
|
| | self.space_k = [0 for _ in range(10)] |
| | self.space_v = [0 for _ in range(10)] |
| | self.scale_param = nn.ParameterList([nn.Parameter(self.identity_matrix) for _ in range(10)]) |
| |
|
| | def init_param(self): |
| |
|
| | self.cur_task += 1 |
| |
|
| | i = self.cur_task |
| |
|
| | nn.init.kaiming_uniform_(self.lora_A_k_list[i].weight, a=math.sqrt(5)) |
| | nn.init.kaiming_uniform_(self.lora_A_v_list[i].weight, a=math.sqrt(5)) |
| | nn.init.zeros_(self.lora_B_k_list[i].weight) |
| | nn.init.zeros_(self.lora_B_v_list[i].weight) |
| |
|
| | def merge_weight(self): |
| |
|
| | print('Not MERGED') |
| | return 0 |
| |
|
| | q_weight, k_weight, v_weight = self.qkv.weight.chunk(3, dim=0) |
| | k_weight = k_weight + self.lora_B_k.weight @ self.lora_A_k.weight |
| | v_weight = v_weight + self.lora_B_v.weight @ self.lora_A_v.weight |
| |
|
| | self.apply_lora = False |
| |
|
| | for exp_id in range(10): |
| | for ii, mask, scale_k, scale_v, space_k, space_v in zip([0, 1], self.scaling_mask[exp_id], self.scale_param_k[exp_id], self.scale_param_v[exp_id], self.space_k[exp_id], self.space_v[exp_id]): |
| |
|
| | if isinstance(space_k, int): |
| | break |
| |
|
| | k_weight = k_weight - k_weight @ space_k.T @ space_k + k_weight @ space_k.T @ scale_k[:space_k.shape[0], :space_k.shape[0]] @ space_k |
| | v_weight = v_weight - v_weight @ space_v.T @ space_v + v_weight @ space_v.T @ scale_k[:space_v.shape[0], :space_v.shape[0]] @ space_v |
| |
|
| | self.space_k[exp_id][ii] = 0 |
| |
|
| | self.qkv.weight.data = torch.cat([q_weight, k_weight, v_weight], dim=0) |
| |
|
| | def save_dir(self): |
| |
|
| | return 0 |
| |
|
| | self.cur_task += 1 |
| |
|
| | ''' |
| | |
| | norm = torch.linalg.matrix_norm(self.lora_B_k.weight @ self.lora_A_k.weight) |
| | |
| | self.lora_A_k.weight.data = self.lora_A_k.weight.data / norm |
| | self.lora_B_k.weight.data = self.lora_B_k.weight.data / norm |
| | |
| | self.space_k[self.cur_task][0] = self.lora_A_k.weight.data.clone() / norm |
| | |
| | norm = torch.linalg.matrix_norm(self.lora_B_v.weight @ self.lora_A_v.weight) |
| | |
| | self.lora_A_v.weight.data = self.lora_A_v.weight.data / norm |
| | self.lora_B_v.weight.data = self.lora_B_v.weight.data / norm |
| | |
| | self.space_v[self.cur_task][0] = self.lora_A_v.weight.data.clone() / norm] |
| | ''' |
| |
|
| | _, k_weight, v_weight = self.qkv.weight.chunk(3, dim=0) |
| |
|
| | U, _, _ = np.linalg.svd(k_weight.data, full_matrices = False) |
| | U, _, _ = np.linalg.svd(U[:, :10], full_matrices = False) |
| | orto_proj = U[:, -50:] |
| |
|
| | self.space_k[self.cur_task][0] = torch.Tensor(orto_proj.T).to(self.qkv.weight.device) |
| |
|
| | U, _, _ = np.linalg.svd(v_weight.data, full_matrices = False) |
| | U, _, _ = np.linalg.svd(U[:, :10], full_matrices = False) |
| | orto_proj = U[:, -50:] |
| |
|
| | self.space_v[self.cur_task][0] = torch.Tensor(orto_proj.T).to(self.qkv.weight.device) |
| |
|
| | self.scaling_mask[self.cur_task][0] = True |
| |
|
| | def enable_scale(self, task_id, space): |
| | |
| | if len(space) == 2: |
| | self.space[task_id][0] = space[0] |
| | self.space[task_id][1] = space[1] |
| | self.scaling_mask[task_id][0] = True |
| | self.scaling_mask[task_id][1] = True |
| | elif len(space) == 1: |
| | self.space[task_id][0] = space[0] |
| | self.scaling_mask[task_id][0] = True |
| |
|
| | for scale_param_list in self.scale_param: |
| | for scale_param in scale_param_list: |
| | scale_param = scale_param.to(self.qkv.weight.device) |
| |
|
| | def save_space(self, task_id, space): |
| | self.activated_expert = task_id |
| | self.saved_space[task_id].append(space) |
| |
|
| | def forward(self, x, x_proj, probs, attn_mask=None, expert_id=0, register_hook=False, prompt=None, get_input_matrix=False): |
| | |
| | B, N, C = x.shape |
| |
|
| | if get_input_matrix: |
| | self.cur_matrix = (self.cur_matrix * self.n_cur_matrix + torch.bmm(x.detach().permute(0, 2, 1), x.detach()).sum(dim=0).cpu())/(self.n_cur_matrix + B * N) |
| | self.n_cur_matrix += B * N |
| | |
| | q_weight, k_weight, v_weight = self.qkv.weight.chunk(3, dim=0) |
| |
|
| | |
| | for exp_id in range(10): |
| |
|
| | break |
| |
|
| | for mask, scale, space_k, space_v in zip(self.scaling_mask[exp_id], self.scale_param[exp_id], self.space_k[exp_id], self.space_v[exp_id]): |
| |
|
| | if isinstance(space_k, int): |
| | break |
| |
|
| | cropped_scale = scale[:space_k.shape[0], :space_k.shape[0]] |
| | print( |
| | round(torch.linalg.norm(k_weight @ space_k.T @ space_k, ord='fro').item(), 2), |
| | round(torch.linalg.norm(k_weight @ space_k.T @ cropped_scale @ space_k, ord='fro').item(), 2), |
| | round(torch.linalg.norm(self.lora_B_k.weight @ self.lora_A_k.weight @ space_k.T @ space_k, ord='fro').item(), 2), |
| | round(torch.linalg.norm(self.lora_B_k.weight @ self.lora_A_k.weight @ space_k.T @ cropped_scale @ space_k, ord='fro').item(), 2), |
| | ) |
| |
|
| | for ii in range(self.cur_task + 1): |
| | k_weight = k_weight + self.lora_B_k_list[ii].weight @ self.lora_A_k_list[ii].weight |
| | v_weight = v_weight + self.lora_B_v_list[ii].weight @ self.lora_A_v_list[ii].weight |
| |
|
| | if not isinstance(self.space_k[ii], int): |
| |
|
| | space_k = self.space_k[ii] |
| | space_v = self.space_v[ii] |
| | scale_k = self.scale_param[ii] |
| |
|
| | |
| | scalee = scale_k[:space_k.shape[0], :space_k.shape[0]] |
| |
|
| | |
| | scalee = scale_k[:space_k.shape[0], :space_k.shape[0]] @ scale_k[:space_k.shape[0], :space_k.shape[0]].T |
| |
|
| | |
| | scalee = torch.diag(torch.diag(scale_k[:space_k.shape[0], :space_k.shape[0]] @ scale_k[:space_k.shape[0], :space_k.shape[0]].T)) |
| |
|
| | |
| | scalee = torch.diag(torch.diag(scale_k[:space_k.shape[0], :space_k.shape[0]])) |
| |
|
| | |
| | |
| |
|
| | use_scale = True |
| | if use_scale: |
| | |
| | |
| | dir_k = space_k |
| | k_weight = k_weight - k_weight @ space_k.T @ space_k + k_weight @ dir_k.T @ scalee @ dir_k |
| |
|
| | |
| | dir_v = space_v |
| |
|
| | v_weight = v_weight - v_weight @ space_v.T @ space_v + v_weight @ dir_v.T @ scalee @ dir_v |
| | else: |
| | pass |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | qkv = F.linear(x, torch.cat([q_weight, k_weight, v_weight], dim=0), self.qkv.bias.data).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| | |
| | q, k, v = qkv[0], qkv[1], qkv[2] |
| |
|
| | attn = (q @ k.transpose(-2, -1)) * self.scale |
| |
|
| | if attn_mask is not None: |
| | attn += attn_mask.unsqueeze(0) |
| |
|
| | attn = attn.softmax(dim=-1) |
| | attn = self.attn_drop(attn) |
| | |
| | if register_hook: |
| | self.save_attention_map(attn) |
| | attn.register_hook(self.save_attn_gradients) |
| |
|
| | x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| | x = self.proj(x) |
| | x = self.proj_drop(x) |
| |
|
| | return x, x, probs |
| |
|
| |
|
| | |
| | class Mlp(nn.Module): |
| | """ MLP as used in Vision Transformer, MLP-Mixer and related networks |
| | """ |
| | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): |
| | super().__init__() |
| | out_features = out_features or in_features |
| | hidden_features = hidden_features or in_features |
| | self.fc1 = nn.Linear(in_features, hidden_features) |
| | self.act = act_layer() |
| | self.fc2 = nn.Linear(hidden_features, out_features) |
| | self.drop = nn.Dropout(drop) |
| |
|
| | def forward(self, x): |
| | x = self.fc1(x) |
| | x = self.act(x) |
| | x = self.drop(x) |
| | x = self.fc2(x) |
| | x = self.drop(x) |
| | return x |
| |
|
| | |
| | class ResidualAttentionBlock(nn.Module): |
| | def __init__(self, |
| | d_model: int, |
| | n_head: int, |
| | mlp_ratio: float = 4., |
| | qkv_bias: bool = True, |
| | qk_scale: float = None, |
| | attn_drop: float = 0., |
| | proj_drop: float = 0., |
| | drop_path: float = 0., |
| | attn_layer = MultiHeadAttention, |
| | act_layer = nn.GELU, |
| | norm_layer = nn.LayerNorm, |
| | norm_layer_eps = 1e-5, |
| | attn_mask: torch.Tensor = None, |
| | text_or_image=None, |
| | |
| | lora_rank: int = 0, |
| | lora_bias: bool = False |
| | ): |
| | super().__init__() |
| |
|
| | if attn_layer == MultiHeadAttention: |
| | self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop) |
| | elif attn_layer == MultiHeadAttention_LoRA: |
| | self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias) |
| | elif attn_layer == MultiHeadAttention_SDLoRA: |
| | self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias) |
| | elif attn_layer == MultiHeadAttention_LoRA_Sub: |
| | self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias) |
| | elif attn_layer == MultiHeadAttention_MaskedLoRA: |
| | self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias) |
| | elif attn_layer == MultiHeadAttention_MultiMaskedLoRA: |
| | self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias) |
| | elif attn_layer == MultiHeadAttention_CL_LoRA: |
| | self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias) |
| | else: |
| | assert 0, f'{attn_layer} not Implemented' |
| | |
| | self.ln_1 = norm_layer(d_model, eps=norm_layer_eps) |
| | self.mlp = Mlp(d_model, int(d_model * mlp_ratio), act_layer=act_layer) |
| | self.ln_2 = norm_layer(d_model, eps=norm_layer_eps) |
| | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
| | self.attn_mask = attn_mask |
| | self.text_or_image = text_or_image |
| | |
| | def attention(self, x: torch.Tensor, **kwargs): |
| | self.attn_mask = self.attn_mask.to(x) if self.attn_mask is not None else None |
| | |
| | x = x.permute(1, 0, 2) |
| | attn = self.attn(x, attn_mask=self.attn_mask, **kwargs) |
| | attn = attn.permute(1, 0, 2) |
| |
|
| | return attn |
| |
|
| | def forward(self, x: torch.Tensor, **kwargs): |
| |
|
| | x = x + self.drop_path(self.attention(self.ln_1(x), **kwargs)) |
| | x = x + self.drop_path(self.mlp(self.ln_2(x))) |
| |
|
| | return x |
| |
|
| | class ResidualAttentionBlock_MLP(ResidualAttentionBlock): |
| | def __init__(self, |
| | d_model: int, |
| | n_head: int, |
| | mlp_ratio: float = 4., |
| | qkv_bias: bool = True, |
| | qk_scale: float = None, |
| | attn_drop: float = 0., |
| | proj_drop: float = 0., |
| | drop_path: float = 0., |
| | attn_layer = MultiHeadAttention, |
| | act_layer = nn.GELU, |
| | norm_layer = nn.LayerNorm, |
| | attn_mask: torch.Tensor = None, |
| | text_or_image=None, |
| | |
| | lora_rank: int = 0, |
| | lora_bias: bool = False, |
| | ): |
| | super().__init__( |
| | d_model, |
| | n_head, |
| | mlp_ratio, |
| | qkv_bias, |
| | qk_scale, |
| | attn_drop, |
| | proj_drop, |
| | drop_path, |
| | attn_layer, |
| | act_layer, |
| | norm_layer, |
| | attn_mask, |
| | text_or_image) |
| |
|
| | self.ffn_num = 64 |
| | self.adaptmlp = Adapter(d_model=d_model, dropout=0.1, bottleneck=self.ffn_num, |
| | init_option='lora', adapter_scalar=0.1, adapter_layernorm_option='none') |
| |
|
| | self.lora_feature = None |
| | |
| | def attention(self, x: torch.Tensor, **kwargs): |
| | self.attn_mask = self.attn_mask.to(x) if self.attn_mask is not None else None |
| | |
| | x = x.permute(1, 0, 2) |
| | attn = self.attn(x, attn_mask=self.attn_mask, **kwargs) |
| | attn = attn.permute(1, 0, 2) |
| |
|
| | return attn |
| |
|
| | def forward(self, x: torch.Tensor, compute_lora_feat = False, **kwargs): |
| | |
| | x = x + self.drop_path(self.attention(self.ln_1(x), **kwargs)) |
| |
|
| | x_re = x.permute(1, 0, 2) |
| | adapt_x = self.adaptmlp(x_re, add_residual=False) |
| | adapt_x = adapt_x.permute(1, 0, 2) |
| |
|
| | x = x + self.drop_path(self.mlp(self.ln_2(x)) + adapt_x) |
| |
|
| | if compute_lora_feat: |
| | self.lora_feature = adapt_x.detach().cpu() |
| |
|
| | return x |
| |
|
| | class ResidualAttentionBlock_MaskedMLP(ResidualAttentionBlock): |
| | def __init__(self, |
| | d_model: int, |
| | n_head: int, |
| | mlp_ratio: float = 4., |
| | qkv_bias: bool = True, |
| | qk_scale: float = None, |
| | attn_drop: float = 0., |
| | proj_drop: float = 0., |
| | drop_path: float = 0., |
| | attn_layer = MultiHeadAttention, |
| | act_layer = nn.GELU, |
| | norm_layer = nn.LayerNorm, |
| | attn_mask: torch.Tensor = None, |
| | text_or_image=None, |
| | |
| | lora_rank: int = 0, |
| | lora_bias: bool = False, |
| | ): |
| | super().__init__( |
| | d_model, |
| | n_head, |
| | mlp_ratio, |
| | qkv_bias, |
| | qk_scale, |
| | attn_drop, |
| | proj_drop, |
| | drop_path, |
| | attn_layer, |
| | act_layer, |
| | norm_layer, |
| | attn_mask, |
| | text_or_image) |
| |
|
| | self.ffn_num = 64 |
| | self.adaptmlp = MaskedAdapter(d_model=d_model, dropout=0.1, bottleneck=self.ffn_num, |
| | init_option='lora', adapter_scalar=0.1, adapter_layernorm_option='none') |
| |
|
| | def attention(self, x: torch.Tensor, **kwargs): |
| | self.attn_mask = self.attn_mask.to(x) if self.attn_mask is not None else None |
| | |
| | x = x.permute(1, 0, 2) |
| | attn = self.attn(x, attn_mask=self.attn_mask, **kwargs) |
| | attn = attn.permute(1, 0, 2) |
| |
|
| | return attn |
| |
|
| | def forward(self, x: torch.Tensor, compute_input_matrix = False, **kwargs): |
| | |
| | x = x + self.drop_path(self.attention(self.ln_1(x), **kwargs)) |
| |
|
| | x_re = x.permute(1, 0, 2) |
| | adapt_x = self.adaptmlp(x_re, add_residual=False, compute_input_matrix=compute_input_matrix) |
| | adapt_x = adapt_x.permute(1, 0, 2) |
| |
|
| | x = x + self.drop_path(self.mlp(self.ln_2(x)) + adapt_x) |
| |
|
| | return x |
| |
|
| | class ResidualAttentionBlock_MoE_MLP(ResidualAttentionBlock): |
| | def __init__(self, |
| | d_model: int, |
| | n_head: int, |
| | mlp_ratio: float = 4., |
| | qkv_bias: bool = True, |
| | qk_scale: float = None, |
| | attn_drop: float = 0., |
| | proj_drop: float = 0., |
| | drop_path: float = 0., |
| | attn_layer = MultiHeadAttention, |
| | act_layer = nn.GELU, |
| | norm_layer = nn.LayerNorm, |
| | attn_mask: torch.Tensor = None, |
| | text_or_image=None, |
| | |
| | lora_rank: int = 0, |
| | lora_bias: bool = False, |
| | |
| | step: int = 0, |
| | experts_num: int = 0, |
| | top_k: int = 0, |
| | noisy_gating: bool = True |
| | ): |
| | super().__init__( |
| | d_model, |
| | n_head, |
| | mlp_ratio, |
| | qkv_bias, |
| | qk_scale, |
| | attn_drop, |
| | proj_drop, |
| | drop_path, |
| | attn_layer, |
| | act_layer, |
| | norm_layer, |
| | attn_mask, |
| | text_or_image) |
| |
|
| | assert top_k <= experts_num |
| |
|
| | self.register_buffer("mean", torch.tensor([0.0])) |
| | self.register_buffer("std", torch.tensor([1.0])) |
| | self.step = step |
| | self.top_k = top_k |
| | self.noisy_gating = noisy_gating |
| |
|
| | self.ffn_num = 64 |
| | self.experts_num = experts_num |
| | self.softmax = nn.Softmax(1) |
| | self.softplus = nn.Softplus() |
| | |
| | self.router_list = nn.ParameterList([ |
| | nn.Parameter(torch.zeros(d_model, self.experts_num), requires_grad=True) for _ in range(self.step) |
| | ]) |
| | self.w_noise_list = nn.ParameterList([ |
| | nn.Parameter(torch.zeros(d_model, self.experts_num), requires_grad=True) for _ in range(self.step) |
| | ]) |
| |
|
| | self.adaptmlp_list = nn.ModuleList([ |
| | Adapter(d_model=d_model, dropout=0.1, bottleneck=self.ffn_num, |
| | init_option='lora', |
| | adapter_scalar=0.1, |
| | adapter_layernorm_option='none') |
| | for _ in range(self.experts_num) |
| | ]) |
| |
|
| | self.lora_feature = None |
| | |
| | |
| | def attention(self, x: torch.Tensor, **kwargs): |
| | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None |
| | |
| | x = x.permute(1, 0, 2) |
| | attn = self.attn(x, attn_mask=self.attn_mask, **kwargs) |
| | attn = attn.permute(1, 0, 2) |
| |
|
| | return attn |
| |
|
| | def cv_squared(self, x): |
| | """The squared coefficient of variation of a sample. |
| | Useful as a loss to encourage a positive distribution to be more uniform. |
| | Epsilons added for numerical stability. |
| | Returns 0 for an empty Tensor. |
| | Args: |
| | x: a `Tensor`. |
| | Returns: |
| | a `Scalar`. |
| | """ |
| | eps = 1e-10 |
| | |
| |
|
| | if x.shape[0] == 1: |
| | return torch.tensor([0], device=x.device, dtype=x.dtype) |
| | return x.float().var() / (x.float().mean()**2 + eps) |
| |
|
| | def _gates_to_load(self, gates): |
| | """Compute the true load per expert, given the gates. |
| | The load is the number of examples for which the corresponding gate is >0. |
| | Args: |
| | gates: a `Tensor` of shape [batch_size, n] |
| | Returns: |
| | a float32 `Tensor` of shape [n] |
| | """ |
| | return (gates > 0).sum(0) |
| |
|
| | def _prob_in_top_k(self, clean_values, noisy_values, noise_stddev, noisy_top_values): |
| | """Helper function to NoisyTopKGating. |
| | Computes the probability that value is in top k, given different random noise. |
| | This gives us a way of backpropagating from a loss that balances the number |
| | of times each expert is in the top k experts per example. |
| | In the case of no noise, pass in None for noise_stddev, and the result will |
| | not be differentiable. |
| | Args: |
| | clean_values: a `Tensor` of shape [batch, n]. |
| | noisy_values: a `Tensor` of shape [batch, n]. Equal to clean values plus |
| | normally distributed noise with standard deviation noise_stddev. |
| | noise_stddev: a `Tensor` of shape [batch, n], or None |
| | noisy_top_values: a `Tensor` of shape [batch, m]. |
| | "values" Output of tf.top_k(noisy_top_values, m). m >= k+1 |
| | Returns: |
| | a `Tensor` of shape [batch, n]. |
| | """ |
| | |
| | batch = clean_values.size(0) |
| | m = noisy_top_values.size(1) |
| | top_values_flat = noisy_top_values.flatten() |
| |
|
| | threshold_positions_if_in = torch.arange(batch, device=clean_values.device) * m + self.top_k |
| | threshold_if_in = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_in), 1) |
| | is_in = torch.gt(noisy_values, threshold_if_in) |
| | threshold_positions_if_out = threshold_positions_if_in - 1 |
| | threshold_if_out = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_out), 1) |
| | |
| | normal = Normal(self.mean, self.std) |
| | |
| |
|
| | prob_if_in = normal.cdf((clean_values - threshold_if_in)/noise_stddev) |
| | prob_if_out = normal.cdf((clean_values - threshold_if_out)/noise_stddev) |
| | prob = torch.where(is_in, prob_if_in, prob_if_out) |
| | return prob |
| |
|
| | def noisy_top_k_gating(self, x, train, w_gate, w_noise, noise_epsilon=1e-2): |
| | """Noisy top-k gating. |
| | See paper: https://arxiv.org/abs/1701.06538. |
| | Args: |
| | x: input Tensor with shape [batch_size, input_size] |
| | train: a boolean - we only add noise at training time. |
| | noise_epsilon: a float |
| | Returns: |
| | gates: a Tensor with shape [batch_size, num_experts] |
| | load: a Tensor with shape [num_experts] |
| | """ |
| |
|
| | clean_logits = x @ w_gate.to(x) |
| |
|
| | if self.noisy_gating and train: |
| | raw_noise_stddev = x @ w_noise.to(x) |
| | noise_stddev = ((self.softplus(raw_noise_stddev) + noise_epsilon)) |
| | noisy_logits = clean_logits + (torch.randn_like(clean_logits) * noise_stddev) |
| | logits = noisy_logits |
| | else: |
| | logits = clean_logits |
| | |
| | top_logits, top_indices = logits.topk(min(self.top_k + 1, self.experts_num), dim=1) |
| | top_k_logits = top_logits[:, :self.top_k] |
| | top_k_indices = top_indices[:, :self.top_k] |
| | top_k_gates = self.softmax(top_k_logits) |
| | zeros = torch.zeros_like(logits) |
| | gates = zeros.scatter(1, top_k_indices, top_k_gates) |
| | |
| | |
| | |
| | |
| | return gates, None |
| |
|
| | def forward(self, x: torch.Tensor, compute_lora_feat=False, **kwargs): |
| | |
| | x = x + self.drop_path(self.attention(self.ln_1(x), **kwargs)) |
| |
|
| | x_re = x.permute(1, 0, 2)[:, 0, :] |
| | gates, load = self.noisy_top_k_gating(x_re, self.training, self.router_list[0], |
| | self.w_noise_list[0]) |
| | |
| | dispatcher = SparseDispatcher(self.experts_num, gates) |
| | expert_inputs = dispatcher.dispatch(x.permute(1, 0, 2).view(x.shape[1], -1)) |
| |
|
| | expert_outputs = [self.adaptmlp_list[i](expert_inputs[i].view(expert_inputs[i].shape[0], |
| | x.shape[0], x.shape[2]).to(x), add_residual=False) |
| | for i in range(self.experts_num)] |
| |
|
| | expert_outputs = [out.view(out.shape[0], -1) for out in expert_outputs if out.shape[0] > 0] |
| |
|
| | y = dispatcher.combine(expert_outputs) |
| | y = y.view(x.shape[1], x.shape[0], x.shape[2]) |
| | x = x + self.drop_path(self.mlp(self.ln_2(x)) + y.permute(1, 0, 2)) |
| |
|
| | return x |
| |
|
| | class ResidualAttentionBlock_MoE_Proj(ResidualAttentionBlock): |
| | def __init__(self, |
| | d_model: int, |
| | n_head: int, |
| | mlp_ratio: float = 4., |
| | qkv_bias: bool = True, |
| | qk_scale: float = None, |
| | attn_drop: float = 0., |
| | proj_drop: float = 0., |
| | drop_path: float = 0., |
| | attn_layer = MultiHeadAttention, |
| | act_layer = nn.GELU, |
| | norm_layer = nn.LayerNorm, |
| | attn_mask: torch.Tensor = None, |
| | text_or_image=None, |
| | |
| | lora_rank: int = 0, |
| | lora_bias: bool = False, |
| | |
| | experts_num=0, |
| | ): |
| | super().__init__() |
| |
|
| | if isinstance(attn_layer, str): |
| | try: |
| | attn_layer = globals()[attn_layer] |
| | except KeyError: |
| | print(f'{attn_layer} not found, using default MultiHeadAttention') |
| | attn_layer = MultiHeadAttention |
| |
|
| | if isinstance(act_layer, str): |
| | try: |
| | act_layer = globals()[act_layer] |
| | except KeyError: |
| | print(f'{act_layer} not found, using default nn.GELU') |
| | act_layer = nn.GELU |
| | |
| | if isinstance(norm_layer, str): |
| | try: |
| | norm_layer = globals()[norm_layer] |
| | except KeyError: |
| | print(f'{norm_layer} not found, using default nn.LayerNorm') |
| | norm_layer = nn.LayerNorm |
| |
|
| | self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop) |
| | self.ln_1 = norm_layer(d_model) |
| | self.mlp = Mlp(d_model, int(d_model * mlp_ratio), act_layer=act_layer) |
| | self.ln_2 = norm_layer(d_model) |
| | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
| | self.attn_mask = attn_mask |
| | self.is_train = True |
| | |
| | if experts_num > 1: |
| | self.register_buffer("mean", torch.tensor([0.0])) |
| | self.register_buffer("std", torch.tensor([1.0])) |
| | self.step = 1 |
| | else: |
| | self.step = 0 |
| | self.top_k = 2 |
| | self.ffn_num = 64 |
| | self.experts_num = experts_num |
| | self.softmax = nn.Softmax(1) |
| | self.softplus = nn.Softplus() |
| | self.noisy_gating = True |
| | self.text_or_image = text_or_image |
| | self.router_list = nn.ParameterList() |
| | self.w_noise_list = nn.ParameterList() |
| |
|
| | for i in range(self.step): |
| | self.router_list.append(nn.Parameter(torch.zeros(d_model, self.experts_num), requires_grad=True)) |
| | self.w_noise_list.append(nn.Parameter(torch.zeros(d_model, self.experts_num), requires_grad=True)) |
| | |
| | self.adaptmlp_list = nn.ModuleList() |
| | for i in range(self.experts_num): |
| | self.adaptmlp_list.append(Adapter(d_model=d_model, dropout=0.1, bottleneck=self.ffn_num, |
| | init_option='lora', |
| | adapter_scalar=0.1, |
| | adapter_layernorm_option='none', |
| | )) |
| |
|
| | self.lora_feature = None |
| | |
| | def attention(self, x: torch.Tensor, **kwargs): |
| | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None |
| | |
| | x = x.permute(1, 0, 2) |
| | attn = self.attn(x, attn_mask=self.attn_mask, **kwargs) |
| | attn = attn.permute(1, 0, 2) |
| |
|
| | return attn |
| |
|
| | def cv_squared(self, x): |
| | """The squared coefficient of variation of a sample. |
| | Useful as a loss to encourage a positive distribution to be more uniform. |
| | Epsilons added for numerical stability. |
| | Returns 0 for an empty Tensor. |
| | Args: |
| | x: a `Tensor`. |
| | Returns: |
| | a `Scalar`. |
| | """ |
| | eps = 1e-10 |
| | |
| |
|
| | if x.shape[0] == 1: |
| | return torch.tensor([0], device=x.device, dtype=x.dtype) |
| | return x.float().var() / (x.float().mean()**2 + eps) |
| |
|
| | def _gates_to_load(self, gates): |
| | """Compute the true load per expert, given the gates. |
| | The load is the number of examples for which the corresponding gate is >0. |
| | Args: |
| | gates: a `Tensor` of shape [batch_size, n] |
| | Returns: |
| | a float32 `Tensor` of shape [n] |
| | """ |
| | return (gates > 0).sum(0) |
| |
|
| | def _prob_in_top_k(self, clean_values, noisy_values, noise_stddev, noisy_top_values): |
| | """Helper function to NoisyTopKGating. |
| | Computes the probability that value is in top k, given different random noise. |
| | This gives us a way of backpropagating from a loss that balances the number |
| | of times each expert is in the top k experts per example. |
| | In the case of no noise, pass in None for noise_stddev, and the result will |
| | not be differentiable. |
| | Args: |
| | clean_values: a `Tensor` of shape [batch, n]. |
| | noisy_values: a `Tensor` of shape [batch, n]. Equal to clean values plus |
| | normally distributed noise with standard deviation noise_stddev. |
| | noise_stddev: a `Tensor` of shape [batch, n], or None |
| | noisy_top_values: a `Tensor` of shape [batch, m]. |
| | "values" Output of tf.top_k(noisy_top_values, m). m >= k+1 |
| | Returns: |
| | a `Tensor` of shape [batch, n]. |
| | """ |
| | |
| | batch = clean_values.size(0) |
| | m = noisy_top_values.size(1) |
| | top_values_flat = noisy_top_values.flatten() |
| |
|
| | threshold_positions_if_in = torch.arange(batch, device=clean_values.device) * m + self.top_k |
| | threshold_if_in = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_in), 1) |
| | is_in = torch.gt(noisy_values, threshold_if_in) |
| | threshold_positions_if_out = threshold_positions_if_in - 1 |
| | threshold_if_out = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_out), 1) |
| | |
| | normal = Normal(self.mean, self.std) |
| | |
| |
|
| | prob_if_in = normal.cdf((clean_values - threshold_if_in)/noise_stddev) |
| | prob_if_out = normal.cdf((clean_values - threshold_if_out)/noise_stddev) |
| | prob = torch.where(is_in, prob_if_in, prob_if_out) |
| | return prob |
| |
|
| | def noisy_top_k_gating(self, x, train, w_gate, w_noise, noise_epsilon=1e-2): |
| | """Noisy top-k gating. |
| | See paper: https://arxiv.org/abs/1701.06538. |
| | Args: |
| | x: input Tensor with shape [batch_size, input_size] |
| | train: a boolean - we only add noise at training time. |
| | noise_epsilon: a float |
| | Returns: |
| | gates: a Tensor with shape [batch_size, num_experts] |
| | load: a Tensor with shape [num_experts] |
| | """ |
| |
|
| | clean_logits = x @ w_gate.to(x) |
| | if self.noisy_gating and train: |
| | raw_noise_stddev = x @ w_noise.to(x) |
| | noise_stddev = ((self.softplus(raw_noise_stddev) + noise_epsilon)) |
| | noisy_logits = clean_logits + (torch.randn_like(clean_logits) * noise_stddev) |
| | logits = noisy_logits |
| | else: |
| | logits = clean_logits |
| | |
| | top_logits, top_indices = logits.topk(min(self.top_k + 1, self.experts_num), dim=1) |
| | top_k_logits = top_logits[:, :self.top_k] |
| | top_k_indices = top_indices[:, :self.top_k] |
| | top_k_gates = self.softmax(top_k_logits) |
| | zeros = torch.zeros_like(logits) |
| | gates = zeros.scatter(1, top_k_indices, top_k_gates) |
| | |
| | |
| | |
| | |
| | return gates, None |
| |
|
| | def forward(self, x: torch.Tensor, **kwargs): |
| | |
| | x = x + self.drop_path(self.attention(self.ln_1(x), **kwargs)) |
| |
|
| | if self.experts_num == 0: |
| |
|
| | x = x + self.drop_path(self.mlp(self.ln_2(x))) |
| |
|
| | elif self.experts_num == 1: |
| |
|
| | x_re = x.permute(1, 0, 2) |
| | adapt_x = self.adaptmlp_list[0](x_re, add_residual=False) |
| | adapt_x = adapt_x.permute(1, 0, 2) |
| |
|
| | x = x + self.drop_path(self.mlp(self.ln_2(x)) + adapt_x) |
| |
|
| | if compute_lora_feat: |
| | self.lora_feature = adapt_x.detach().cpu() |
| |
|
| | else: |
| |
|
| | x_re = x.permute(1, 0, 2)[:, 0, :] |
| | gates, load = self.noisy_top_k_gating(x_re, self.is_train, self.router_list[0], |
| | self.w_noise_list[0]) |
| | |
| | dispatcher = SparseDispatcher(self.experts_num, gates) |
| | expert_inputs = dispatcher.dispatch(x.permute(1, 0, 2).view(x.shape[1], -1)) |
| |
|
| | expert_outputs = [self.adaptmlp_list[i](expert_inputs[i].view(expert_inputs[i].shape[0], |
| | x.shape[0], x.shape[2]).to(x), add_residual=False) |
| | for i in range(self.experts_num)] |
| |
|
| | expert_outputs = [out.view(out.shape[0], -1) for out in expert_outputs if out.shape[0] > 0] |
| |
|
| | y = dispatcher.combine(expert_outputs) |
| | y = y.view(x.shape[1], x.shape[0], x.shape[2]) |
| | x = x + self.drop_path(self.mlp(self.ln_2(x)) + y.permute(1, 0, 2)) |
| |
|
| | return x |
| |
|
| | class ResidualAttentionBiBlock(nn.Module): |
| | def __init__(self, |
| | d_model: int, |
| | n_head: int, |
| | mlp_ratio: float = 4., |
| | qkv_bias: bool = True, |
| | qk_scale: float = None, |
| | attn_drop: float = 0., |
| | proj_drop: float = 0., |
| | drop_path: float = 0., |
| | attn_layer = MultiHeadAttention, |
| | act_layer = nn.GELU, |
| | norm_layer = nn.LayerNorm, |
| | attn_mask: torch.Tensor = None, |
| | text_or_image=None, |
| | |
| | lora_rank: int = 0, |
| | lora_bias: bool = False |
| | ): |
| | super().__init__() |
| |
|
| | if attn_layer == MultiHeadAttention: |
| | self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop) |
| | elif attn_layer == MultiHeadAttention_LoRA: |
| | self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias) |
| | elif attn_layer == MultiHeadAttention_MaskedLoRA: |
| | self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias) |
| | elif attn_layer == MultiHeadAttention_MultiMaskedLoRA or attn_layer == MultiHeadAttention_MultiMaskedLoRA3 or attn_layer == MultiHeadAttention_MaskedLoRA1: |
| | self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias) |
| | else: |
| | assert 0, f'{attn_layer} not Implemented' |
| | |
| | self.ln_1 = norm_layer(d_model) |
| | self.mlp = Mlp(d_model, int(d_model * mlp_ratio), act_layer=act_layer) |
| | self.ln_2 = norm_layer(d_model) |
| | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
| | self.attn_mask = attn_mask |
| | self.text_or_image = text_or_image |
| | |
| | def attention(self, x: torch.Tensor, x_proj, probs, **kwargs): |
| |
|
| | self.attn_mask = self.attn_mask.to(x) if self.attn_mask is not None else None |
| | |
| | x, x_proj = x.permute(1, 0, 2), x_proj.permute(1, 0, 2) |
| | attn, attn_proj, probs = self.attn(x, x_proj, probs, attn_mask=self.attn_mask, **kwargs) |
| | attn, attn_proj = attn.permute(1, 0, 2), attn_proj.permute(1, 0, 2) |
| |
|
| | return attn, attn_proj, probs |
| |
|
| | def forward(self, x: torch.Tensor, x_proj, probs, **kwargs): |
| | |
| | attn, attn_proj, probs = self.attention(self.ln_1(x), self.ln_1(x_proj), probs, **kwargs) |
| |
|
| | x = x + self.drop_path(attn) |
| | x_proj = x_proj + self.drop_path(attn_proj) |
| |
|
| | x = x + self.drop_path(self.mlp(self.ln_2(x))) |
| | x_proj = x_proj + self.drop_path(self.mlp(self.ln_2(x_proj))) |
| |
|
| | return x, x_proj, probs |
| |
|
| | |
| | class Transformer(nn.Module): |
| | def __init__(self, |
| | width: int, |
| | layers: int, |
| | heads: int, |
| | block_layer = ResidualAttentionBlock, |
| | attn_layer = MultiHeadAttention, |
| | act_layer = nn.GELU, |
| | norm_layer = nn.LayerNorm, |
| | attn_mask: torch.Tensor = None, |
| | text_or_image=None, |
| | **kwargs |
| | ): |
| | super().__init__() |
| | self.width = width |
| | self.layers = layers |
| |
|
| | if isinstance(block_layer, str): |
| | try: |
| | block_layer = globals()[block_layer] |
| | except KeyError: |
| | print(f'{block_layer} not found, using default ResidualAttentionBlock') |
| | block_layer = ResidualAttentionBlock |
| |
|
| | if isinstance(attn_layer, str): |
| | try: |
| | attn_layer = globals()[attn_layer] |
| | except KeyError: |
| | print(f'{attn_layer} not found, using default MultiHeadAttention') |
| | attn_layer = MultiHeadAttention |
| |
|
| | if isinstance(act_layer, str): |
| | try: |
| | act_layer = globals()[act_layer] |
| | except KeyError: |
| | print(f'{act_layer} not found, using default nn.GELU') |
| | act_layer = nn.GELU |
| | |
| | if isinstance(norm_layer, str): |
| | try: |
| | norm_layer = globals()[norm_layer] |
| | except KeyError: |
| | print(f'{norm_layer} not found, using default nn.LayerNorm') |
| | norm_layer = nn.LayerNorm |
| |
|
| | self.blocks = nn.ModuleList([ |
| | block_layer( |
| | d_model=width, |
| | n_head=heads, |
| | attn_layer=attn_layer, |
| | act_layer=act_layer, |
| | norm_layer=norm_layer, |
| | attn_mask=attn_mask, |
| | text_or_image=text_or_image, |
| | **kwargs) |
| | for _ in range(layers)]) |
| |
|
| | def forward(self, x: torch.Tensor, l2p_prompt=None, l2p_e_prompt_layer_idx=[], **kwargs): |
| |
|
| | prompt_counter = -1 |
| | for i, block in enumerate(self.blocks): |
| | if l2p_prompt is not None and (i in l2p_e_prompt_layer_idx): |
| | prompt_counter += 1 |
| | batched_prompt = l2p_prompt[prompt_counter] |
| | batched_prompt = batched_prompt.permute(1, 0, 2) |
| | x = torch.cat([batched_prompt, x], dim=0) |
| |
|
| | x = block(x, **kwargs) |
| | |
| | return x |
| |
|
| | class Transformer_Proj(Transformer): |
| | def __init__(self, |
| | width: int, |
| | layers: int, |
| | heads: int, |
| | block_layer = ResidualAttentionBlock, |
| | attn_layer = MultiHeadAttention, |
| | act_layer = nn.GELU, |
| | norm_layer = nn.LayerNorm, |
| | attn_mask: torch.Tensor = None, |
| | text_or_image=None, |
| | **kwargs |
| | ): |
| | super().__init__(width, layers, heads, block_layer, attn_layer, act_layer, norm_layer, attn_mask, text_or_image, **kwargs) |
| | self.probs = [] |
| |
|
| | def forward(self, x: torch.Tensor, **kwargs): |
| | |
| | x_proj = x.clone() |
| | self.probs = [] |
| | for i, block in enumerate(self.blocks): |
| | x, x_proj, self.probs = block(x, x_proj, self.probs, **kwargs) |
| |
|
| | return x_proj |
| |
|
| | class Transformer_CL_LoRA(Transformer): |
| | def __init__(self, |
| | width: int, |
| | layers: int, |
| | heads: int, |
| | block_layer = ResidualAttentionBlock, |
| | attn_layer = MultiHeadAttention, |
| | act_layer = nn.GELU, |
| | norm_layer = nn.LayerNorm, |
| | attn_mask: torch.Tensor = None, |
| | text_or_image=None, |
| | **kwargs |
| | ): |
| | super().__init__(width, layers, heads, block_layer, attn_layer, act_layer, norm_layer, attn_mask, text_or_image, **kwargs) |
| |
|
| | def forward(self, x, adapt, prompt, rank_prompt, block_weight, **kwargs): |
| |
|
| | for idx, blk in enumerate(self.blocks): |
| |
|
| | if idx >= 6: |
| | x = blk( |
| | x, |
| | adapt = adapt[idx], |
| | prompt = prompt, |
| | rank_prompt = rank_prompt, |
| | block_weight = block_weight[:, idx - 6], |
| | **kwargs |
| | ) |
| | else: |
| | x = blk( |
| | x, |
| | adapt = adapt[idx], |
| | prompt = prompt, |
| | rank_prompt = rank_prompt, |
| | block_weight = None, |
| | **kwargs |
| | ) |
| |
|
| | return x |
| |
|
| | |
| | class VisualTransformer(nn.Module): |
| | def __init__(self, |
| | img_size: int, |
| | patch_size: int, |
| | in_chans: int = 3, |
| | width: int = 768, |
| | depth: int = 12, |
| | heads: int = 8, |
| | output_dim: int = 512, |
| | text_or_image: str = None, |
| | **kwargs |
| | ): |
| | super().__init__() |
| | self.img_size = img_size |
| | self.patch_size = patch_size |
| | self.in_chans = in_chans |
| | self.width = width |
| | self.depth = depth |
| | self.heads = heads |
| | self.output_dim = output_dim |
| |
|
| | self.conv1 = nn.Conv2d(in_channels=in_chans, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) |
| |
|
| | scale = width ** -0.5 |
| | self.class_embedding = nn.Parameter(scale * torch.randn(width)) |
| | self.positional_embedding = nn.Parameter(scale * torch.randn((img_size // patch_size) ** 2 + 1, width)) |
| | self.ln_pre = LayerNorm(width) |
| |
|
| | self.transformer = Transformer(width, depth, heads, text_or_image=text_or_image, **kwargs) |
| |
|
| | self.ln_post = LayerNorm(width) |
| | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) |
| |
|
| | def forward(self, x: torch.Tensor, **kwargs): |
| |
|
| | x = self.conv1(x) |
| | x = x.reshape(x.shape[0], x.shape[1], -1) |
| | x = x.permute(0, 2, 1) |
| |
|
| | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) |
| | x = x + self.positional_embedding.to(x.dtype) |
| | x = self.ln_pre(x) |
| |
|
| | x = x.permute(1, 0, 2) |
| | x = self.transformer(x, **kwargs) |
| | x = x.permute(1, 0, 2) |
| |
|
| | x = self.ln_post(x[:, 0, :]) |
| |
|
| | if self.proj is not None: |
| | x = x @ self.proj |
| |
|
| | return x |
| |
|
| | |
| | class VisionTransformer(nn.Module): |
| | """ Vision Transformer |
| | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - |
| | https://arxiv.org/abs/2010.11929 |
| | """ |
| | def __init__(self, |
| | img_size=224, |
| | patch_size=16, |
| | in_chans=3, |
| | num_classes=1000, |
| | embed_dim=768, |
| | depth=12, |
| | num_heads=12, |
| | attn_layer=MultiHeadAttention, |
| | mlp_ratio=4., |
| | qkv_bias=True, |
| | qk_scale=None, |
| | representation_size=None, |
| | drop_rate=0., |
| | attn_drop_rate=0., |
| | drop_path_rate=0., |
| | norm_layer=nn.LayerNorm, |
| | ckpt_layer=0, |
| | transformer_layer=Transformer, |
| | **kwargs): |
| | """ |
| | Args: |
| | img_size (int, tuple): input image size |
| | patch_size (int, tuple): patch size |
| | in_chans (int): number of input channels |
| | num_classes (int): number of classes for classification head |
| | embed_dim (int): embedding dimension |
| | depth (int): depth of transformer |
| | num_heads (int): number of attention heads |
| | mlp_ratio (int): ratio of mlp hidden dim to embedding dim |
| | qkv_bias (bool): enable bias for qkv if True |
| | qk_scale (float): override default qk scale of head_dim ** -0.5 if set |
| | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set |
| | drop_rate (float): dropout rate |
| | attn_drop_rate (float): attention dropout rate |
| | drop_path_rate (float): stochastic depth rate |
| | norm_layer: (nn.Module): normalization layer |
| | """ |
| | super().__init__() |
| |
|
| | self.num_features = self.embed_dim = embed_dim |
| | self.num_heads = num_heads |
| |
|
| | self.patch_embed = PatchEmbed( |
| | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) |
| |
|
| | num_patches = self.patch_embed.num_patches |
| |
|
| | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
| | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) |
| | self.pos_drop = nn.Dropout(p=drop_rate) |
| | if transformer_layer == 'Transformer_Proj': |
| | self.transformer = Transformer_Proj(embed_dim, depth, num_heads, text_or_image='image', attn_layer=attn_layer, norm_layer=norm_layer, **kwargs) |
| | elif transformer_layer == 'Transformer_CL_LoRA': |
| | self.transformer = Transformer_CL_LoRA(embed_dim, depth, num_heads, text_or_image='image', attn_layer=attn_layer, norm_layer=norm_layer, **kwargs) |
| | else: |
| | self.transformer = Transformer(embed_dim, depth, num_heads, text_or_image='image', attn_layer=attn_layer, norm_layer=norm_layer, **kwargs) |
| | self.norm = partial(nn.LayerNorm, eps=1e-6)(embed_dim) |
| |
|
| | trunc_normal_(self.pos_embed, std=.02) |
| | trunc_normal_(self.cls_token, std=.02) |
| | self.apply(self._init_weights) |
| |
|
| | def _init_weights(self, m): |
| | if isinstance(m, nn.Linear): |
| | trunc_normal_(m.weight, std=.02) |
| | if isinstance(m, nn.Linear) and m.bias is not None: |
| | nn.init.constant_(m.bias, 0) |
| | elif isinstance(m, nn.LayerNorm): |
| | nn.init.constant_(m.bias, 0) |
| | nn.init.constant_(m.weight, 1.0) |
| |
|
| | @torch.jit.ignore |
| | def no_weight_decay(self): |
| | return {'pos_embed', 'cls_token'} |
| |
|
| | def forward(self, x, register_blk=-1, prompt=None, prompt_flag='', q=None, train=False, task_id=-1, cls_features=None, **kwargs): |
| |
|
| | B = x.shape[0] |
| | x = self.patch_embed(x) |
| |
|
| | if prompt_flag == 'l2p': |
| |
|
| | batched_prompt = None |
| | e_prompt_layer_idx = [] |
| | if prompt: |
| |
|
| | num_prompted_layers = 1 |
| | e_prompt_layer_idx = [0] |
| | total_prompt_len = prompt.length * prompt.top_k * len(e_prompt_layer_idx) |
| |
|
| | batched_prompt, reduce_sim = prompt(x, cls_features=cls_features) |
| |
|
| | cls_tokens = self.cls_token.expand(B, -1, -1) |
| | x = torch.cat((cls_tokens, x), dim=1) |
| | |
| | x = x + self.pos_embed[:, :x.size(1), :] |
| | x = self.pos_drop(x) |
| |
|
| | x = x.permute(1, 0, 2) |
| | x = self.transformer( |
| | x, |
| | l2p_prompt = batched_prompt, |
| | l2p_e_prompt_layer_idx = e_prompt_layer_idx, |
| | **kwargs |
| | ) |
| | x = x.permute(1, 0, 2) |
| |
|
| | x = self.norm(x) |
| |
|
| | if prompt: |
| | x = x[:, :total_prompt_len] |
| | x = x.mean(dim=1) |
| | return x, reduce_sim |
| | else: |
| | return x[:, 0] |
| |
|
| | else: |
| |
|
| | cls_tokens = self.cls_token.expand(B, -1, -1) |
| | x = torch.cat((cls_tokens, x), dim=1) |
| |
|
| | x = x + self.pos_embed[:,:x.size(1),:] |
| | x = self.pos_drop(x) |
| |
|
| | |
| | prompt_loss = torch.zeros((1,), requires_grad=True).to(x.device) |
| | if prompt is not None: |
| | for i,blk in enumerate(self.transformer.blocks): |
| |
|
| | if prompt is not None: |
| | if train: |
| | p_list, loss, x = prompt.forward(q, i, x, train=True, task_id=task_id) |
| | prompt_loss += loss |
| | else: |
| | p_list, _, x = prompt.forward(q, i, x, train=False, task_id=task_id) |
| | else: |
| | p_list = None |
| |
|
| | |
| | x = x.permute(1, 0, 2) |
| | x = blk(x, register_hook=register_blk==i, prompt=p_list, **kwargs) |
| | x = x.permute(1, 0, 2) |
| | else: |
| |
|
| | x = x.permute(1, 0, 2) |
| | x = self.transformer(x, **kwargs) |
| | x = x.permute(1, 0, 2) |
| |
|
| | x = self.norm(x) |
| | return x, prompt_loss |
| |
|
| | @torch.jit.ignore() |
| | def load_pretrained(self, checkpoint_path, prefix=''): |
| | _load_weights(self, checkpoint_path, prefix) |
| |
|
| | class VisionTransformer_CL_LoRA(VisionTransformer): |
| | """ Vision Transformer |
| | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - |
| | https://arxiv.org/abs/2010.11929 |
| | """ |
| |
|
| | class Adapter_lora(nn.Module): |
| | def __init__(self, |
| | config=None, |
| | d_model=None, |
| | bottleneck=None, |
| | dropout=0.0, |
| | init_option="bert", |
| | adapter_scalar="1.0", |
| | adapter_layernorm_option="in"): |
| | super().__init__() |
| |
|
| | self.n_embd = config.d_model if d_model is None else d_model |
| | self.down_size = config.attn_bn if bottleneck is None else bottleneck |
| |
|
| | self.lora_A = nn.Linear(self.down_size, self.n_embd, bias=False) |
| | self.lora_B = nn.Linear(self.n_embd, self.down_size, bias=False) |
| |
|
| | random_matrix = torch.rand(self.n_embd, self.down_size) |
| | q, r = torch.linalg.qr(random_matrix) |
| | with torch.no_grad(): |
| | self.lora_B.weight.copy_(q.T) |
| | scaling_factor = 1. |
| | self.lora_B.weight.data *= scaling_factor |
| |
|
| | if init_option == "bert": |
| | raise NotImplementedError |
| | elif init_option == "lora": |
| | with torch.no_grad(): |
| | nn.init.zeros_(self.lora_A.weight) |
| | else: |
| | raise NotImplementedError |
| |
|
| | def forward(self, x): |
| | inter_x = self.lora_B(x) |
| | out = self.lora_A(inter_x) |
| | return out |
| |
|
| | def __init__(self, |
| | img_size=224, |
| | patch_size=16, |
| | in_chans=3, |
| | num_classes=1000, |
| | embed_dim=768, |
| | depth=12, |
| | num_heads=12, |
| | attn_layer=MultiHeadAttention, |
| | mlp_ratio=4., |
| | qkv_bias=True, |
| | qk_scale=None, |
| | representation_size=None, |
| | drop_rate=0., |
| | attn_drop_rate=0., |
| | drop_path_rate=0., |
| | norm_layer=nn.LayerNorm, |
| | ckpt_layer=0, |
| | transformer_layer=Transformer, |
| | **kwargs): |
| | """ |
| | Args: |
| | img_size (int, tuple): input image size |
| | patch_size (int, tuple): patch size |
| | in_chans (int): number of input channels |
| | num_classes (int): number of classes for classification head |
| | embed_dim (int): embedding dimension |
| | depth (int): depth of transformer |
| | num_heads (int): number of attention heads |
| | mlp_ratio (int): ratio of mlp hidden dim to embedding dim |
| | qkv_bias (bool): enable bias for qkv if True |
| | qk_scale (float): override default qk scale of head_dim ** -0.5 if set |
| | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set |
| | drop_rate (float): dropout rate |
| | attn_drop_rate (float): attention dropout rate |
| | drop_path_rate (float): stochastic depth rate |
| | norm_layer: (nn.Module): normalization layer |
| | """ |
| | super().__init__( |
| | img_size=img_size, |
| | patch_size=patch_size, |
| | in_chans=in_chans, |
| | num_classes=num_classes, |
| | embed_dim=embed_dim, |
| | depth=depth, |
| | num_heads=num_heads, |
| | attn_layer=attn_layer, |
| | mlp_ratio=mlp_ratio, |
| | qkv_bias=qkv_bias, |
| | qk_scale=qk_scale, |
| | representation_size=representation_size, |
| | drop_rate=drop_rate, |
| | attn_drop_rate=attn_drop_rate, |
| | drop_path_rate=drop_path_rate, |
| | norm_layer=norm_layer, |
| | ckpt_layer=ckpt_layer, |
| | transformer_layer=transformer_layer, |
| | **kwargs |
| | ) |
| |
|
| |
|
| | cfg_dict = { |
| | 'use_distillation': True, |
| | 'use_block_weight': True, |
| | 'msa_adapt': True, |
| | 'msa': [1, 0, 1], |
| | 'specfic_pos': [6, 7, 8, 9, 10, 11], |
| | 'general_pos': [0, 1, 2, 3, 4, 5], |
| | 'ffn_adapt': True, |
| | 'ffn_option': 'parallel', |
| | 'ffn_adapter_layernorm_option': 'none', |
| | 'ffn_adapter_init_option': 'lora', |
| | 'ffn_adapter_scalar': '0.1', |
| | 'ffn_num': kwargs['lora_rank'], |
| | 'd_model': 768, |
| | 'vpt_on': False, |
| | 'vpt_num': 0, |
| | '_device': 'cuda:0' |
| | } |
| | |
| | from types import SimpleNamespace |
| |
|
| | self.tuning_config = SimpleNamespace(**cfg_dict) |
| | self.config = self.tuning_config |
| |
|
| | self._device = self.tuning_config._device |
| | self.msa_adapt = self.tuning_config.msa_adapt |
| | self.use_distillation = self.tuning_config.use_distillation |
| | self.use_block_weight = self.tuning_config.use_block_weight |
| |
|
| | self.general_pos = self.tuning_config.general_pos |
| | self.specfic_pos = self.tuning_config.specfic_pos |
| | self.adapt_pos = self.general_pos + self.specfic_pos |
| | self.adapt_pos = sorted(self.adapt_pos) |
| |
|
| | if self.msa_adapt: |
| | self.msa = self.tuning_config.msa |
| |
|
| | if self.use_distillation: |
| | self.old_adapter_list = nn.ModuleList() |
| |
|
| | if self.use_block_weight: |
| | self.block_weight_list = [] |
| | self.block_weight = nn.Parameter(torch.randn(3, len(self.specfic_pos))) |
| | nn.init.uniform_(self.block_weight, .5, 1.5) |
| |
|
| | self.adapter_list = [] |
| | self.adapter_pos_list = [] |
| | self.cur_adapter = nn.ModuleList() |
| | self.get_new_adapter_initial_msa() |
| |
|
| | def forward(self, x, test = False, register_blk=-1, prompt=None, prompt_flag='', q=None, train=False, task_id=-1, cls_features=None, **kwargs): |
| |
|
| | if not test: |
| | output = self.forward_train(x) |
| | output = output[:, 0] |
| | return output, None |
| |
|
| | else: |
| | features = self.forward_test(x) |
| | output = torch.Tensor().to(features[0].device) |
| | for x in features: |
| | cls = x[:, 0, :] |
| | output = torch.cat(( |
| | output, |
| | cls |
| | ), dim=1) |
| | return output, None |
| |
|
| | def forward_train(self, x): |
| |
|
| | B = x.shape[0] |
| | x = self.patch_embed(x) |
| |
|
| | cls_tokens = self.cls_token.expand(B, -1, -1) |
| | x = torch.cat((cls_tokens, x), dim=1) |
| |
|
| | x = x + self.pos_embed[:,:x.size(1),:] |
| | x = self.pos_drop(x) |
| |
|
| | x = x.permute(1, 0, 2) |
| | |
| | x = self.transformer( |
| | x, |
| | adapt = self.cur_adapter, |
| | prompt = None, |
| | rank_prompt = None, |
| | block_weight=self.block_weight) |
| | x = x.permute(1, 0, 2) |
| | x = self.norm(x) |
| |
|
| | return x |
| |
|
| | def forward_test(self, x, use_init_ptm=False): |
| | import copy |
| | B = x.shape[0] |
| | x = self.patch_embed(x) |
| |
|
| | cls_tokens = self.cls_token.expand(B, -1, -1) |
| | x = torch.cat((cls_tokens, x), dim=1) |
| | x = x + self.pos_embed |
| | x_init = self.pos_drop(x) |
| |
|
| | features = [] |
| | assert self.config.ffn_adapt |
| | assert self.adapt_pos == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] |
| | assert self.general_pos == [0, 1, 2, 3, 4, 5] |
| | assert self.use_block_weight |
| |
|
| | |
| |
|
| | for i in range(len(self.adapter_list)): |
| | x = copy.deepcopy(x_init) |
| |
|
| | x = x.permute(1, 0, 2) |
| | for idx, blk in enumerate(self.transformer.blocks): |
| |
|
| | if idx >= 6: |
| | x = blk(x, adapt = self.adapter_list[i][idx - 6], prompt = None, rank_prompt = None, |
| | block_weight=self.block_weight_list[i][:, idx - 6]) |
| | else: |
| | x = blk(x, adapt = self.cur_adapter[idx], prompt = None, rank_prompt = None, block_weight=None) |
| | x = x.permute(1, 0, 2) |
| |
|
| | x = self.norm(x) |
| | features.append(x) |
| |
|
| | x = copy.deepcopy(x_init) |
| | x = x.permute(1, 0, 2) |
| | for idx, blk in enumerate(self.transformer.blocks): |
| |
|
| | if idx >= 6: |
| | x = blk(x, adapt = self.cur_adapter[idx], prompt = None, rank_prompt = None, |
| | block_weight=self.block_weight[:, idx - 6]) |
| | else: |
| | x = blk(x, adapt = self.cur_adapter[idx], prompt = None, rank_prompt = None, block_weight=None) |
| | x = x.permute(1, 0, 2) |
| |
|
| |
|
| | x = self.norm(x) |
| | features.append(x) |
| |
|
| | return features |
| |
|
| | def forward_proto(self, x, adapt_index): |
| | assert adapt_index > -1 |
| | assert self.config.ffn_adapt |
| | assert self.use_block_weight |
| |
|
| | B = x.shape[0] |
| | x = self.patch_embed(x) |
| |
|
| | cls_tokens = self.cls_token.expand(B, -1, -1) |
| | x = torch.cat((cls_tokens, x), dim=1) |
| | x = x + self.pos_embed |
| | x = self.pos_drop(x) |
| |
|
| |
|
| | if adapt_index < len(self.adapter_list): |
| | |
| | x = x.permute(1, 0, 2) |
| | for idx, blk in enumerate(self.transformer.blocks): |
| |
|
| | if idx >= 6: |
| | x = blk(x, adapt = self.adapter_list[adapt_index][idx - 6], prompt = None, rank_prompt = None, |
| | block_weight=self.block_weight_list[adapt_index][:, idx - 6]) |
| | else: |
| | x = blk(x, adapt = self.cur_adapter[idx], prompt = None, rank_prompt = None, block_weight=None) |
| | x = x.permute(1, 0, 2) |
| |
|
| | else: |
| | |
| | x = x.permute(1, 0, 2) |
| | for idx, blk in enumerate(self.transformer.blocks): |
| |
|
| | if idx >= 6: |
| | x = blk(x, adapt = self.cur_adapter[idx], prompt = None, rank_prompt = None, |
| | block_weight=self.block_weight[:, idx - 6]) |
| | else: |
| | x = blk(x, adapt = self.cur_adapter[idx], prompt = None, rank_prompt = None, block_weight=None) |
| | x = x.permute(1, 0, 2) |
| |
|
| | x = self.norm(x) |
| | x = x[:, 0, :] |
| |
|
| | return x |
| |
|
| | def forward_general_cls(self, x, t_idx): |
| | import copy |
| | B = x.shape[0] |
| | x = self.patch_embed(x) |
| |
|
| | cls_tokens = self.cls_token.expand(B, -1, -1) |
| | x = torch.cat((cls_tokens, x), dim=1) |
| | x = x + self.pos_embed |
| | x = self.pos_drop(x) |
| |
|
| | x_teacher = copy.deepcopy(x) |
| |
|
| | for j in range(6): |
| | x = self.transformer.blocks[j](x, adapt = self.cur_adapter[j]) |
| | x_teacher = self.transformer.blocks[j](x_teacher, adapt = self.old_adapter_list[t_idx-1][j]) |
| |
|
| | x = self.norm(x) |
| | output_new = x[:, 0, :] |
| |
|
| | x_teacher = self.norm(x_teacher) |
| | output_teacher= x_teacher[:, 0, :] |
| |
|
| | return output_new, output_teacher |
| |
|
| | def get_new_adapter_initial_msa(self): |
| |
|
| | config = self.config |
| | if config.ffn_adapt: |
| | for i in range(len(self.adapt_pos)): |
| | temp_adapter = nn.ModuleList() |
| | for j in self.msa: |
| | if j ==1: |
| | adapter = VisionTransformer_CL_LoRA.Adapter_lora(self.config, dropout=0.0, bottleneck=config.ffn_num, |
| | init_option=config.ffn_adapter_init_option, |
| | adapter_scalar=config.ffn_adapter_scalar, |
| | adapter_layernorm_option=config.ffn_adapter_layernorm_option, |
| | ).to(self._device) |
| | else: |
| | adapter = nn.Identity() |
| | temp_adapter.append(adapter) |
| |
|
| | self.cur_adapter.append(temp_adapter) |
| | self.cur_adapter.requires_grad_(True) |
| |
|
| | else: |
| | print("====Not use adapter===") |
| |
|
| | def add_adapter_to_list(self): |
| | temp_adapter = [] |
| | import copy |
| | for i in range(len(self.specfic_pos)): |
| | temp_pos = self.adapt_pos.index(self.specfic_pos[i]) |
| | temp_adapter.append(copy.deepcopy(self.cur_adapter[temp_pos].requires_grad_(False))) |
| | self.adapter_list.append(temp_adapter) |
| |
|
| | if self.use_block_weight: |
| | self.block_weight_old = copy.deepcopy(self.block_weight) |
| | self.block_weight_list.append(self.block_weight_old.requires_grad_(False)) |
| | self.block_weight = nn.Parameter(torch.randn(3, len(self.specfic_pos))) |
| | nn.init.uniform_(self.block_weight, .5, 1.5) |
| |
|
| | self.adapter_pos_list.append(self.adapt_pos) |
| |
|
| | if self.use_distillation: |
| | self.old_adapter_list.append(copy.deepcopy(self.cur_adapter).requires_grad_(False)) |
| | if self.msa_adapt: |
| | self.get_new_adapter_msa() |
| |
|
| | def get_new_adapter_msa(self): |
| | config = self.config |
| |
|
| | if config.ffn_adapt: |
| | for i in range(len(self.specfic_pos)): |
| | pos = self.adapt_pos.index(self.specfic_pos[i]) |
| | temp_adapter = nn.ModuleList() |
| | for j in self.msa: |
| | if j == 1: |
| | adapter = VisionTransformer_CL_LoRA.Adapter_lora(self.config, dropout=0.0, bottleneck=config.ffn_num, |
| | init_option=config.ffn_adapter_init_option, |
| | adapter_scalar=config.ffn_adapter_scalar, |
| | adapter_layernorm_option=config.ffn_adapter_layernorm_option, |
| | ).to(self._device) |
| | adapter.requires_grad_(True) |
| | else: |
| | adapter = nn.Identity() |
| | temp_adapter.append(adapter) |
| | self.cur_adapter[pos] = temp_adapter |
| |
|
| | if len(self.specfic_pos) < 12: |
| | self.cur_adapter.requires_grad_(True) |
| |
|
| | for i in self.adapt_pos: |
| | if i in self.general_pos: |
| | pos = self.adapt_pos.index(i) |
| | for j in range(len(self.msa)): |
| | if self.msa[j] == 1: |
| | self.cur_adapter[pos][j].lora_B.requires_grad_(False) |
| | else: |
| | print("====Not use adapter===") |
| |
|
| | @torch.jit.ignore() |
| | def load_pretrained(self, checkpoint_path, prefix=''): |
| | _load_weights(self, checkpoint_path, prefix) |
| |
|
| | @torch.no_grad() |
| | def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): |
| | """ Load weights from .npz checkpoints for official Google Brain Flax implementation |
| | """ |
| | import numpy as np |
| |
|
| | def _n2p(w, t=True): |
| | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: |
| | w = w.flatten() |
| | if t: |
| | if w.ndim == 4: |
| | w = w.transpose([3, 2, 0, 1]) |
| | elif w.ndim == 3: |
| | w = w.transpose([2, 0, 1]) |
| | elif w.ndim == 2: |
| | w = w.transpose([1, 0]) |
| | return torch.from_numpy(w) |
| |
|
| | w = np.load(checkpoint_path) |
| | if not prefix and 'opt/target/embedding/kernel' in w: |
| | prefix = 'opt/target/' |
| |
|
| | if hasattr(model.patch_embed, 'backbone'): |
| | |
| | backbone = model.patch_embed.backbone |
| | stem_only = not hasattr(backbone, 'stem') |
| | stem = backbone if stem_only else backbone.stem |
| | stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) |
| | stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) |
| | stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) |
| | if not stem_only: |
| | for i, stage in enumerate(backbone.stages): |
| | for j, block in enumerate(stage.blocks): |
| | bp = f'{prefix}block{i + 1}/unit{j + 1}/' |
| | for r in range(3): |
| | getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) |
| | getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) |
| | getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) |
| | if block.downsample is not None: |
| | block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) |
| | block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) |
| | block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) |
| | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) |
| | else: |
| | embed_conv_w = adapt_input_conv( |
| | model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) |
| | model.patch_embed.proj.weight.copy_(embed_conv_w) |
| | model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) |
| | model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) |
| | pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) |
| | if pos_embed_w.shape != model.pos_embed.shape: |
| | pos_embed_w = resize_pos_embed( |
| | pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) |
| | model.pos_embed.copy_(pos_embed_w) |
| | model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) |
| | model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) |
| |
|
| | for i, block in enumerate(model.blocks.children()): |
| | block_prefix = f'{prefix}Transformer/encoderblock_{i}/' |
| | mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' |
| | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) |
| | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) |
| | block.attn.qkv.weight.copy_(torch.cat([ |
| | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) |
| | block.attn.qkv.bias.copy_(torch.cat([ |
| | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) |
| | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) |
| | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) |
| | for r in range(2): |
| | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) |
| | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) |
| | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) |
| | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) |
| |
|