Spaces:
Runtime error
Runtime error
ethanlshen
commited on
Commit
•
8e6ca38
1
Parent(s):
02f45d3
Changed parallel to nn
Browse files
superposed/llama/superposed_model.py
CHANGED
@@ -199,39 +199,31 @@ class Attention(nn.Module):
|
|
199 |
"""
|
200 |
super().__init__()
|
201 |
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
202 |
-
model_parallel_size =
|
203 |
self.n_local_heads = args.n_heads // model_parallel_size
|
204 |
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
|
205 |
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
206 |
self.head_dim = args.dim // args.n_heads
|
207 |
|
208 |
-
self.wq =
|
209 |
args.dim,
|
210 |
args.n_heads * self.head_dim,
|
211 |
bias=False,
|
212 |
-
gather_output=False,
|
213 |
-
init_method=lambda x: x,
|
214 |
)
|
215 |
-
self.wk =
|
216 |
args.dim,
|
217 |
self.n_kv_heads * self.head_dim,
|
218 |
-
bias=False
|
219 |
-
gather_output=False,
|
220 |
-
init_method=lambda x: x,
|
221 |
)
|
222 |
-
self.wv =
|
223 |
args.dim,
|
224 |
self.n_kv_heads * self.head_dim,
|
225 |
-
bias=False
|
226 |
-
gather_output=False,
|
227 |
-
init_method=lambda x: x,
|
228 |
)
|
229 |
-
self.wo =
|
230 |
args.n_heads * self.head_dim,
|
231 |
args.dim,
|
232 |
-
bias=False
|
233 |
-
input_is_parallel=True,
|
234 |
-
init_method=lambda x: x,
|
235 |
)
|
236 |
|
237 |
self.cache_k = torch.zeros(
|
@@ -336,14 +328,14 @@ class FeedForward(nn.Module):
|
|
336 |
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
337 |
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
338 |
|
339 |
-
self.w1 =
|
340 |
-
dim, hidden_dim, bias=False
|
341 |
)
|
342 |
-
self.w2 =
|
343 |
-
hidden_dim, dim, bias=False
|
344 |
)
|
345 |
-
self.w3 =
|
346 |
-
dim, hidden_dim, bias=False
|
347 |
)
|
348 |
|
349 |
def forward(self, x):
|
@@ -435,12 +427,12 @@ class SuperposedTransformer(nn.Module):
|
|
435 |
self.vocab_size = params.vocab_size
|
436 |
self.n_layers = params.n_layers
|
437 |
|
438 |
-
self.tok_embeddings =
|
439 |
-
params.vocab_size, params.dim
|
440 |
)
|
441 |
|
442 |
-
self.tok_mixing_embeddings =
|
443 |
-
params.vocab_size, params.dim, bias=False
|
444 |
) # dims here are formality (what matters is below)
|
445 |
self.tok_mixing_embeddings.weight = nn.Parameter(self.tok_embeddings.weight.T)
|
446 |
|
@@ -449,8 +441,8 @@ class SuperposedTransformer(nn.Module):
|
|
449 |
self.layers.append(MixedTransformerBlock(layer_id, params))
|
450 |
|
451 |
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
452 |
-
self.output =
|
453 |
-
params.dim, params.vocab_size, bias=False
|
454 |
)
|
455 |
|
456 |
self.freqs_cis = precompute_freqs_cis(
|
|
|
199 |
"""
|
200 |
super().__init__()
|
201 |
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
202 |
+
model_parallel_size = 1
|
203 |
self.n_local_heads = args.n_heads // model_parallel_size
|
204 |
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
|
205 |
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
206 |
self.head_dim = args.dim // args.n_heads
|
207 |
|
208 |
+
self.wq = nn.Linear(
|
209 |
args.dim,
|
210 |
args.n_heads * self.head_dim,
|
211 |
bias=False,
|
|
|
|
|
212 |
)
|
213 |
+
self.wk = nn.Linear(
|
214 |
args.dim,
|
215 |
self.n_kv_heads * self.head_dim,
|
216 |
+
bias=False
|
|
|
|
|
217 |
)
|
218 |
+
self.wv = nn.Linear(
|
219 |
args.dim,
|
220 |
self.n_kv_heads * self.head_dim,
|
221 |
+
bias=False
|
|
|
|
|
222 |
)
|
223 |
+
self.wo = nn.Linear(
|
224 |
args.n_heads * self.head_dim,
|
225 |
args.dim,
|
226 |
+
bias=False
|
|
|
|
|
227 |
)
|
228 |
|
229 |
self.cache_k = torch.zeros(
|
|
|
328 |
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
329 |
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
330 |
|
331 |
+
self.w1 = nn.Linear(
|
332 |
+
dim, hidden_dim, bias=False
|
333 |
)
|
334 |
+
self.w2 = nn.Linear(
|
335 |
+
hidden_dim, dim, bias=False
|
336 |
)
|
337 |
+
self.w3 = nn.Linear(
|
338 |
+
dim, hidden_dim, bias=False
|
339 |
)
|
340 |
|
341 |
def forward(self, x):
|
|
|
427 |
self.vocab_size = params.vocab_size
|
428 |
self.n_layers = params.n_layers
|
429 |
|
430 |
+
self.tok_embeddings = nn.Embedding(
|
431 |
+
params.vocab_size, params.dim
|
432 |
)
|
433 |
|
434 |
+
self.tok_mixing_embeddings = nn.Linear(
|
435 |
+
params.vocab_size, params.dim, bias=False
|
436 |
) # dims here are formality (what matters is below)
|
437 |
self.tok_mixing_embeddings.weight = nn.Parameter(self.tok_embeddings.weight.T)
|
438 |
|
|
|
441 |
self.layers.append(MixedTransformerBlock(layer_id, params))
|
442 |
|
443 |
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
444 |
+
self.output = nn.Linear(
|
445 |
+
params.dim, params.vocab_size, bias=False
|
446 |
)
|
447 |
|
448 |
self.freqs_cis = precompute_freqs_cis(
|