HaileyStorm commited on
Commit
f6ed371
1 Parent(s): bf10919

Update chess-gpt-eval-contrastive/mamba_module.py

Browse files
chess-gpt-eval-contrastive/mamba_module.py CHANGED
@@ -12,6 +12,7 @@ import torch.nn as nn
12
  import torch.optim as optim
13
  import wandb
14
  import math
 
15
 
16
  BASE_DIR = "mamba/"
17
 
@@ -376,13 +377,33 @@ class MambaPlayer:
376
  def evaluate_linear_probes(self, board: chess.Board):
377
  self.move_num = board.fullmove_number
378
  bucket = next(b for b in self.move_buckets if self.move_num <= b)
 
 
 
 
379
  for layer_idx in self.linear_probes:
380
  X = torch.from_numpy(self.activations_sum[layer_idx][bucket]['current']).float().flatten(1)
381
  for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
382
  target = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().item()
383
  probe = self.linear_probes[layer_idx][probe_type]
384
- #probe.eval()
385
  prediction = probe(X).item()
386
- if probe_type == 'material_balance':
387
- print(f"Layer {layer_idx}, {probe_type}: {int(prediction)} vs {int(target)}")
 
 
 
 
 
 
 
 
 
 
 
 
388
  self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
 
 
 
 
 
 
12
  import torch.optim as optim
13
  import wandb
14
  import math
15
+ import json
16
 
17
  BASE_DIR = "mamba/"
18
 
 
377
  def evaluate_linear_probes(self, board: chess.Board):
378
  self.move_num = board.fullmove_number
379
  bucket = next(b for b in self.move_buckets if self.move_num <= b)
380
+
381
+ # Create a dictionary to store the statistics for the current move
382
+ probe_stats = {probe_type: {layer_idx: {self.move_num: None} for layer_idx in self.linear_probes} for probe_type in ['q_value', 'q_value_delta', 'material_balance']}
383
+
384
  for layer_idx in self.linear_probes:
385
  X = torch.from_numpy(self.activations_sum[layer_idx][bucket]['current']).float().flatten(1)
386
  for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
387
  target = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().item()
388
  probe = self.linear_probes[layer_idx][probe_type]
 
389
  prediction = probe(X).item()
390
+ #print(f"Layer {layer_idx}, {probe_type}: {prediction} vs {target}")
391
+
392
+ # Calculate the percentage accuracy based on the probe type
393
+ if probe_type == 'q_value':
394
+ accuracy = 1 - abs(prediction - target) / 2 # Q-value range: -1 to 1
395
+ elif probe_type == 'q_value_delta':
396
+ accuracy = 1 - abs(prediction - target) / 4 # Q-value delta range: -2 to 2
397
+ else: # material_balance
398
+ max_range = 35 # Adjust this value based on the expected range of material balance
399
+ accuracy = 1 - min(abs(prediction - target) / max_range, 1)
400
+
401
+ # Store the accuracy in the probe_stats dictionary for the current move
402
+ probe_stats[probe_type][layer_idx][self.move_num] = accuracy
403
+
404
  self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
405
+
406
+ # Append the probe_stats to the file
407
+ with open('probe_stats.json', 'a') as f:
408
+ json.dump(probe_stats, f)
409
+ f.write('\n') # Add a newline separator between moves