mamba-hypernetwork-personalization / modeling_mamba_hypernetwork.py
phammminhhieu's picture
Add model code
8faef29 verified
import torch
import torch.nn as nn
from mamba_ssm import Mamba
class MambaHypernetwork(nn.Module):
def __init__(self, config):
super().__init__()
vocab_size = config["vocab_size"]
hidden_dim = config["hidden_dim"]
state_dim = config["state_dim"]
expand = config["expand"]
num_llm_layers = config["num_llm_layers"]
lora_rank = config["lora_rank"]
q_proj_dim = config["q_proj_dim"]
v_proj_dim = config["v_proj_dim"]
self.hidden_dim = hidden_dim
self.num_llm_layers = num_llm_layers
self.lora_rank = lora_rank
self.q_proj_dim = q_proj_dim
self.v_proj_dim = v_proj_dim
self.embedding = nn.Embedding(vocab_size, hidden_dim)
self.mamba = Mamba(d_model=hidden_dim, d_state=state_dim, d_conv=4, expand=expand)
self.persona_proj = nn.Linear(2 * hidden_dim, hidden_dim)
self.history_proj = nn.Linear(2 * hidden_dim, hidden_dim)
self.combine = nn.Sequential(
nn.Linear(2 * hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
)
self.delta_heads = nn.ModuleList([
nn.ModuleDict({
"q_proj_A": nn.Linear(hidden_dim, q_proj_dim * lora_rank),
"q_proj_B": nn.Linear(hidden_dim, lora_rank * q_proj_dim),
"v_proj_A": nn.Linear(hidden_dim, lora_rank * q_proj_dim),
"v_proj_B": nn.Linear(hidden_dim, v_proj_dim * lora_rank),
})
for _ in range(num_llm_layers)
])
def encode_text(self, input_ids, attention_mask):
emb = self.embedding(input_ids)
mamba_out = self.mamba(emb)
mask_expanded = attention_mask.unsqueeze(-1).float()
masked_out = mamba_out * mask_expanded
sum_out = masked_out.sum(dim=1)
count = mask_expanded.sum(dim=1)
mean_pooled = sum_out / (count + 1e-8)
masked_out_for_max = masked_out.clone()
masked_out_for_max[attention_mask == 0] = float('-inf')
max_pooled = masked_out_for_max.max(dim=1).values
pooled = torch.cat([mean_pooled, max_pooled], dim=-1)
return pooled
def forward(self, persona_ids, persona_mask, history_ids, history_mask):
persona_feat = self.encode_text(persona_ids, persona_mask)
persona_feat = self.persona_proj(persona_feat)
history_feat = self.encode_text(history_ids, history_mask)
history_feat = self.history_proj(history_feat)
combined = torch.cat([persona_feat, history_feat], dim=-1)
combined = self.combine(combined)
all_deltas = []
for head in self.delta_heads:
layer_deltas = {name: head[name](combined) for name in head}
all_deltas.append(layer_deltas)
return all_deltas