| import torch |
| import torch.nn as nn |
| from torch import Tensor |
| import math |
| import torch.nn.functional as F |
|
|
| from transformers import AutoConfig, PretrainedConfig |
| from jaxtyping import Float |
| from dataclasses import asdict, dataclass |
| from typing import List, Optional, Tuple, Dict |
| import einops |
|
|
|
|
| from .configIBA import MainConfig, HyperXSConfig, TrainingConfig |
|
|
|
|
|
|
| def transpose(weight, fan_in_fan_out): |
| return weight.T if fan_in_fan_out else weight |
|
|
| class LoraLayer: |
| def __init__( |
| self, |
| |
| rank: int, |
| train_cfg: TrainingConfig, |
| |
| lora_alpha: int, |
| lora_dropout: float, |
| ): |
| self.rank = rank |
| self.batch_train = train_cfg.per_device_train_batch_size |
| self.batch_valid = train_cfg.per_device_eval_batch_size |
| |
| self.lora_alpha = lora_alpha |
| |
| if lora_dropout > 0.0: |
| self.lora_dropout = nn.Dropout(p=lora_dropout) |
| else: |
| self.lora_dropout = lambda x: x |
| |
| self.disable_adapters = False |
|
|
| class LoraXSLinear(nn.Linear, LoraLayer): |
| |
| def __init__( |
| self, |
| in_features: int, |
| out_features: int, |
| train_cfg: TrainingConfig, |
| rank: int = 64, |
| |
| lora_alpha: int = 1, |
| lora_dropout: float = 0.0, |
| fan_in_fan_out: bool = False, |
| **kwargs, |
| ): |
| nn.Linear.__init__(self, in_features, out_features, **kwargs) |
| LoraLayer.__init__(self, rank=rank, train_cfg=train_cfg, lora_alpha=lora_alpha, |
| lora_dropout=lora_dropout) |
|
|
| self.fan_in_fan_out = fan_in_fan_out |
| |
| if rank > 0: |
| |
| self.register_buffer("lora_A", torch.zeros([in_features, rank]), persistent=True) |
| self.register_buffer("lora_B", torch.zeros([rank, out_features]), persistent=True) |
|
|
| self.scaling = self.lora_alpha / self.rank |
| |
| self.weight.requires_grad = False |
| self.lora_R = None |
| |
| |
| |
| if fan_in_fan_out: |
| self.weight.data = self.weight.data.T |
| self.reset_parameters() |
|
|
| def reset_parameters(self): |
| nn.Linear.reset_parameters(self) |
| if hasattr(self, "lora_A"): |
| |
| nn.init.kaiming_uniform_(self.lora_A, mode='fan_out', a=math.sqrt(5)) |
| nn.init.kaiming_uniform_(self.lora_B, mode='fan_in', a=math.sqrt(5)) |
|
|
|
|
| |
| |
|
|
| def set_R(self, R: torch.Tensor): |
| self.lora_R = R |
|
|
| def decompose_weight_svd(self, rank): |
| W = self.weight.data |
| device, dtype = W.device, W.dtype |
| |
| try: |
| U, S, Vt = torch.linalg.svd(W,full_matrices=False) |
| except torch.linalg.LinAlgError as e: |
| print(f"SVD computation failed: {e}") |
| return None, None |
|
|
| |
| U_r = U[:, :rank] |
| S_r_values = S[:rank] |
| sqrt_S_r_diag = torch.diag(torch.sqrt(S_r_values)) |
| Vt_r = Vt[:rank, :] |
|
|
| B = U_r @ sqrt_S_r_diag |
| A = sqrt_S_r_diag @ Vt_r |
|
|
| |
| self.lora_A = A.T.to(device, dtype) |
| self.lora_B = B.T.to(device, dtype) |
| |
| |
| |
| |
| |
|
|
| |
| def forward(self, x: torch.Tensor): |
| previous_dtype = self.weight.dtype |
|
|
| if self.disable_adapters: |
| result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) |
| elif self.rank > 0: |
| result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) |
|
|
| if self.lora_R is not None: |
| lora_R = self.lora_R |
| result = result + (self.lora_dropout(x) @ self.lora_A) @ (lora_R @ self.lora_B) * self.scaling |
|
|
| |
| |
|
|
| else: |
| result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) |
|
|
| if result.dtype != previous_dtype: |
| result = result.to(previous_dtype) |
|
|
| return result |
| |
|
|
| class HyperNetXSexp(nn.Module): |
| def __init__( |
| self, |
| hyperxs_cfg: HyperXSConfig, |
| hf_model_cfg: PretrainedConfig, |
| |
| ): |
| super(HyperNetXSexp, self).__init__() |
| self.n_modules = hyperxs_cfg.modules_per_layer |
| self.rank = hyperxs_cfg.lora_attn_dim |
| self.latent_feature_dim = hyperxs_cfg.latent_feature_dim |
|
|
| self.module_embed_dim = hyperxs_cfg.module_embed_dim |
| self.layer_embed_dim = hyperxs_cfg.layer_embed_dim |
| self.hyper_out = hyperxs_cfg.lora_attn_dim ** 2 |
|
|
| |
| |
| n_flat_indim = self.latent_feature_dim * hyperxs_cfg.n_cross_attn_tokens + self.module_embed_dim + self.layer_embed_dim |
|
|
| n_flat_outdim = hyperxs_cfg.out_proj_dim * hyperxs_cfg.n_cross_attn_tokens |
| n_proj = 4 * n_flat_outdim |
|
|
| self.latent_proj = nn.Linear(hf_model_cfg.hidden_size, self.latent_feature_dim) |
| self.mixture = nn.Linear(n_flat_indim, n_flat_outdim) |
| self.c_fc = nn.Linear(n_flat_outdim, n_proj) |
| self.c_proj = nn.Linear(n_proj, self.hyper_out) |
| self.act = nn.GELU() |
|
|
| |
| |
| |
| |
| |
| self.ln_latent = nn.LayerNorm(hf_model_cfg.hidden_size, eps=hyperxs_cfg.layer_norm_epsilon) |
| self.ln_1 = nn.LayerNorm(n_flat_indim, eps=hyperxs_cfg.layer_norm_epsilon) |
| self.ln_2 = nn.LayerNorm(n_flat_outdim, eps=hyperxs_cfg.layer_norm_epsilon) |
| |
| |
| self.layer_embedding = nn.Embedding(hf_model_cfg.num_hidden_layers, self.layer_embed_dim) |
| |
| self.module_embedding = nn.Embedding(self.n_modules, self.module_embed_dim) |
| self.hyperxs_cfg = hyperxs_cfg |
| self.hf_model_cfg = hf_model_cfg |
|
|
| self.reset_parameters() |
|
|
| def reset_parameters(self): |
| |
| INIT_STD = 1e-3 |
| nn.init.kaiming_normal_(self.latent_proj.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') |
| nn.init.constant_(self.latent_proj.bias, 0) |
|
|
| nn.init.kaiming_normal_(self.mixture.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') |
| |
| nn.init.constant_(self.mixture.bias, 0) |
|
|
| nn.init.kaiming_normal_(self.c_fc.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') |
| |
| nn.init.constant_(self.c_fc.bias, 0) |
|
|
| nn.init.normal_(self.layer_embedding.weight, mean=0.0, std=INIT_STD) |
|
|
| |
| |
| nn.init.constant_(self.c_proj.weight, 0) |
| nn.init.constant_(self.c_proj.bias, 0) |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
|
|
| def forward(self, x: Float[Tensor, 'b s f'], layer_idx) -> Float[Tensor, 'b r in out']: |
| batch_size = x.shape[0] |
| dtype_in = x.dtype |
| x = x.to(self.latent_proj.weight.dtype) |
| |
|
|
| |
| x = self.ln_latent(x) |
| x = self.latent_proj(x) |
| |
| |
| |
| x = einops.rearrange(x, 'batch seq fea -> batch (seq fea)') |
| |
| module_embedding = self.module_embedding.weight |
| |
| module_embedding = module_embedding.expand(batch_size, -1, -1) |
| x = x[:, None, ...] |
| x = x.expand(-1, self.n_modules, -1) |
| |
| |
| x = torch.cat((module_embedding, x), dim=-1) |
| x = einops.rearrange(x, 'batch n_modules in_dim -> (batch n_modules) in_dim') |
| |
| |
| if self.layer_embed_dim > 0: |
| |
| layer_embedding = self.layer_embedding(torch.tensor(layer_idx, device=x.device)) |
| |
| layer_embedding = layer_embedding.expand(batch_size, self.n_modules, -1) |
| layer_embedding = einops.rearrange(layer_embedding, 'batch n_modules in_dim -> (batch n_modules) in_dim') |
|
|
| x = torch.cat((layer_embedding, x), dim=-1) |
| |
| assert x.shape == (batch_size*self.n_modules, self.mixture.weight.data.shape[1]), 'Wrong at hypernetMLP.forward.x' |
| |
| h = self.ln_1(x) |
| h = self.mixture(x) |
| |
| h = self.act(h) |
| |
| |
| h = self.ln_2(h) |
| h = self.c_fc(h) |
| |
| h = self.act(h) |
|
|
| |
| h = self.c_proj(h) |
|
|
| h = einops.rearrange(h, '(batch n_modules) (rank r) -> batch n_modules rank r', |
| batch = batch_size, n_modules=self.n_modules, |
| rank = self.rank, r = self.rank) |
| h = h.to(dtype_in) |
| return h |
|
|
| def test_hypernet(): |
| """ |
| A simple test function for the HyperNetMLP class. |
| Given empty B @ A |
| """ |
| mainCfg=MainConfig() |
| print(mainCfg) |
| hf_model_cfg = AutoConfig.from_pretrained( |
| mainCfg.model.base_model_name |
| ) |
| print(hf_model_cfg) |
|
|
| print("--- Starting HyperNetMLP Test ---") |
| |
| in_features = hf_model_cfg.hidden_size |
| reduced_dim = 128 |
| out_features = 256 |
| batch_size = 27 |
|
|
| rank = 30 |
| outW = [768, 2*768] |
| n_mlp=2 |
| input_tensor = torch.randn(batch_size, mainCfg.hyperxs.n_cross_attn_tokens, in_features) |
|
|
| model = HyperNetXSexp(mainCfg.hyperxs, hf_model_cfg) |
| count_parameters(model) |
| |
| output = model(input_tensor, layer_idx=torch.tensor(1, dtype=torch.long)) |
| print('output shape', output.shape) |
| B = output[:,1,:,:768] |
| print('input shape', input_tensor.shape) |
| print('output shape and sum of B', output.shape, output.sum(), B.sum()) |
| if output.shape == (batch_size, n_mlp, rank, rank) and B.sum().item()==0: |
| print("\n--- HyperNetMLP Test Passed Successfully! ✅ ---") |
|
|
| def count_parameters(model:nn.Module): |
| print(f'Counting params in {model.__class__.__name__}') |
| total_params = 0 |
|
|
| |
| counted_param_ids = set() |
| print(f"{'Parameter Name':^60} | {'Shape':^20} | {'Num Params':^20}") |
| print("-" * 110) |
|
|
| for name, parameter in model.named_parameters(): |
| if not parameter.requires_grad: |
| continue |
|
|
| |
| |
| |
| param_id = id(parameter) |
| if param_id in counted_param_ids: |
| |
| print(f"Skipping shared parameter: {name}") |
| continue |
| counted_param_ids.add(param_id) |
| |
| shape = list(parameter.shape) |
| |
| |
| num_params = parameter.numel() |
| |
| |
| |
| |
| print(f"{name:<60} | {str(shape):<25} | {num_params:,}") |
|
|
| total_params += num_params |
| print(f"Model: {model.__class__.__name__} Total Trainable Params: {total_params:,}") |
| return total_params |
|
|
| if __name__ == "__main__": |
| print("Hello world from iba_lora") |
|
|
| mainCfg=MainConfig() |
| |
| hf_model_cfg = AutoConfig.from_pretrained( |
| mainCfg.model.base_model_name |
| ) |
| |
| print('-'*50) |
| test_hypernet() |