HaileyStorm commited on
Commit
27f8947
1 Parent(s): 62e35f0

Upload chess-mamba-vs-xformer/train_bygame.py with huggingface_hub

Browse files
chess-mamba-vs-xformer/train_bygame.py CHANGED
@@ -92,17 +92,22 @@ anneal_decay_iters = None # Set at init
92
 
93
  if model_type == 'mamba':
94
  from mamba_lm import MambaLM, MambaLMConfig
 
95
  model_config = MambaLMConfig(
96
  d_model=d_model,
97
- n_layers=n_layer,
98
- dt_rank=dt_rank,
99
- d_state=d_state,
100
- expand_factor=expand_factor,
101
- bias=bias,
102
- conv_bias=conv_bias,
103
- pscan=pscan,
104
- vocab_size=vocab_size
105
- )
 
 
 
 
106
  elif model_type == 'xformer':
107
  from xformer import GPTConfig, GPT
108
  model_config = GPTConfig(
@@ -152,10 +157,13 @@ train_files = glob.glob(os.path.join(data_dir, 'train*.parquet')) + \
152
  glob.glob(os.path.join(data_dir, 'stable*.parquet')) + \
153
  glob.glob(os.path.join(data_dir, 'anneal*.parquet'))
154
  train_datasets = []
 
155
  for f in train_files:
156
  dataset = pq.read_table(f).to_pandas()
157
  dataset = dataset[dataset['tokenized'].apply(len) >= 8]
158
  train_datasets.append(dataset)
 
 
159
  #val_data = pq.read_table(os.path.join(data_dir, 'val.parquet')).to_pandas()
160
  #val_data = val_data[val_data['tokenized'].apply(len) >= 8]
161
  truncated_games_count = 0
@@ -217,7 +225,8 @@ if init_from == 'scratch':
217
  else:
218
  model_config.vocab_size = meta_vocab_size
219
  if model_type == 'mamba':
220
- model = MambaLM(model_config)
 
221
  else:
222
  model = GPT(model_config)
223
  if auto_clip:
@@ -233,7 +242,8 @@ elif init_from == 'resume' or init_from == 'anneal':
233
  checkpoint = torch.load(ckpt_path, map_location=device)
234
  model_config = checkpoint['model_args']
235
  if model_type == 'mamba':
236
- model = MambaLM(model_config)
 
237
  else:
238
  model = GPT(model_config)
239
  state_dict = checkpoint['model']
@@ -309,10 +319,11 @@ if ddp:
309
 
310
  def batch_to_loss(sequences, max_length_in_batch):
311
  if model_type == 'mamba':
312
- logits = model(sequences[:, :-1]) # Forward pass, exclude last token for input
313
  # Compute loss (assuming next token prediction task)
314
  targets = sequences[:, 1:].reshape(-1) # Shifted by one for next token prediction
315
  return F.cross_entropy(logits.view(-1, logits.size(-1)), targets)
 
316
  else:
317
  inputs = sequences[:, :-1]
318
  targets = sequences[:, 1:].reshape(-1)
@@ -474,7 +485,7 @@ while True:
474
  scaler.update()
475
  # flush the gradients as soon as we can, no need for this memory anymore
476
  optimizer.zero_grad(set_to_none=True)
477
- torch.cuda.empty_cache()
478
 
479
  # timing and logging
480
  t1 = time.time()
 
92
 
93
  if model_type == 'mamba':
94
  from mamba_lm import MambaLM, MambaLMConfig
95
+ from mamba_ssm import MambaLMHeadModel
96
  model_config = MambaLMConfig(
97
  d_model=d_model,
98
+ #n_layers=n_layer,
99
+ n_layer=n_layer,
100
+ ssm_cfg={
101
+ 'dt_rank': dt_rank,
102
+ 'd_state': d_state,
103
+ #'expand_factor': expand_factor,
104
+ 'bias': bias,
105
+ 'conv_bias':conv_bias,
106
+ #'pscan':pscan,
107
+ },
108
+ vocab_size=vocab_size,
109
+ pad_vocab_size_multiple=1
110
+ ).to_mamba_config()
111
  elif model_type == 'xformer':
112
  from xformer import GPTConfig, GPT
113
  model_config = GPTConfig(
 
157
  glob.glob(os.path.join(data_dir, 'stable*.parquet')) + \
158
  glob.glob(os.path.join(data_dir, 'anneal*.parquet'))
159
  train_datasets = []
160
+ print("Loading dataset...")
161
  for f in train_files:
162
  dataset = pq.read_table(f).to_pandas()
163
  dataset = dataset[dataset['tokenized'].apply(len) >= 8]
164
  train_datasets.append(dataset)
165
+ print('.',end='',flush=True)
166
+ print("\nLoaded.")
167
  #val_data = pq.read_table(os.path.join(data_dir, 'val.parquet')).to_pandas()
168
  #val_data = val_data[val_data['tokenized'].apply(len) >= 8]
169
  truncated_games_count = 0
 
225
  else:
226
  model_config.vocab_size = meta_vocab_size
227
  if model_type == 'mamba':
228
+ #model = MambaLM(model_config)
229
+ model = MambaLMHeadModel(model_config)
230
  else:
231
  model = GPT(model_config)
232
  if auto_clip:
 
242
  checkpoint = torch.load(ckpt_path, map_location=device)
243
  model_config = checkpoint['model_args']
244
  if model_type == 'mamba':
245
+ #model = MambaLM(model_config)
246
+ model = MambaLMHeadModel(model_config)
247
  else:
248
  model = GPT(model_config)
249
  state_dict = checkpoint['model']
 
319
 
320
  def batch_to_loss(sequences, max_length_in_batch):
321
  if model_type == 'mamba':
322
+ logits = model(sequences[:, :-1]).logits # Forward pass, exclude last token for input
323
  # Compute loss (assuming next token prediction task)
324
  targets = sequences[:, 1:].reshape(-1) # Shifted by one for next token prediction
325
  return F.cross_entropy(logits.view(-1, logits.size(-1)), targets)
326
+ #return F.cross_entropy(logits.reshape(-1), targets)
327
  else:
328
  inputs = sequences[:, :-1]
329
  targets = sequences[:, 1:].reshape(-1)
 
485
  scaler.update()
486
  # flush the gradients as soon as we can, no need for this memory anymore
487
  optimizer.zero_grad(set_to_none=True)
488
+ #torch.cuda.empty_cache()
489
 
490
  # timing and logging
491
  t1 = time.time()