| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from otitans_core import OLoRALinear |
|
|
| class OTitansMemoryGate(nn.Module): |
| """ |
| Phase 2: The OTITANS Memory Core. |
| A recurrent memory state shielded by orthogonal LoRA projections. |
| """ |
| def __init__(self, hidden_size: int, rank: int = 8, memory_momentum: float = 0.9): |
| super().__init__() |
| self.hidden_size = hidden_size |
| self.memory_momentum = memory_momentum |
| |
| |
| |
| |
| self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False) |
| self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False) |
| self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False) |
| |
| |
| |
| self.gate = nn.Sequential( |
| nn.Linear(hidden_size * 2, hidden_size // 4), |
| nn.SiLU(), |
| nn.Linear(hidden_size // 4, hidden_size), |
| nn.Sigmoid() |
| ) |
| |
| |
| |
| self.register_buffer("memory_state", torch.zeros(hidden_size, hidden_size)) |
|
|
| def reset_memory(self): |
| """Wipes the recurrent memory clean for a new session.""" |
| self.memory_state.zero_() |
|
|
| def forward(self, hidden_states: torch.Tensor): |
| batch_size, seq_len, _ = hidden_states.shape |
| |
| |
| q = self.q_proj(hidden_states) |
| k = self.k_proj(hidden_states) |
| v = self.v_proj(hidden_states) |
| |
| memory_outputs = [] |
| |
| |
| |
| current_memory = self.memory_state.clone() |
| |
| for t in range(seq_len): |
| q_t = q[:, t, :] |
| k_t = k[:, t, :] |
| v_t = v[:, t, :] |
| |
| |
| |
| retrieval = torch.matmul(q_t.unsqueeze(1), current_memory).squeeze(1) |
| memory_outputs.append(retrieval) |
| |
| |
| |
| memory_prediction = torch.matmul(k_t.unsqueeze(1), current_memory).squeeze(1) |
| surprise = v_t - memory_prediction |
| |
| |
| update = torch.bmm(surprise.unsqueeze(2), k_t.unsqueeze(1)) |
| current_memory = (self.memory_momentum * current_memory) + update |
| |
| |
| memory_out_tensor = torch.stack(memory_outputs, dim=1) |
| |
| |
| self.memory_state.copy_(current_memory.detach()) |
| |
| |
| |
| gate_input = torch.cat([hidden_states, memory_out_tensor], dim=-1) |
| gate_value = self.gate(gate_input) |
| |
| |
| return hidden_states + (gate_value * memory_out_tensor) |
|
|