HaileyStorm
commited on
Update chess-gpt-eval-contrastive/mamba_module.py
Browse files
chess-gpt-eval-contrastive/mamba_module.py
CHANGED
@@ -126,7 +126,7 @@ class MambaPlayer:
|
|
126 |
tensor_output = output
|
127 |
seq_len = tensor_output.shape[1]
|
128 |
bucket = next(b for b in self.move_buckets if self.move_num <= b)
|
129 |
-
self.activations_sum[layer_idx][bucket]["current"][:, :
|
130 |
self.activations_count[layer_idx][bucket]["current"] += 1
|
131 |
|
132 |
self.hooks.append(layer.register_forward_hook(hook))
|
@@ -324,8 +324,8 @@ class MambaPlayer:
|
|
324 |
def get_lr(it):
|
325 |
warmup_iters = 150 * 43
|
326 |
lr_decay_iters = 5000 * 43
|
327 |
-
learning_rate = 0.
|
328 |
-
min_lr = 0.
|
329 |
# 1) linear warmup for warmup_iters steps
|
330 |
if it < warmup_iters:
|
331 |
return learning_rate * it / warmup_iters
|
@@ -345,7 +345,7 @@ class MambaPlayer:
|
|
345 |
for layer_idx in self.linear_probes:
|
346 |
for bucket in self.move_buckets:
|
347 |
if self.activations_count[layer_idx][bucket]['current'] > 0:
|
348 |
-
X = torch.from_numpy(self.activations_sum[layer_idx][bucket]['current']).float().flatten(1)
|
349 |
for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
|
350 |
y = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().unsqueeze(1)
|
351 |
if len(y) > 0:
|
|
|
126 |
tensor_output = output
|
127 |
seq_len = tensor_output.shape[1]
|
128 |
bucket = next(b for b in self.move_buckets if self.move_num <= b)
|
129 |
+
self.activations_sum[layer_idx][bucket]["current"][:, :8, :] += tensor_output.detach().cpu().numpy()[:self.seq_len][-8:]
|
130 |
self.activations_count[layer_idx][bucket]["current"] += 1
|
131 |
|
132 |
self.hooks.append(layer.register_forward_hook(hook))
|
|
|
324 |
def get_lr(it):
|
325 |
warmup_iters = 150 * 43
|
326 |
lr_decay_iters = 5000 * 43
|
327 |
+
learning_rate = 0.003
|
328 |
+
min_lr = 0.0001
|
329 |
# 1) linear warmup for warmup_iters steps
|
330 |
if it < warmup_iters:
|
331 |
return learning_rate * it / warmup_iters
|
|
|
345 |
for layer_idx in self.linear_probes:
|
346 |
for bucket in self.move_buckets:
|
347 |
if self.activations_count[layer_idx][bucket]['current'] > 0:
|
348 |
+
X = torch.from_numpy(self.activations_sum[layer_idx][bucket]['current']).float().flatten(1) #/ self.activations_count[layer_idx][bucket]['current']).float()
|
349 |
for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
|
350 |
y = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().unsqueeze(1)
|
351 |
if len(y) > 0:
|