HaileyStorm commited on
Commit
d29de63
·
verified ·
1 Parent(s): 164b5fe

Update chess-gpt-eval-contrastive/mamba_module.py

Browse files
chess-gpt-eval-contrastive/mamba_module.py CHANGED
@@ -126,7 +126,7 @@ class MambaPlayer:
126
  tensor_output = output
127
  seq_len = tensor_output.shape[1]
128
  bucket = next(b for b in self.move_buckets if self.move_num <= b)
129
- self.activations_sum[layer_idx][bucket]["current"][:, :seq_len, :] += tensor_output.detach().cpu().numpy()
130
  self.activations_count[layer_idx][bucket]["current"] += 1
131
 
132
  self.hooks.append(layer.register_forward_hook(hook))
@@ -324,8 +324,8 @@ class MambaPlayer:
324
  def get_lr(it):
325
  warmup_iters = 150 * 43
326
  lr_decay_iters = 5000 * 43
327
- learning_rate = 0.000015
328
- min_lr = 0.000001
329
  # 1) linear warmup for warmup_iters steps
330
  if it < warmup_iters:
331
  return learning_rate * it / warmup_iters
@@ -345,7 +345,7 @@ class MambaPlayer:
345
  for layer_idx in self.linear_probes:
346
  for bucket in self.move_buckets:
347
  if self.activations_count[layer_idx][bucket]['current'] > 0:
348
- X = torch.from_numpy(self.activations_sum[layer_idx][bucket]['current']).float().flatten(1)[:self.seq_len][-8:] #/ self.activations_count[layer_idx][bucket]['current']).float()
349
  for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
350
  y = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().unsqueeze(1)
351
  if len(y) > 0:
 
126
  tensor_output = output
127
  seq_len = tensor_output.shape[1]
128
  bucket = next(b for b in self.move_buckets if self.move_num <= b)
129
+ self.activations_sum[layer_idx][bucket]["current"][:, :8, :] += tensor_output.detach().cpu().numpy()[:self.seq_len][-8:]
130
  self.activations_count[layer_idx][bucket]["current"] += 1
131
 
132
  self.hooks.append(layer.register_forward_hook(hook))
 
324
  def get_lr(it):
325
  warmup_iters = 150 * 43
326
  lr_decay_iters = 5000 * 43
327
+ learning_rate = 0.003
328
+ min_lr = 0.0001
329
  # 1) linear warmup for warmup_iters steps
330
  if it < warmup_iters:
331
  return learning_rate * it / warmup_iters
 
345
  for layer_idx in self.linear_probes:
346
  for bucket in self.move_buckets:
347
  if self.activations_count[layer_idx][bucket]['current'] > 0:
348
+ X = torch.from_numpy(self.activations_sum[layer_idx][bucket]['current']).float().flatten(1) #/ self.activations_count[layer_idx][bucket]['current']).float()
349
  for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
350
  y = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().unsqueeze(1)
351
  if len(y) > 0: