# --- START OF FILE architecture.py --- import torch import torch.nn as nn import torch.nn.functional as F from transformers import Phi3Config, Phi3ForCausalLM from typing import Optional, Dict # --- BUILDING BLOCK 1: VectorMemoryHead (No changes needed here, it inherits dtype correctly) --- class VectorMemoryHead(nn.Module): def __init__(self, hidden_dim: int, num_memory_slots: int, num_heads: int, ff_dim: int, device=None, dtype=None): super().__init__() self.hidden_dim = hidden_dim self.num_memory_slots = num_memory_slots encoder_layer = nn.TransformerEncoderLayer( d_model=hidden_dim, nhead=num_heads, dim_feedforward=ff_dim, dropout=0.1, batch_first=True, device=device, dtype=dtype ) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=1) self.memory_queries = nn.Parameter(torch.randn(1, num_memory_slots, hidden_dim, device=device, dtype=dtype)) self.memory_attention = nn.MultiheadAttention( embed_dim=hidden_dim, num_heads=num_heads, dropout=0.1, batch_first=True, device=device, dtype=dtype ) self.memory_layernorm = nn.LayerNorm(hidden_dim, device=device, dtype=dtype) self.decoder_attention = nn.MultiheadAttention( embed_dim=hidden_dim, num_heads=num_heads, dropout=0.1, batch_first=True, device=device, dtype=dtype ) self.decoder_layernorm = nn.LayerNorm(hidden_dim, device=device, dtype=dtype) self.decoder_ffn = nn.Sequential( nn.Linear(hidden_dim, ff_dim, device=device, dtype=dtype), nn.ReLU(), nn.Linear(ff_dim, hidden_dim, device=device, dtype=dtype) ) def forward(self, memory_input_sequence: torch.Tensor): batch_size = memory_input_sequence.shape[0] encoded_vectors = self.encoder(memory_input_sequence) queries = self.memory_queries.expand(batch_size, -1, -1) compressed_memory, _ = self.memory_attention(query=queries, key=encoded_vectors, value=encoded_vectors) compressed_memory = self.memory_layernorm(compressed_memory + queries) reconstructed, _ = self.decoder_attention(query=encoded_vectors, key=compressed_memory, value=compressed_memory) reconstructed_vectors = self.decoder_layernorm(reconstructed + encoded_vectors) reconstructed_vectors = self.decoder_ffn(reconstructed_vectors) return compressed_memory, reconstructed_vectors # --- BUILDING BLOCK 2: The Custom Layer (With Iterative Self-Correction) --- class GCVectorMemoryLayer(nn.Module): def __init__(self, original_layer: nn.Linear, global_input_dim: int, memory_dim: int, num_memory_slots: int, memory_num_heads: int, global_state_storage: Dict): super().__init__() self.input_dim = original_layer.in_features self.output_dim = original_layer.out_features self.memory_dim = memory_dim self.global_state_storage = global_state_storage self.linear = original_layer device, dtype = self.linear.weight.device, self.linear.weight.dtype # This part is correct: initialize with the correct dtype self.local_state_proj = nn.Linear(self.input_dim, memory_dim, device=device, dtype=dtype) self.global_state_proj = nn.Linear(global_input_dim, memory_dim, device=device, dtype=dtype) self.memory_head = VectorMemoryHead( hidden_dim=memory_dim, num_memory_slots=num_memory_slots, num_heads=memory_num_heads, ff_dim=memory_dim * 2, device=device, dtype=dtype ) self.correction_head = nn.Linear(memory_dim, 2 * self.output_dim, device=device, dtype=dtype) # --- NEW: Parameter for iterative self-correction --- # This can be changed at inference time to apply the correction multiple times. # Default is 1 to match training behavior. self.num_correction_passes: int = 1 self.last_corrected_activation: Optional[torch.Tensor] = None self.last_additive_correction: Optional[torch.Tensor] = None self.last_memory_input: Optional[torch.Tensor] = None self.last_reconstructed_from_memory: Optional[torch.Tensor] = None def forward(self, x: torch.Tensor): base_output = self.linear(x) # If no global state is available or correction is disabled, return base output. if 'embeds' not in self.global_state_storage or self.num_correction_passes < 1: return base_output global_embeds = self.global_state_storage['embeds'] if global_embeds.shape[1] != x.shape[1]: global_embeds = global_embeds[:, -x.shape[1]:, :] B, S, _ = x.shape with torch.enable_grad(): # --- 1. Calculate the correction signal ONCE --- proj_local = self.local_state_proj(x) proj_global = self.global_state_proj(global_embeds) memory_input = torch.stack([proj_global, proj_local], dim=2) memory_input_flat = memory_input.view(B * S, 2, self.memory_dim) compressed_mem_flat, recon_flat = self.memory_head(memory_input_flat) aggregated_thought_flat = compressed_mem_flat.mean(dim=1) aggregated_thought = aggregated_thought_flat.view(B, S, self.memory_dim) raw_correction = self.correction_head(aggregated_thought) gate, value = torch.chunk(raw_correction, 2, dim=-1) # --- 2. Iteratively apply the correction signal --- corrected_activation = base_output for _ in range(self.num_correction_passes): corrected_activation = corrected_activation * torch.sigmoid(gate.to(x.dtype)) + value.to(x.dtype) # During training, store the final activation and the original correction signal # for loss calculation. if self.training: self.last_corrected_activation = corrected_activation self.last_additive_correction = value # The 'value' is the core additive signal self.last_memory_input = memory_input_flat self.last_reconstructed_from_memory = recon_flat return corrected_activation # --- BUILDING BLOCK 3: The Full Custom Model --- class Phi3WithVectorMemoryForCausalLM(Phi3ForCausalLM): def __init__(self, config): super().__init__(config) self.global_state_storage = {} self.target_layer_path = "model.layers.15.mlp.gate_up_proj" self.model.embed_tokens.register_forward_hook( lambda module, input, output: self.global_state_storage.update({'embeds': output.detach()}) ) try: original_layer = self.get_submodule(self.target_layer_path) custom_layer = GCVectorMemoryLayer( original_layer=original_layer, global_input_dim=config.hidden_size, memory_dim=64, num_memory_slots=8, memory_num_heads=4, global_state_storage=self.global_state_storage ) parent_path = ".".join(self.target_layer_path.split('.')[:-1]) child_name = self.target_layer_path.split('.')[-1] setattr(self.get_submodule(parent_path), child_name, custom_layer) print(f"Successfully replaced '{self.target_layer_path}' with GCVectorMemoryLayer.") except AttributeError: print(f"Could not find target layer '{self.target_layer_path}'. Model remains unmodified.") # --- END OF FILE architecture.py ---