|
|
| import torch |
| import torch.nn as nn |
| from mamba_ssm import Mamba |
|
|
| class MambaHypernetwork(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| |
| vocab_size = config["vocab_size"] |
| hidden_dim = config["hidden_dim"] |
| state_dim = config["state_dim"] |
| expand = config["expand"] |
| num_llm_layers = config["num_llm_layers"] |
| lora_rank = config["lora_rank"] |
| q_proj_dim = config["q_proj_dim"] |
| v_proj_dim = config["v_proj_dim"] |
| |
| self.hidden_dim = hidden_dim |
| self.num_llm_layers = num_llm_layers |
| self.lora_rank = lora_rank |
| self.q_proj_dim = q_proj_dim |
| self.v_proj_dim = v_proj_dim |
| |
| self.embedding = nn.Embedding(vocab_size, hidden_dim) |
| self.mamba = Mamba(d_model=hidden_dim, d_state=state_dim, d_conv=4, expand=expand) |
| |
| self.persona_proj = nn.Linear(2 * hidden_dim, hidden_dim) |
| self.history_proj = nn.Linear(2 * hidden_dim, hidden_dim) |
| self.combine = nn.Sequential( |
| nn.Linear(2 * hidden_dim, hidden_dim), |
| nn.LayerNorm(hidden_dim), |
| nn.ReLU(), |
| nn.Linear(hidden_dim, hidden_dim), |
| ) |
| |
| self.delta_heads = nn.ModuleList([ |
| nn.ModuleDict({ |
| "q_proj_A": nn.Linear(hidden_dim, q_proj_dim * lora_rank), |
| "q_proj_B": nn.Linear(hidden_dim, lora_rank * q_proj_dim), |
| "v_proj_A": nn.Linear(hidden_dim, lora_rank * q_proj_dim), |
| "v_proj_B": nn.Linear(hidden_dim, v_proj_dim * lora_rank), |
| }) |
| for _ in range(num_llm_layers) |
| ]) |
| |
| def encode_text(self, input_ids, attention_mask): |
| emb = self.embedding(input_ids) |
| mamba_out = self.mamba(emb) |
| mask_expanded = attention_mask.unsqueeze(-1).float() |
| masked_out = mamba_out * mask_expanded |
| sum_out = masked_out.sum(dim=1) |
| count = mask_expanded.sum(dim=1) |
| mean_pooled = sum_out / (count + 1e-8) |
| masked_out_for_max = masked_out.clone() |
| masked_out_for_max[attention_mask == 0] = float('-inf') |
| max_pooled = masked_out_for_max.max(dim=1).values |
| pooled = torch.cat([mean_pooled, max_pooled], dim=-1) |
| return pooled |
| |
| def forward(self, persona_ids, persona_mask, history_ids, history_mask): |
| persona_feat = self.encode_text(persona_ids, persona_mask) |
| persona_feat = self.persona_proj(persona_feat) |
| history_feat = self.encode_text(history_ids, history_mask) |
| history_feat = self.history_proj(history_feat) |
| combined = torch.cat([persona_feat, history_feat], dim=-1) |
| combined = self.combine(combined) |
| all_deltas = [] |
| for head in self.delta_heads: |
| layer_deltas = {name: head[name](combined) for name in head} |
| all_deltas.append(layer_deltas) |
| return all_deltas |
|
|