ataeff commited on
Commit
2ca7d54
·
verified ·
1 Parent(s): b4f66a6

Add model.py

Browse files
Files changed (1) hide show
  1. model.py +233 -0
model.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Resonance 200M — Content + RRPRAM dual attention transformer.
3
+ Low-rank RRPRAM (Wr = Wr_a @ Wr_b), SwiGLU MLP, RMSNorm, RoPE.
4
+ Content attention uses FlashAttention via F.scaled_dot_product_attention.
5
+
6
+ Architecture matches resonance-bpe.c (with low-rank extension).
7
+ """
8
+
9
+ import math
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import torch.utils.checkpoint
14
+
15
+
16
+ class RMSNorm(nn.Module):
17
+ def __init__(self, dim, eps=1e-5):
18
+ super().__init__()
19
+ self.eps = eps
20
+ self.weight = nn.Parameter(torch.ones(dim))
21
+
22
+ def forward(self, x):
23
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
24
+
25
+
26
+ class ResonanceBlock(nn.Module):
27
+ """
28
+ Dual attention block: Content (QKV + RoPE + FlashAttn) + RRPRAM (low-rank Wr) + SwiGLU MLP.
29
+ """
30
+
31
+ def __init__(self, config):
32
+ super().__init__()
33
+ E = config['n_embd']
34
+ H = config['n_head']
35
+ D = config['head_dim']
36
+ R = config['rrpram_rank']
37
+ T = config['context_len']
38
+ M = config['ffn_dim']
39
+
40
+ self.n_head = H
41
+ self.head_dim = D
42
+ self.n_embd = E
43
+
44
+ # Pre-attention norm
45
+ self.norm1 = RMSNorm(E)
46
+
47
+ # Content attention (MHA): Q, K, V
48
+ self.wq = nn.Linear(E, H * D, bias=False)
49
+ self.wk = nn.Linear(E, H * D, bias=False)
50
+ self.wv = nn.Linear(E, H * D, bias=False)
51
+
52
+ # RRPRAM attention (low-rank): Wr_a[H, E, R] @ Wr_b[H, R, T] = Wr[H, E, T]
53
+ self.wr_a = nn.Parameter(torch.randn(H, E, R) * (2.0 / E) ** 0.5)
54
+ self.wr_b = nn.Parameter(torch.randn(H, R, T) * (2.0 / R) ** 0.5)
55
+
56
+ # Per-head gate: sigmoid(gate) blends content vs RRPRAM
57
+ self.gate = nn.Parameter(torch.zeros(H)) # init 0 → sigmoid(0) = 0.5 = balanced
58
+
59
+ # Output projection
60
+ self.wo = nn.Linear(E, E, bias=False)
61
+
62
+ # Pre-MLP norm
63
+ self.norm2 = RMSNorm(E)
64
+
65
+ # SwiGLU MLP
66
+ self.mlp_gate = nn.Linear(E, M, bias=False)
67
+ self.mlp_up = nn.Linear(E, M, bias=False)
68
+ self.mlp_down = nn.Linear(M, E, bias=False)
69
+
70
+ # Init output projections with smaller std (GPT-2 convention)
71
+ n_layer = config['n_layer']
72
+ nn.init.normal_(self.wo.weight, std=0.02 / math.sqrt(2 * n_layer))
73
+ nn.init.normal_(self.mlp_down.weight, std=0.02 / math.sqrt(2 * n_layer))
74
+
75
+ def forward(self, x, rope_cos, rope_sin, mask):
76
+ B, T, E = x.shape
77
+ H = self.n_head
78
+ D = self.head_dim
79
+
80
+ # Pre-norm
81
+ xn = self.norm1(x)
82
+
83
+ # === Content attention with RoPE + FlashAttention ===
84
+ q = self.wq(xn).view(B, T, H, D).transpose(1, 2) # [B, H, T, D]
85
+ k = self.wk(xn).view(B, T, H, D).transpose(1, 2)
86
+ v = self.wv(xn).view(B, T, H, D).transpose(1, 2)
87
+
88
+ # Apply RoPE to Q and K
89
+ q = _apply_rope(q, rope_cos, rope_sin)
90
+ k = _apply_rope(k, rope_cos, rope_sin)
91
+
92
+ # FlashAttention — O(T) memory instead of O(T²)
93
+ c_out = F.scaled_dot_product_attention(q, k, v, is_causal=True) # [B, H, T, D]
94
+
95
+ # === RRPRAM attention (low-rank) ===
96
+ # Wr = Wr_a @ Wr_b: [H, E, R] @ [H, R, T] = [H, E, T]
97
+ # Score: xn @ Wr → [B, T, E] @ [H, E, T] → [B, H, T, T]
98
+ xn_h = xn.unsqueeze(1).expand(-1, H, -1, -1) # [B, H, T, E]
99
+ # Low-rank: (xn @ Wr_a) @ Wr_b
100
+ temp = torch.einsum('bhie,her->bhir', xn_h, self.wr_a) # [B, H, T, R]
101
+ r_attn = torch.einsum('bhir,hrj->bhij', temp, self.wr_b[:, :, :T]) # [B, H, T, T]
102
+ r_attn = r_attn * (D ** -0.5)
103
+ r_attn = r_attn.masked_fill(mask, float('-inf'))
104
+ r_attn = F.softmax(r_attn, dim=-1)
105
+ r_out = r_attn @ v # [B, H, T, D] — shared V with content
106
+
107
+ # === Gate: blend content and RRPRAM ===
108
+ g = torch.sigmoid(self.gate).view(1, H, 1, 1) # [1, H, 1, 1]
109
+ attn_out = g * c_out + (1 - g) * r_out # [B, H, T, D]
110
+
111
+ # Output projection + residual
112
+ attn_out = attn_out.transpose(1, 2).contiguous().view(B, T, E)
113
+ x = x + self.wo(attn_out)
114
+
115
+ # === SwiGLU MLP ===
116
+ xn = self.norm2(x)
117
+ gate = F.silu(self.mlp_gate(xn))
118
+ up = self.mlp_up(xn)
119
+ x = x + self.mlp_down(gate * up)
120
+
121
+ return x
122
+
123
+
124
+ def _apply_rope(x, cos, sin):
125
+ """Apply RoPE to tensor x: [B, H, T, D]."""
126
+ x1 = x[..., ::2] # even dims
127
+ x2 = x[..., 1::2] # odd dims
128
+ out = torch.stack([
129
+ x1 * cos - x2 * sin,
130
+ x1 * sin + x2 * cos,
131
+ ], dim=-1).flatten(-2)
132
+ return out
133
+
134
+
135
+ class Resonance(nn.Module):
136
+ """
137
+ Resonance 200M: dual attention (Content + RRPRAM) transformer.
138
+ """
139
+
140
+ def __init__(self, config):
141
+ super().__init__()
142
+ self.config = config
143
+ V = config['vocab_size']
144
+ E = config['n_embd']
145
+ T = config['context_len']
146
+ D = config['head_dim']
147
+
148
+ # Token embedding (no position — RoPE handles it)
149
+ self.tok_emb = nn.Embedding(V, E)
150
+ nn.init.normal_(self.tok_emb.weight, std=0.02)
151
+
152
+ # Transformer blocks
153
+ self.blocks = nn.ModuleList([
154
+ ResonanceBlock(config) for _ in range(config['n_layer'])
155
+ ])
156
+
157
+ # Final norm + output head (untied from embedding)
158
+ self.norm_f = RMSNorm(E)
159
+ self.out_head = nn.Linear(E, V, bias=False)
160
+ nn.init.normal_(self.out_head.weight, std=0.02)
161
+
162
+ # Precompute RoPE
163
+ freqs = 1.0 / (10000.0 ** (torch.arange(0, D, 2).float() / D))
164
+ t = torch.arange(T).float()
165
+ angles = torch.outer(t, freqs)
166
+ self.register_buffer('rope_cos', angles.cos().unsqueeze(0).unsqueeze(0)) # [1,1,T,D//2]
167
+ self.register_buffer('rope_sin', angles.sin().unsqueeze(0).unsqueeze(0))
168
+
169
+ # Causal mask (for RRPRAM — content uses is_causal=True in SDPA)
170
+ mask = torch.triu(torch.ones(T, T, dtype=torch.bool), diagonal=1)
171
+ self.register_buffer('causal_mask', mask)
172
+
173
+ n_params = sum(p.numel() for p in self.parameters())
174
+ print(f" [Resonance] {n_params:,} parameters")
175
+ self._report_balance()
176
+
177
+ def _report_balance(self):
178
+ """Report parameter budget distribution."""
179
+ cfg = self.config
180
+ E, H, D = cfg['n_embd'], cfg['n_head'], cfg['head_dim']
181
+ R, T, M = cfg['rrpram_rank'], cfg['context_len'], cfg['ffn_dim']
182
+ V, L = cfg['vocab_size'], cfg['n_layer']
183
+
184
+ emb = V * E * 2 # tok_emb + out_head (untied)
185
+ qkv = L * (3 * E * H * D)
186
+ rrpram = L * (H * E * R + H * R * T + H) # wr_a + wr_b + gate
187
+ wo = L * E * E
188
+ mlp = L * (3 * E * M)
189
+ norms = L * 2 * E + E # per-block norms + final
190
+
191
+ total = emb + qkv + rrpram + wo + mlp + norms
192
+ print(f" [Resonance] Budget: emb={emb/total*100:.1f}% qkv={qkv/total*100:.1f}% "
193
+ f"rrpram={rrpram/total*100:.1f}% wo={wo/total*100:.1f}% "
194
+ f"mlp={mlp/total*100:.1f}% norms={norms/total*100:.1f}%")
195
+
196
+ def set_gradient_checkpointing(self, enable=True):
197
+ self._grad_ckpt = enable
198
+
199
+ def forward(self, idx, targets=None):
200
+ B, T = idx.shape
201
+ x = self.tok_emb(idx)
202
+
203
+ cos = self.rope_cos[:, :, :T, :]
204
+ sin = self.rope_sin[:, :, :T, :]
205
+ mask = self.causal_mask[:T, :T]
206
+
207
+ for block in self.blocks:
208
+ if getattr(self, '_grad_ckpt', False) and self.training:
209
+ x = torch.utils.checkpoint.checkpoint(
210
+ block, x, cos, sin, mask, use_reentrant=False)
211
+ else:
212
+ x = block(x, cos, sin, mask)
213
+
214
+ logits = self.out_head(self.norm_f(x))
215
+
216
+ loss = None
217
+ if targets is not None:
218
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
219
+
220
+ return logits, loss
221
+
222
+
223
+ # === Default config: ~200M params ===
224
+ RESONANCE_200M = {
225
+ 'n_embd': 768,
226
+ 'n_head': 12,
227
+ 'head_dim': 64, # n_embd // n_head
228
+ 'n_layer': 20,
229
+ 'rrpram_rank': 48, # low-rank R
230
+ 'context_len': 2048,
231
+ 'ffn_dim': 2048, # round(8*768/3, 256)
232
+ 'vocab_size': 16384, # 256 + 16128 BPE merges
233
+ }