ml-ryanlee commited on
Commit
4e14e99
·
verified ·
1 Parent(s): 7d5a033

Add trust_remote_code modeling file

Browse files
Files changed (1) hide show
  1. modeling_loop_lm.py +990 -0
modeling_loop_lm.py ADDED
@@ -0,0 +1,990 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Self-contained modeling file for trust_remote_code use.
2
+
3
+ This file merges mup_models.py and hf_wrapper.py into a single module with no
4
+ imports from looped_scaling.*. It is intended to be placed alongside a
5
+ config.json that sets ``auto_map`` / ``model_type = "loop-lm"`` so that
6
+ HuggingFace's ``from_pretrained(..., trust_remote_code=True)`` can load it
7
+ without requiring the looped_scaling package to be installed.
8
+
9
+ Supported model variants: "base" (MuTransformer), "looped" (LoopedTransformer),
10
+ "moe" (MoETransformer), "looped-moe" (LoopedMoETransformer).
11
+ """
12
+
13
+ import torch
14
+ import math
15
+ import sys
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from collections.abc import Callable, Iterable
19
+ from einops import rearrange, einsum, reduce, repeat
20
+ from typing import IO, Any, BinaryIO, Optional
21
+ from torch import Tensor
22
+ from collections import Counter, defaultdict
23
+ from torch.nn.functional import scaled_dot_product_attention as sdpa # for flash attention
24
+ from torch.nn.functional import grouped_mm, silu
25
+ from transformers import PretrainedConfig, PreTrainedModel, AutoConfig, AutoModelForCausalLM
26
+ from transformers.generation import GenerationMixin
27
+ from transformers.modeling_outputs import CausalLMOutputWithPast
28
+
29
+ BASE_D_MODEL = 128
30
+ BASE_D_FF = 384
31
+
32
+ """ Standard Transformer and Components implemented with muP """
33
+
34
+
35
+ # ---------------------------------------------------------------------------
36
+ # Numerically stable softmax (inlined from looped_scaling/utils.py)
37
+ # ---------------------------------------------------------------------------
38
+
39
+ def softmax(logits: Tensor, dim: int) -> Tensor:
40
+ logits = logits.float()
41
+ # get max values over specified dimension
42
+ max_values = torch.max(logits, dim=dim, keepdim=True).values
43
+
44
+ # subtract max_values from x so max element is 0
45
+ shifted = logits - max_values # broadcast should work
46
+
47
+ # get exp of shifted terms
48
+ shifted_exps = torch.exp(shifted)
49
+
50
+ # get sum of shifted terms
51
+ shifted_exp_sums = torch.sum(shifted_exps, dim=dim, keepdim=True)
52
+
53
+ # calculate product
54
+ product = shifted_exps / shifted_exp_sums
55
+
56
+ return product
57
+
58
+
59
+ # y = Wx (no bias terms!)
60
+ class Linear(nn.Module):
61
+ def __init__(self, in_features, out_features, width_ratio, std_base, device=None, dtype=None):
62
+ super().__init__()
63
+
64
+ # initialize weights matrix
65
+ weights = torch.empty(out_features, in_features, dtype=dtype, device=device)
66
+
67
+ # for muP, derive initial std deviation from given base model's std_deviation and width ratio
68
+ std_scaled = std_base / math.sqrt(width_ratio)
69
+ weights = nn.init.trunc_normal_(weights, mean=0.0, std=std_scaled, a=-3*std_scaled, b=3*std_scaled)
70
+
71
+ # assign as instance variable
72
+ self.weight = nn.Parameter(weights)
73
+
74
+ def forward(self, x: Tensor) -> Tensor:
75
+ # Pytorch standard: on input side of expression, d_in is last dim of x so "... d_in"
76
+ # on output side of einsum expression, so "... d_out" follows convention
77
+ # to put the output dim last
78
+ return einsum(self.weight, x, "d_out d_in, ... d_in -> ... d_out")
79
+
80
+ class Embedding(nn.Module):
81
+ def __init__(self, num_embeddings, embedding_dim, device=None, dtype=None):
82
+ super().__init__()
83
+
84
+ # initialize a matrix of vocab_size x embedding_dim
85
+ embeddings = torch.empty(num_embeddings, embedding_dim, dtype=dtype, device=device)
86
+
87
+ # normalize the embeddings to spec
88
+ embeddings = nn.init.trunc_normal_(embeddings, mean=0.0, std=1.0, a=-3, b=3)
89
+
90
+ # save and enroll as torch param
91
+ self.weight = nn.Parameter(embeddings)
92
+
93
+ def forward(self, token_ids: Tensor) -> Tensor:
94
+ # for every id, we need to pull the row vector associated
95
+ return self.weight[token_ids]
96
+
97
+ class RMSNorm(nn.Module):
98
+ def __init__(self, d_model: int, eps: float = 1e-5, device=None, dtype=None):
99
+ super().__init__()
100
+
101
+ # for muP no gain parameter on the rms
102
+ self.d_model = d_model
103
+ self.eps = eps
104
+
105
+ def forward(self, x: Tensor) -> Tensor:
106
+ # upcast input to torch.float32
107
+ in_dtype = x.dtype
108
+ x = x.to(torch.float32)
109
+
110
+ # calculate the RMS scalar
111
+ # scalar for every ex. in batch, for every emb in sequence
112
+ mean_squared_sum = (1/self.d_model)*einsum(x, x, "... seq d, ... seq d -> ... seq")
113
+ rms = torch.sqrt(mean_squared_sum + self.eps)
114
+
115
+ # for muP, no gain on rms norm as is normally applied.
116
+ rms_norm = einsum(x, 1/rms, "... seq d, ... seq -> ... seq d")
117
+
118
+ # return result to original dtype
119
+ return rms_norm.to(in_dtype)
120
+
121
+ class PositionwiseFeedforward(nn.Module):
122
+ # SwiGLU(x) = W2(SiLU(W1x)⊙W3x)
123
+ def __init__(self, d_model: int, d_ff: int, width_ratio: float, device=None, dtype=None):
124
+ super().__init__()
125
+
126
+ # for muP, calculate the base model's standard deviation
127
+ w_std_base = math.sqrt(2/(BASE_D_MODEL+BASE_D_FF)) # same for all W because d_model+d_ff = d_ff+d_model
128
+
129
+ # initialize parameters of SWiGLU FFN
130
+ self.w1 = Linear(d_model, d_ff, width_ratio, w_std_base, device=device, dtype=dtype)
131
+ self.w2 = Linear(d_ff, d_model, width_ratio, w_std_base, device=device, dtype=dtype)
132
+ self.w3 = Linear(d_model, d_ff, width_ratio, w_std_base, device=device, dtype=dtype)
133
+
134
+ def forward(self, x: Tensor) -> Tensor:
135
+ # FFN = W2*(SiLU(W1*X) dot W3X)
136
+ silu_in = self.w1(x)
137
+ silu_out = silu(silu_in) # silu_in * torch.sigmoid(silu_in)
138
+ gate = self.w3(x)
139
+ gated_prod = silu_out * gate
140
+ final_prod = self.w2(gated_prod)
141
+ return final_prod
142
+
143
+ class RotaryPositionalEmbedding(nn.Module):
144
+ def __init__(self, theta: float, d_k: int, max_seq_len: int, device=None, dtype=None):
145
+ """
146
+ theta: float Θ value for the RoPE
147
+ d_k: int dimension of query and key vectors
148
+ max_seq_len: int Maximum sequence length that will be inputted
149
+ device: torch.device | None = None Device to store the buffer on
150
+ """
151
+ super().__init__()
152
+ rotations = torch.empty(max_seq_len, d_k//2, 2, 2, device=device, dtype=dtype)
153
+
154
+ # initialize rotation matrix
155
+ for i in range(max_seq_len):
156
+ for k in range(d_k//2):
157
+ angle = i/(theta**(2*k/d_k))
158
+ rot = Tensor([[math.cos(angle), -math.sin(angle)],
159
+ [math.sin(angle), math.cos(angle)]])
160
+ rotations[i, k, :] = rot
161
+
162
+ self.register_buffer("rotations", rotations, persistent=True)
163
+
164
+
165
+ def forward(self, x: Tensor, token_positions: Tensor) -> Tensor:
166
+ """
167
+ self.rotations shape: (seq_dim, feature_dim, 2, 2)
168
+ x: tensor of shape (..., seq_dim, feature_dim)
169
+ token_positions: tensor of shape (..., seq_dim)
170
+ """
171
+ # get the correct rotation matrices
172
+ # by default, 0'th dim of array_indexed is index dim, last dim of indices is feature dim
173
+ rot = self.rotations[token_positions].to(dtype=x.dtype) # match activation dtype (buffer is float32, activations may be bfloat16)
174
+
175
+ # rearrange by every two elements along feature dim of input x
176
+ x_pairs = rearrange(x, "... seq_dim (feature_dim i) -> ... seq_dim feature_dim i", i=2)
177
+
178
+ # apply rotations to these. for each pairwise position is A@x->y : (ixj)@(j,)->(i,)
179
+ y_pairs = einsum(rot, x_pairs, "... seq_dim feature_dim i j, ... seq_dim feature_dim j -> ... seq_dim feature_dim i")
180
+
181
+ # reshape y_pairs back to original shape
182
+ y = rearrange(y_pairs, "... seq_dim feature_dim i -> ... seq_dim (feature_dim i)")
183
+
184
+ return y
185
+
186
+ def scaled_dot_product_attention(
187
+ Q: Tensor,
188
+ K: Tensor,
189
+ V: Tensor,
190
+ mask: Optional[Tensor] = None,
191
+ ) -> Tensor:
192
+ """
193
+ Given key (K), query (Q), and value (V) tensors, return
194
+ the output of your scaled dot product attention implementation.
195
+
196
+ Args:
197
+ let m be seq length of inputs, n be seq length of outputs
198
+ d_k is look-up dim, d_v is value dim
199
+ Q (Float[Tensor, "batch ... n d_k"]): Query tensor
200
+ K (Float[Tensor, "batch ... m d_k"]): Key tensor
201
+ V (Float[Tensor, "batch ... m d_v"]): Values tensor
202
+ mask (Float[Tensor, " ... n m"] | None): Mask tensor
203
+ Returns:
204
+ Float[Tensor, " ... n d_v"]: Output of SDPA
205
+ """
206
+
207
+ # get the key feature dim (should be last dim of Q and K)
208
+ d_k = Q.shape[-1]
209
+ assert d_k == K.shape[-1]
210
+
211
+ # calculate the weighted scores (similarity product). for muP, scale by d_k not sqrt(d_k)
212
+ scores = einsum(Q, K, "... n d_k, ... m d_k -> ... n m") / d_k
213
+
214
+ # apply the mask if there is one
215
+ if mask is not None:
216
+ bool_mask = mask.bool() # compatible if somehow, input is mask bool or if float
217
+ attn_mask = torch.where(bool_mask, 0.0, float('-inf')).to(scores.dtype)
218
+ scores = scores + attn_mask
219
+
220
+ # calculate the weighted
221
+ weights = softmax(scores, dim=-1) # the softmax should be taken over the m inputs at an i'th output pos.
222
+
223
+ # return weights@V
224
+ return einsum(weights, V, "... n m, ... m d_v -> ... n d_v")
225
+
226
+ class MultiheadSelfAttention(nn.Module):
227
+ """
228
+ Args:
229
+ d_model (int): Dimensionality of the feedforward input and output.
230
+ num_heads (int): Number of heads to use in multi-headed attention.
231
+ max_seq_len (int): Maximum sequence length to pre-cache if your implementation does that.
232
+ q_proj_weight (Float[Tensor, "d_k d_in"]): Weights for the Q projection
233
+ k_proj_weight (Float[Tensor, "d_k d_in"]): Weights for the K projection
234
+ v_proj_weight (Float[Tensor, "d_k d_in"]): Weights for the V projection
235
+ o_proj_weight (Float[Tensor, "d_model d_v"]): Weights for the output projection
236
+ in_features (Float[Tensor, "... sequence_length d_in"]): Tensor to run your implementation on.
237
+
238
+ Returns:
239
+ Float[Tensor, " ... sequence_length d_out"]: Tensor with the output of running your optimized, batched multi-headed attention
240
+ implementation with the given QKV projection weights and input features.
241
+ """
242
+ def __init__(self, d_model: int, num_heads: int, max_seq_len: int = None, theta: float = None, width_ratio: float = 1.0, device=None, dtype=None):
243
+ super().__init__()
244
+
245
+ # initialize the multi-head self attention weights as 1 large matrix (which will be sliced)
246
+ assert d_model % num_heads == 0, f"d_model ({d_model}) must be divisible by num_heads ({num_heads})"
247
+
248
+ self.d_model = d_model
249
+ self.num_heads = num_heads
250
+
251
+ # for muP, calculate standard deviation of base model
252
+ attn_std_base = math.sqrt(2/(BASE_D_MODEL+BASE_D_MODEL))
253
+
254
+ # for muP, initialize the Wq,Wk,Wv,Wo linear weights with width_ratio and base model's stddev
255
+ self.q_proj = Linear(d_model, d_model, width_ratio, attn_std_base, device=device, dtype=dtype)
256
+ self.k_proj = Linear(d_model, d_model, width_ratio, attn_std_base, device=device, dtype=dtype)
257
+ self.v_proj = Linear(d_model, d_model, width_ratio, attn_std_base, device=device, dtype=dtype)
258
+ self.output_proj = Linear(d_model, d_model, width_ratio, attn_std_base, device=device, dtype=dtype)
259
+
260
+ # # Removed for torch sdpa, uncomment if using normal code
261
+ # if max_seq_len:
262
+ # causal_mask = torch.tril(torch.ones(max_seq_len, max_seq_len, dtype=dtype, device=device))
263
+ # self.register_buffer("causal_mask", causal_mask, persistent=False)
264
+ # else:
265
+ # self.register_buffer("causal_mask", None, persistent=False)
266
+
267
+ assert theta is None or max_seq_len is not None, "max_seq_len must be provided when theta is given for multi-head self attention with RoPE."
268
+
269
+ if theta:
270
+ d_k = d_model//num_heads
271
+ self.rope = RotaryPositionalEmbedding(theta, d_k, max_seq_len, device, dtype)
272
+ else:
273
+ self.rope = None
274
+
275
+ def forward(self, x: Tensor, token_positions: Optional[Tensor] = None) -> Tensor:
276
+ # get Q, K, V matrices
277
+ Q = self.q_proj(x) # output shape is [batch seq d_model]
278
+ K = self.k_proj(x)
279
+ V = self.v_proj(x)
280
+
281
+ # #create causal mask intepreting the second to last dim as seq dim
282
+ # if self.causal_mask is None:
283
+ # seq_dim = x.shape[-2]
284
+ # cmask = torch.tril(torch.ones(seq_dim, seq_dim, dtype=x.dtype, device=x.device))
285
+ # else:
286
+ # # Slice the pre-computed mask to match actual sequence length (could be < than max_seq_len)
287
+ # seq_dim = x.shape[-2]
288
+ # cmask = self.causal_mask[:seq_dim, :seq_dim]
289
+
290
+ # get slice size for multi-head self attention
291
+ d_k = self.d_model // self.num_heads
292
+ d_v = self.d_model // self.num_heads
293
+
294
+ q_heads = rearrange(Q, "... seq (heads d_k) -> ... heads seq d_k", d_k=d_k)
295
+ k_heads = rearrange(K, "... seq (heads d_k) -> ... heads seq d_k", d_k=d_k)
296
+
297
+ # apply RoPE to q_heads and k_heads
298
+ if self.rope:
299
+ seq_dim = x.shape[-2] # x is (b,s,d)
300
+ if token_positions is None:
301
+ token_positions = torch.arange(seq_dim, device=x.device)
302
+ token_positions = rearrange(token_positions, "seq -> 1 seq") # 1 seq allows broadcast across batch dim
303
+
304
+ q_heads = self.rope(q_heads, token_positions)
305
+ k_heads = self.rope(k_heads, token_positions)
306
+
307
+ v_heads = rearrange(V, "... seq (heads d_v) -> ... heads seq d_v", d_v=d_v)
308
+
309
+ #mha_heads = scaled_dot_product_attention(q_heads, k_heads, v_heads, cmask)
310
+ mha_heads = sdpa(q_heads, k_heads, v_heads, is_causal=True, scale=1.0/d_k)
311
+ mha = rearrange(mha_heads, "... heads seq d_v -> ... seq (heads d_v)")
312
+
313
+ # apply o_proj_weight to the concatenated multi-head attention product
314
+ out = self.output_proj(mha)
315
+
316
+ return out
317
+
318
+ class PrenormBlock(nn.Module):
319
+ def __init__(self,
320
+ d_model: int,
321
+ num_heads: int,
322
+ d_ff: int,
323
+ max_seq_len: int,
324
+ theta: float,
325
+ width_ratio: float,
326
+ device=None,
327
+ dtype=None):
328
+ super().__init__()
329
+ # norm layer
330
+ self.ln1 = RMSNorm(d_model, device=device, dtype=dtype)
331
+ # mhsa with rope
332
+ self.attn = MultiheadSelfAttention(d_model, num_heads, max_seq_len, theta, width_ratio, device, dtype)
333
+ # add step
334
+ # norm layer
335
+ self.ln2 = RMSNorm(d_model, device=device, dtype=dtype)
336
+ # positionwise feed forward
337
+ self.ffn = PositionwiseFeedforward(d_model, d_ff, width_ratio, device, dtype)
338
+ # add to output
339
+
340
+ def forward(self, x: Tensor, token_positions: Optional[Tensor] = None) -> Tensor:
341
+
342
+ # first Tx operation, Norm + MHSA w/ RoPE
343
+ norm1_out = self.ln1(x)
344
+ # we may have to define token_positions if it is not given
345
+ attn_out = self.attn(norm1_out, token_positions)
346
+
347
+ # ensure no broadcasting, elementwise addition on [batch seq d_model]
348
+ assert(x.shape == attn_out.shape)
349
+ resid1_out = attn_out + x
350
+
351
+ # second Tx operation, Norm + SwiGLU
352
+ norm2_out = self.ln2(resid1_out)
353
+ ffn_out = self.ffn(norm2_out)
354
+
355
+ # ensure no broadcasting, elementwise addition
356
+ assert(ffn_out.shape == resid1_out.shape)
357
+ final_out = resid1_out + ffn_out
358
+ return final_out
359
+
360
+ class MuTransformer(nn.Module):
361
+ def __init__(
362
+ self, vocab_size: int,
363
+ context_length: int,
364
+ d_model: int,
365
+ num_layers: int,
366
+ num_heads: int,
367
+ d_ff: int,
368
+ rope_theta: float,
369
+ width_ratio: float = 1.0,
370
+ weight_tying: bool = False,
371
+ device=None, dtype=None):
372
+ super().__init__()
373
+ self.token_embeddings = Embedding(vocab_size, d_model, device=device, dtype=dtype)
374
+ self.layers = nn.ModuleList([PrenormBlock(d_model, num_heads, d_ff, context_length, rope_theta, width_ratio, device, dtype) for _ in range(num_layers)])
375
+ self.ln_final = RMSNorm(d_model, device=device, dtype=dtype)
376
+ self.weight_tying = weight_tying
377
+ if weight_tying:
378
+ self.lm_head = self.token_embeddings.weight
379
+ else:
380
+ std_base_lm_head = math.sqrt(2/(BASE_D_MODEL+vocab_size))
381
+ self.lm_head = Linear(d_model, vocab_size, width_ratio=width_ratio, std_base=std_base_lm_head, device=device, dtype=dtype)
382
+ self.width_ratio = width_ratio
383
+
384
+ def forward(self, x: Tensor) -> Tensor:
385
+ # 1. token embed step, no muP alpha_in
386
+ x = self.token_embeddings(x)
387
+
388
+ # 2. prenorm blocks step
389
+ for layer in self.layers:
390
+ x = layer(x)
391
+
392
+ # 3. Final norm
393
+ x = self.ln_final(x)
394
+
395
+ # 4. unembed layer, muP implemented as scaling on init variance and lr of lm_head, not output scaling
396
+ if self.weight_tying:
397
+ x = einsum(x, self.lm_head, "... s d, v d -> ... s v")/self.width_ratio
398
+ else:
399
+ x = self.lm_head(x)
400
+
401
+ # 5. return output, no muP alpha_out
402
+ return x
403
+
404
+ """ Looped Language Models implemented with MuP """
405
+
406
+ class LoopedStack(nn.Module):
407
+ def __init__(
408
+ self,
409
+ context_length: int,
410
+ d_model: int,
411
+ num_layers_in_stack: int,
412
+ num_heads: int,
413
+ d_ff: int,
414
+ rope_theta: float,
415
+ width_ratio: float = 1.0,
416
+ mixture_of_experts: bool = False,
417
+ num_experts: Optional[int] = None,
418
+ num_active: Optional[int] = None,
419
+ device=None, dtype=None):
420
+ super().__init__()
421
+ if mixture_of_experts:
422
+ # self.layers = nn.ModuleList([MoEPrenormBlock(d_model,num_heads,d_ff,num_experts,num_active,
423
+ # context_length,rope_theta,width_ratio,device,dtype)
424
+ # for _ in range(num_layers_in_stack)])
425
+ self.layers = nn.ModuleList([GroupedMoEPrenormBlock(d_model, num_heads, d_ff, num_experts, num_active,
426
+ context_length, rope_theta, width_ratio, device, dtype)
427
+ for _ in range(num_layers_in_stack)])
428
+ else:
429
+ self.layers = nn.ModuleList([PrenormBlock(d_model, num_heads, d_ff, context_length, rope_theta,
430
+ width_ratio, device, dtype) for _ in range(num_layers_in_stack)])
431
+ self.mixture_of_experts = mixture_of_experts
432
+
433
+ def forward(self, x: Tensor) -> Tensor:
434
+ # prenorm blocks step
435
+ if self.mixture_of_experts:
436
+ lb_total = 0
437
+ lz_total = 0
438
+ # sum up load balancing and z-losses across each layer
439
+ for layer in self.layers:
440
+ x, lb, lz = layer(x)
441
+ lb_total += lb
442
+ lz_total += lz
443
+ return x, lb_total, lz_total
444
+ else:
445
+ for layer in self.layers:
446
+ x = layer(x)
447
+ return x
448
+
449
+ class LoopedTransformer(nn.Module):
450
+ def __init__(
451
+ self,
452
+ vocab_size: int,
453
+ context_length: int,
454
+ d_model: int,
455
+ num_layers_in_stack: int,
456
+ num_stacks: int,
457
+ num_heads: int,
458
+ d_ff: int,
459
+ rope_theta: float,
460
+ width_ratio: float = 1.0,
461
+ weight_tying: bool = False,
462
+ device=None, dtype=None):
463
+ super().__init__()
464
+ self.num_stacks = num_stacks
465
+
466
+ self.token_embeddings = Embedding(vocab_size, d_model, device=device, dtype=dtype)
467
+ self.stack = LoopedStack(context_length, d_model, num_layers_in_stack, num_heads, d_ff, rope_theta, width_ratio, device=device, dtype=dtype)
468
+ self.ln_final = RMSNorm(d_model, device=device, dtype=dtype)
469
+ self.weight_tying = weight_tying
470
+ self.width_ratio = width_ratio
471
+
472
+ if weight_tying:
473
+ self.lm_head = self.token_embeddings.weight
474
+ else:
475
+ std_base_lm_head = math.sqrt(2/(BASE_D_MODEL+vocab_size))
476
+ self.lm_head = Linear(d_model, vocab_size, width_ratio, std_base_lm_head, device=device, dtype=dtype)
477
+
478
+ def forward(self, x: Tensor) -> Tensor:
479
+ # token embed step
480
+ x = self.token_embeddings(x)
481
+
482
+ # repeated calls to stack
483
+ for i in range(self.num_stacks):
484
+ x = self.stack(x)
485
+
486
+ # final norm
487
+ x = self.ln_final(x)
488
+
489
+ # Vocab projection or lm_head
490
+ if self.weight_tying:
491
+ x = einsum(x, self.lm_head, "... s d, v d -> ... s v")/self.width_ratio
492
+ else:
493
+ x = self.lm_head(x)
494
+
495
+ return x
496
+
497
+ """ Mixture-of-Experts Implementation in muP """
498
+
499
+ # Router Class
500
+ class Router(nn.Module):
501
+ def __init__(self, d_model: int, num_experts: int, num_active=None, width_ratio: float = 1.0, device=None, dtype=None):
502
+ super().__init__()
503
+ # router is simply a linear layer. we initialize (d_in, d_out) according to my code
504
+ std_base = math.sqrt(2/(BASE_D_MODEL+num_experts))
505
+ self.gate = Linear(d_model, num_experts, width_ratio, std_base, device=device, dtype=dtype) # adjusted for muP
506
+ self.num_active = num_active
507
+
508
+ def forward(self, x: Tensor):
509
+ # returns scores, top_k_scores, top_k_indices
510
+ logits = self.gate(x) # should be shape (batch, seq, n_routers)
511
+
512
+ # probs
513
+ probs = softmax(logits, dim=-1)
514
+
515
+ # get top_k
516
+ top_scores, top_experts = torch.topk(probs, k=self.num_active, dim=-1)
517
+
518
+ # renormalize the top scores so weighted sum of expert products can be taken
519
+ score_sums = torch.sum(top_scores, dim=-1, keepdim=True) # (batch, seq)
520
+ top_scores = top_scores/score_sums
521
+
522
+ return logits, probs, top_scores, top_experts
523
+
524
+ class MoEPrenormBlock(nn.Module):
525
+ def __init__(self, d_model: int, num_heads: int, d_ff: int, num_experts: int, num_active: int,
526
+ max_seq_len: int, theta: float, width_ratio: float = 1.0, device=None, dtype=None):
527
+ super().__init__()
528
+ # norm layer before mHSA+RoPE
529
+ self.ln1 = RMSNorm(d_model, device=device, dtype=dtype)
530
+
531
+ # mhsa with rope
532
+ self.attn = MultiheadSelfAttention(d_model, num_heads, max_seq_len, theta, width_ratio, device, dtype)
533
+
534
+ # norm layer before position-wise feedfoward
535
+ self.ln2 = RMSNorm(d_model, device=device, dtype=dtype)
536
+
537
+ # router
538
+ self.router = Router(d_model, num_experts, num_active, width_ratio=width_ratio, device=device, dtype=dtype)
539
+
540
+ # save MoE hyperparams
541
+ self.num_experts = num_experts
542
+ self.num_active = num_active
543
+
544
+ # initialize MoE FFNs as a module list
545
+ d_ff_expert = d_ff // num_active
546
+ self.experts = nn.ModuleList([PositionwiseFeedforward(d_model, d_ff_expert, width_ratio, device, dtype) for _ in range(num_experts)]) # adjusted for muP
547
+
548
+ def forward(self, x: Tensor, token_positions: Optional[Tensor] = None) -> Tensor:
549
+ # input dims
550
+ batch, seq, dim = x.shape
551
+
552
+ # first Tx operation, Norm + MHSA w/ RoPE
553
+ norm1_out = self.ln1(x)
554
+ # we may have to define token_positions if it is not given
555
+ attn_out = self.attn(norm1_out, token_positions)
556
+
557
+ # ensure no broadcasting, elementwise addition on [batch seq d_model]
558
+ assert(x.shape == attn_out.shape)
559
+ resid1_out = attn_out + x
560
+
561
+ # prenorm before position-wise feedforward
562
+ norm2_out = self.ln2(resid1_out)
563
+
564
+ # get scores from Router. returns shape (batch,seq,k)
565
+ logits, probs, top_scores, top_experts = self.router(norm2_out) # logits and probs are (batch, seq, n_routers)
566
+ expert_mean_probs = torch.mean(probs, dim=(0, 1)) # take mean across batch and seq dims
567
+
568
+ # apply mixture of experts
569
+ experts_out = torch.zeros_like(norm2_out) # copies shape, device and dtype
570
+ total_tokens_assigned = batch*seq*self.num_active
571
+ lb_sum = 0
572
+
573
+ for expert_idx in range(self.num_experts):
574
+ # get masks for expert selection
575
+ expert_mask = (top_experts == expert_idx)
576
+ embed_mask = expert_mask.any(dim=-1) # if any of the k is expert, we want to transform embed
577
+ if not embed_mask.any(): continue
578
+ pi = expert_mean_probs[expert_idx].item()
579
+ fi = (expert_mask.sum().item())/total_tokens_assigned # num embeds assigned to expert in batch
580
+ lb_sum += fi*pi
581
+
582
+ # extract embeds and weights for activated experts
583
+ weights = top_scores[expert_mask] # (num_embeds)
584
+ expert_embeds = norm2_out[embed_mask] # (num_embeds, hidden_dim)
585
+
586
+ # forward for the correct experts
587
+ expert_out = self.experts[expert_idx](expert_embeds) # Vanilla Implementation
588
+
589
+ # map back to experts output
590
+ experts_out[embed_mask] += weights.unsqueeze(-1)*expert_out # broadcast elementwise multiply by hidden dim
591
+
592
+ # calculate batch's load balancing loss
593
+ lb = self.num_experts*lb_sum
594
+
595
+ # calculate batch's router z loss
596
+ logsumexp = torch.logsumexp(logits.float(), dim=-1)
597
+ lz = torch.mean(logsumexp ** 2)
598
+
599
+ # ensure no broadcasting, elementwise addition
600
+ assert(experts_out.shape == resid1_out.shape)
601
+ final_out = resid1_out + experts_out
602
+ return final_out, lb, lz
603
+
604
+
605
+ class GroupedMoEPrenormBlock(nn.Module):
606
+ @staticmethod
607
+ def _init_expert_weights(num_experts, in_features, out_features, width_ratio, std_base, device, dtype) -> nn.Parameter:
608
+ w = torch.empty(num_experts, in_features, out_features, device=device, dtype=dtype) # (batch, in, out)
609
+ std_scaled = std_base / math.sqrt(width_ratio)
610
+ nn.init.trunc_normal_(w, mean=0.0, std=std_scaled, a=-3*std_scaled, b=3*std_scaled)
611
+ return nn.Parameter(w)
612
+
613
+ def __init__(self, d_model: int, num_heads: int, d_ff: int, num_experts: int, num_active: int,
614
+ max_seq_len: int, theta: float, width_ratio: float = 1.0, device=None, dtype=None):
615
+ super().__init__()
616
+ # norm layer before mHSA+RoPE
617
+ self.ln1 = RMSNorm(d_model, device=device, dtype=dtype)
618
+
619
+ # mhsa with rope
620
+ self.attn = MultiheadSelfAttention(d_model, num_heads, max_seq_len, theta, width_ratio, device, dtype)
621
+
622
+ # norm layer before position-wise feedfoward
623
+ self.ln2 = RMSNorm(d_model, device=device, dtype=dtype)
624
+
625
+ # router
626
+ self.router = Router(d_model, num_experts, num_active, width_ratio=width_ratio, device=device, dtype=dtype)
627
+
628
+ # save MoE hyperparams
629
+ self.num_experts = num_experts
630
+ self.num_active = num_active
631
+
632
+ # initialize MoE FFNs as a module list
633
+ d_ff_expert = d_ff // num_active
634
+
635
+ # expose and stack the MoE SwiGLU weights for all experts. with experts in string, optimizer scales weights by width_ratio
636
+ w_std_base = math.sqrt(2 / (BASE_D_MODEL + BASE_D_FF))
637
+ self.experts_w1 = self._init_expert_weights(num_experts, d_model, d_ff_expert, width_ratio, w_std_base, device, dtype)
638
+ self.experts_w2 = self._init_expert_weights(num_experts, d_ff_expert, d_model, width_ratio, w_std_base, device, dtype)
639
+ self.experts_w3 = self._init_expert_weights(num_experts, d_model, d_ff_expert, width_ratio, w_std_base, device, dtype)
640
+
641
+ def forward(self, x: Tensor, token_positions: Optional[Tensor] = None) -> Tensor:
642
+ batch, seq, dim = x.shape
643
+ total_tokens = batch * seq
644
+
645
+ # first Tx operation, Norm + MHSA w/ RoPE
646
+ norm1_out = self.ln1(x)
647
+ attn_out = self.attn(norm1_out, token_positions)
648
+
649
+ assert(x.shape == attn_out.shape)
650
+ resid1_out = attn_out + x
651
+
652
+ # prenorm before position-wise feedforward
653
+ norm2_out = self.ln2(resid1_out)
654
+
655
+ # get scores from Router. returns shape (batch, seq, k)
656
+ logits, probs, top_scores, top_experts = self.router(norm2_out)
657
+
658
+ # flatten to 2D for grouped_mm
659
+ x_flat = rearrange(norm2_out, 'b s d -> (b s) d') # (total_tokens, d_model)
660
+ flat_expert_ids = rearrange(top_experts, 'b s k -> (b s k)') # (total_tokens * k,)
661
+ flat_scores = rearrange(top_scores, 'b s k -> (b s k)') # (total_tokens * k,)
662
+ flat_positions = torch.arange(total_tokens, device=x.device) # (total_tokens)
663
+ flat_token_ids = repeat(flat_positions, 'n -> (n k)', k=self.num_active) # (total_tokens * k)
664
+
665
+ # sort by expert
666
+ sort_indices = flat_expert_ids.argsort(stable=True)
667
+ sorted_expert_ids = flat_expert_ids[sort_indices]
668
+ sorted_token_ids = flat_token_ids[sort_indices]
669
+ sorted_scores = flat_scores[sort_indices]
670
+ sorted_x = x_flat[sorted_token_ids] # (total_tokens * k, d_model)
671
+
672
+ # build offs (cumulative token counts per expert)
673
+ counts = torch.bincount(sorted_expert_ids, minlength=self.num_experts)
674
+ offs = counts.cumsum(0).to(torch.int32) # (num_experts,)
675
+
676
+ # grouped SwiGLU: W2(SiLU(W1 x) dot W3 x)
677
+ h1 = grouped_mm(sorted_x, self.experts_w1, offs=offs)
678
+ h3 = grouped_mm(sorted_x, self.experts_w3, offs=offs)
679
+ gated = silu(h1) * h3
680
+ expert_out = grouped_mm(gated, self.experts_w2, offs=offs) # (total_tokens * k, d_model)
681
+
682
+ # weight by router scores and scatter-add back
683
+ expert_out = einsum(expert_out, sorted_scores, 'n d, n -> n d')
684
+ output_flat = torch.zeros(total_tokens, dim, device=x.device, dtype=expert_out.dtype)
685
+ output_flat.index_add_(0, sorted_token_ids, expert_out)
686
+
687
+ # reshape back to (batch, seq, d_model)
688
+ experts_out = rearrange(output_flat, '(b s) d -> b s d', b=batch, s=seq)
689
+
690
+ # aux losses
691
+ fi = counts.float() / (total_tokens * self.num_active)
692
+ pi = reduce(probs, 'b s e -> e', 'mean')
693
+ lb = self.num_experts * einsum(fi, pi, 'e, e ->')
694
+
695
+ logsumexp = torch.logsumexp(logits.float(), dim=-1)
696
+ lz = reduce(logsumexp ** 2, '... -> ', 'mean')
697
+
698
+ # residual connection
699
+ assert(experts_out.shape == resid1_out.shape)
700
+ final_out = resid1_out + experts_out
701
+ return final_out, lb, lz
702
+
703
+
704
+ # MoE Implementation
705
+ class MoETransformer(nn.Module):
706
+ def __init__(
707
+ self, vocab_size: int,
708
+ context_length: int,
709
+ d_model: int,
710
+ num_layers: int,
711
+ num_heads: int,
712
+ d_ff: int,
713
+ num_experts: int,
714
+ num_active: int,
715
+ rope_theta: float,
716
+ width_ratio: float = 1.0,
717
+ device=None, dtype=None):
718
+ super().__init__()
719
+ self.token_embeddings = Embedding(vocab_size, d_model, device=device, dtype=dtype)
720
+ self.num_layers = num_layers
721
+ # self.layers = nn.ModuleList([MoEPrenormBlock(d_model,num_heads,d_ff,num_experts,num_active,
722
+ # context_length,rope_theta,width_ratio,device,dtype) for _ in range(num_layers)])
723
+ self.layers = nn.ModuleList([GroupedMoEPrenormBlock(d_model, num_heads, d_ff, num_experts, num_active,
724
+ context_length, rope_theta, width_ratio, device, dtype) for _ in range(num_layers)])
725
+ self.ln_final = RMSNorm(d_model, device=device, dtype=dtype)
726
+
727
+ # only non-tied embeddings now
728
+ std_base_lm_head = math.sqrt(2/(BASE_D_MODEL+vocab_size))
729
+ self.lm_head = Linear(d_model, vocab_size, width_ratio=width_ratio, std_base=std_base_lm_head, device=device, dtype=dtype)
730
+
731
+ def forward(self, x: Tensor) -> Tensor:
732
+ # collect aux losses
733
+ lb_total = 0
734
+ lz_total = 0
735
+
736
+ # 1. token embed step
737
+ x = self.token_embeddings(x)
738
+
739
+ # 2. prenorm blocks step
740
+ for layer in self.layers:
741
+ x, lb, lz = layer(x)
742
+ lb_total += lb
743
+ lz_total += lz
744
+
745
+ # 3. Final norm
746
+ x = self.ln_final(x)
747
+
748
+ # 4. Vocab projection or lm_head
749
+ x = self.lm_head(x)
750
+
751
+ # calculate average layer aux loss
752
+ lb_avg = lb_total / self.num_layers
753
+ lz_avg = lz_total / self.num_layers
754
+
755
+ return x, lb_avg, lz_avg
756
+
757
+ class LoopedMoETransformer(nn.Module):
758
+ def __init__(
759
+ self, vocab_size: int,
760
+ context_length: int,
761
+ d_model: int,
762
+ num_layers_in_stack: int,
763
+ num_stacks: int,
764
+ num_heads: int,
765
+ d_ff: int,
766
+ num_experts: int,
767
+ num_active: int,
768
+ rope_theta: float,
769
+ width_ratio: float,
770
+ device=None, dtype=None):
771
+ super().__init__()
772
+ self.stack_depth = num_stacks
773
+ self.total_layers = num_stacks*num_layers_in_stack
774
+ self.token_embeddings = Embedding(vocab_size, d_model, device=device, dtype=dtype)
775
+ self.stack = LoopedStack(context_length, d_model, num_layers_in_stack, num_heads,
776
+ d_ff, rope_theta, width_ratio, mixture_of_experts=True,
777
+ num_experts=num_experts, num_active=num_active,
778
+ device=device, dtype=dtype) # parameters for loop with MoE
779
+ self.ln_final = RMSNorm(d_model, device=device, dtype=dtype)
780
+
781
+ # scale lm head
782
+ std_base_lm_head = math.sqrt(2/(BASE_D_MODEL+vocab_size))
783
+ self.lm_head = Linear(d_model, vocab_size, width_ratio=width_ratio, std_base=std_base_lm_head, device=device, dtype=dtype)
784
+
785
+
786
+ def forward(self, x: Tensor) -> Tensor:
787
+ # collect aux losses
788
+ lb_total = 0
789
+ lz_total = 0
790
+
791
+ # token embed step
792
+ x = self.token_embeddings(x)
793
+
794
+ # repeated calls to stack
795
+ for i in range(self.stack_depth):
796
+ x, lb, lz = self.stack(x)
797
+ lb_total += lb
798
+ lz_total += lz
799
+
800
+ # final norm
801
+ x = self.ln_final(x)
802
+
803
+ # Vocab projection or lm_head
804
+ x = self.lm_head(x)
805
+
806
+ # calculate aux loss averages
807
+ lb_avg = lb_total / self.total_layers
808
+ lz_avg = lz_total / self.total_layers
809
+
810
+ return x, lb_avg, lz_avg
811
+
812
+
813
+ # ---------------------------------------------------------------------------
814
+ # HuggingFace wrapper (from hf_wrapper.py)
815
+ # ---------------------------------------------------------------------------
816
+
817
+ class LoopLMConfig(PretrainedConfig):
818
+ """Config for all four loop-lm model variants."""
819
+
820
+ model_type = "loop-lm"
821
+
822
+ def __init__(
823
+ self,
824
+ # which of the four architectures to use
825
+ model_variant: str = "base", # "base" | "looped" | "moe" | "looped-moe"
826
+ # shared
827
+ vocab_size: int = 50257,
828
+ context_length: int = 1024,
829
+ d_model: int = 1024,
830
+ num_heads: int = 16,
831
+ d_ff: int = 2752,
832
+ rope_theta: float = 10000.0,
833
+ width_ratio: float = 8.0, # d_model / base_d_model (128); set at training time
834
+ # base + moe only
835
+ num_layers: int = 16,
836
+ # base + looped only
837
+ weight_tying: bool = False,
838
+ # looped + looped-moe only
839
+ num_layers_in_stack: int = 8,
840
+ num_stacks: int = 2,
841
+ # moe + looped-moe only
842
+ num_experts: int = 8,
843
+ num_active: int = 2,
844
+ # aux loss weights — used when forward() is called with labels
845
+ lb_loss_factor: float = 0.01,
846
+ lz_loss_factor: float = 0.001,
847
+ **kwargs,
848
+ ):
849
+ super().__init__(**kwargs)
850
+ self.model_variant = model_variant
851
+ self.vocab_size = vocab_size
852
+ self.context_length = context_length
853
+ self.d_model = d_model
854
+ self.num_heads = num_heads
855
+ self.d_ff = d_ff
856
+ self.rope_theta = rope_theta
857
+ self.width_ratio = width_ratio
858
+ self.num_layers = num_layers
859
+ self.weight_tying = weight_tying
860
+ self.num_layers_in_stack = num_layers_in_stack
861
+ self.num_stacks = num_stacks
862
+ self.num_experts = num_experts
863
+ self.num_active = num_active
864
+ self.lb_loss_factor = lb_loss_factor
865
+ self.lz_loss_factor = lz_loss_factor
866
+ # lm-evaluation-harness looks for this attribute to cap sequence length
867
+ self.max_length = context_length
868
+
869
+
870
+ class LoopLMForCausalLM(PreTrainedModel, GenerationMixin):
871
+ """Causal LM wrapper over all four looped-scaling variants.
872
+
873
+ Implements the HuggingFace PreTrainedModel interface so you can:
874
+ - Upload/download via push_to_hub / from_pretrained
875
+ - Run lm-evaluation-harness evals
876
+ - Fine-tune with TRL's SFTTrainer / DPOTrainer
877
+ """
878
+
879
+ config_class = LoopLMConfig
880
+ # tell HF which parameter holds the output logits for generation
881
+ _keys_to_ignore_on_load_missing = []
882
+
883
+ def __init__(self, config: LoopLMConfig):
884
+ super().__init__(config)
885
+ self.model = self._build_inner_model(config)
886
+ self.post_init()
887
+
888
+ # ------------------------------------------------------------------
889
+ # Model construction
890
+ # ------------------------------------------------------------------
891
+
892
+ def _build_inner_model(self, config: LoopLMConfig):
893
+ kw = dict(
894
+ vocab_size=config.vocab_size,
895
+ context_length=config.context_length,
896
+ d_model=config.d_model,
897
+ num_heads=config.num_heads,
898
+ d_ff=config.d_ff,
899
+ rope_theta=config.rope_theta,
900
+ width_ratio=config.width_ratio,
901
+ # device=None so weights are placed on CPU; caller uses .to(device)
902
+ )
903
+ v = config.model_variant
904
+ if v == "base":
905
+ return MuTransformer(
906
+ **kw,
907
+ num_layers=config.num_layers,
908
+ weight_tying=config.weight_tying,
909
+ )
910
+ elif v == "looped":
911
+ return LoopedTransformer(
912
+ **kw,
913
+ num_layers_in_stack=config.num_layers_in_stack,
914
+ num_stacks=config.num_stacks,
915
+ weight_tying=config.weight_tying,
916
+ )
917
+ elif v == "moe":
918
+ return MoETransformer(
919
+ **kw,
920
+ num_layers=config.num_layers,
921
+ num_experts=config.num_experts,
922
+ num_active=config.num_active,
923
+ )
924
+ elif v == "looped-moe":
925
+ return LoopedMoETransformer(
926
+ **kw,
927
+ num_layers_in_stack=config.num_layers_in_stack,
928
+ num_stacks=config.num_stacks,
929
+ num_experts=config.num_experts,
930
+ num_active=config.num_active,
931
+ )
932
+ else:
933
+ raise ValueError(f"Unknown model_variant: {v!r}. Choose from: base, looped, moe, looped-moe")
934
+
935
+ # ------------------------------------------------------------------
936
+ # Embedding access (required by some HF utilities)
937
+ # ------------------------------------------------------------------
938
+
939
+ def get_input_embeddings(self):
940
+ return self.model.token_embeddings
941
+
942
+ def set_input_embeddings(self, value):
943
+ self.model.token_embeddings = value
944
+
945
+ # ------------------------------------------------------------------
946
+ # Forward
947
+ # ------------------------------------------------------------------
948
+
949
+ def forward(
950
+ self,
951
+ input_ids: torch.LongTensor,
952
+ attention_mask: Optional[torch.Tensor] = None, # causal mask is handled internally
953
+ labels: Optional[torch.LongTensor] = None,
954
+ **kwargs,
955
+ ) -> CausalLMOutputWithPast:
956
+ """
957
+ Args:
958
+ input_ids: (batch, seq)
959
+ attention_mask: ignored — models use a built-in causal mask
960
+ labels: (batch, seq) token ids; if provided, returns cross-entropy loss.
961
+ For MoE variants, aux losses (lb + lz) are added to the CE loss.
962
+ """
963
+ is_moe = self.config.model_variant in ("moe", "looped-moe")
964
+
965
+ if is_moe:
966
+ logits, lb, lz = self.model(input_ids)
967
+ else:
968
+ logits = self.model(input_ids)
969
+ lb = lz = 0.0
970
+
971
+ loss = None
972
+ if labels is not None:
973
+ ce_loss = F.cross_entropy(
974
+ logits.view(-1, logits.size(-1)),
975
+ labels.view(-1),
976
+ )
977
+ aux = self.config.lb_loss_factor * lb + self.config.lz_loss_factor * lz
978
+ loss = ce_loss + aux
979
+
980
+ return CausalLMOutputWithPast(
981
+ loss=loss,
982
+ logits=logits,
983
+ )
984
+
985
+ # ------------------------------------------------------------------
986
+ # Generation support (no KV cache — generation is correct but slow)
987
+ # ------------------------------------------------------------------
988
+
989
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
990
+ return {"input_ids": input_ids}