ethanlshen commited on
Commit
8e6ca38
1 Parent(s): 02f45d3

Changed parallel to nn

Browse files
Files changed (1) hide show
  1. superposed/llama/superposed_model.py +20 -28
superposed/llama/superposed_model.py CHANGED
@@ -199,39 +199,31 @@ class Attention(nn.Module):
199
  """
200
  super().__init__()
201
  self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
202
- model_parallel_size = fs_init.get_model_parallel_world_size()
203
  self.n_local_heads = args.n_heads // model_parallel_size
204
  self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
205
  self.n_rep = self.n_local_heads // self.n_local_kv_heads
206
  self.head_dim = args.dim // args.n_heads
207
 
208
- self.wq = ColumnParallelLinear(
209
  args.dim,
210
  args.n_heads * self.head_dim,
211
  bias=False,
212
- gather_output=False,
213
- init_method=lambda x: x,
214
  )
215
- self.wk = ColumnParallelLinear(
216
  args.dim,
217
  self.n_kv_heads * self.head_dim,
218
- bias=False,
219
- gather_output=False,
220
- init_method=lambda x: x,
221
  )
222
- self.wv = ColumnParallelLinear(
223
  args.dim,
224
  self.n_kv_heads * self.head_dim,
225
- bias=False,
226
- gather_output=False,
227
- init_method=lambda x: x,
228
  )
229
- self.wo = RowParallelLinear(
230
  args.n_heads * self.head_dim,
231
  args.dim,
232
- bias=False,
233
- input_is_parallel=True,
234
- init_method=lambda x: x,
235
  )
236
 
237
  self.cache_k = torch.zeros(
@@ -336,14 +328,14 @@ class FeedForward(nn.Module):
336
  hidden_dim = int(ffn_dim_multiplier * hidden_dim)
337
  hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
338
 
339
- self.w1 = ColumnParallelLinear(
340
- dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
341
  )
342
- self.w2 = RowParallelLinear(
343
- hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
344
  )
345
- self.w3 = ColumnParallelLinear(
346
- dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
347
  )
348
 
349
  def forward(self, x):
@@ -435,12 +427,12 @@ class SuperposedTransformer(nn.Module):
435
  self.vocab_size = params.vocab_size
436
  self.n_layers = params.n_layers
437
 
438
- self.tok_embeddings = ParallelEmbedding(
439
- params.vocab_size, params.dim, init_method=lambda x: x
440
  )
441
 
442
- self.tok_mixing_embeddings = ColumnParallelLinear(
443
- params.vocab_size, params.dim, bias=False, init_method=lambda x: x
444
  ) # dims here are formality (what matters is below)
445
  self.tok_mixing_embeddings.weight = nn.Parameter(self.tok_embeddings.weight.T)
446
 
@@ -449,8 +441,8 @@ class SuperposedTransformer(nn.Module):
449
  self.layers.append(MixedTransformerBlock(layer_id, params))
450
 
451
  self.norm = RMSNorm(params.dim, eps=params.norm_eps)
452
- self.output = ColumnParallelLinear(
453
- params.dim, params.vocab_size, bias=False, init_method=lambda x: x
454
  )
455
 
456
  self.freqs_cis = precompute_freqs_cis(
 
199
  """
200
  super().__init__()
201
  self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
202
+ model_parallel_size = 1
203
  self.n_local_heads = args.n_heads // model_parallel_size
204
  self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
205
  self.n_rep = self.n_local_heads // self.n_local_kv_heads
206
  self.head_dim = args.dim // args.n_heads
207
 
208
+ self.wq = nn.Linear(
209
  args.dim,
210
  args.n_heads * self.head_dim,
211
  bias=False,
 
 
212
  )
213
+ self.wk = nn.Linear(
214
  args.dim,
215
  self.n_kv_heads * self.head_dim,
216
+ bias=False
 
 
217
  )
218
+ self.wv = nn.Linear(
219
  args.dim,
220
  self.n_kv_heads * self.head_dim,
221
+ bias=False
 
 
222
  )
223
+ self.wo = nn.Linear(
224
  args.n_heads * self.head_dim,
225
  args.dim,
226
+ bias=False
 
 
227
  )
228
 
229
  self.cache_k = torch.zeros(
 
328
  hidden_dim = int(ffn_dim_multiplier * hidden_dim)
329
  hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
330
 
331
+ self.w1 = nn.Linear(
332
+ dim, hidden_dim, bias=False
333
  )
334
+ self.w2 = nn.Linear(
335
+ hidden_dim, dim, bias=False
336
  )
337
+ self.w3 = nn.Linear(
338
+ dim, hidden_dim, bias=False
339
  )
340
 
341
  def forward(self, x):
 
427
  self.vocab_size = params.vocab_size
428
  self.n_layers = params.n_layers
429
 
430
+ self.tok_embeddings = nn.Embedding(
431
+ params.vocab_size, params.dim
432
  )
433
 
434
+ self.tok_mixing_embeddings = nn.Linear(
435
+ params.vocab_size, params.dim, bias=False
436
  ) # dims here are formality (what matters is below)
437
  self.tok_mixing_embeddings.weight = nn.Parameter(self.tok_embeddings.weight.T)
438
 
 
441
  self.layers.append(MixedTransformerBlock(layer_id, params))
442
 
443
  self.norm = RMSNorm(params.dim, eps=params.norm_eps)
444
+ self.output = nn.Linear(
445
+ params.dim, params.vocab_size, bias=False
446
  )
447
 
448
  self.freqs_cis = precompute_freqs_cis(