| | |
| | |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | class RecurrentReasoningBlock(nn.Module): |
| | """ |
| | System 2 Thinking Module. |
| | Refines the latent representation through N steps of recurrence. |
| | Formula: z_{t+1} = z_t + MLP(LayerNorm(z_t)) |
| | """ |
| | def __init__(self, d_model, thinking_steps=3, dropout=0.1): |
| | super().__init__() |
| | self.d_model = d_model |
| | self.thinking_steps = thinking_steps |
| | |
| | |
| | |
| | self.think_mlp = nn.Sequential( |
| | nn.Linear(d_model, 4 * d_model), |
| | nn.GELU(), |
| | nn.Linear(4 * d_model, d_model), |
| | nn.Dropout(dropout) |
| | ) |
| | |
| | self.norm = nn.LayerNorm(d_model) |
| | |
| | |
| | |
| | self.gate = nn.Linear(d_model, d_model) |
| |
|
| | def forward(self, x): |
| | """ |
| | Args: |
| | x: (Batch, Seq_Len, d_model) - Initial Latent (System 1 output) |
| | Returns: |
| | x: Refined Latent (System 2 output) |
| | """ |
| | |
| | |
| | |
| | current_thought = x |
| | |
| | for _ in range(self.thinking_steps): |
| | |
| | normed = self.norm(current_thought) |
| | |
| | |
| | update = self.think_mlp(normed) |
| | |
| | |
| | |
| | g = torch.sigmoid(self.gate(normed)) |
| | |
| | |
| | current_thought = current_thought + (g * update) |
| | |
| | return current_thought |
| |
|