HaileyStorm commited on
Commit
432e67d
·
verified ·
1 Parent(s): e8aba5c

Update chess-gpt-eval-contrastive/mamba_module.py

Browse files
chess-gpt-eval-contrastive/mamba_module.py CHANGED
@@ -137,7 +137,7 @@ class MambaPlayer:
137
  self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
138
  self.linear_optimizers = {
139
  layer_idx: {
140
- probe_type: optim.Adam(self.linear_probes[layer_idx][probe_type].parameters(), lr=lr)
141
  for probe_type in ['q_value', 'q_value_delta', 'material_balance']
142
  }
143
  for layer_idx in self.linear_probes
@@ -312,7 +312,7 @@ class MambaPlayer:
312
  self.linear_probe_targets[layer_idx][bucket]['q_value_delta'].append(q_value_delta)
313
  self.linear_probe_targets[layer_idx][bucket]['material_balance'].append(material_bal)
314
 
315
- def train_linear_probes(self, lr=0.01):
316
  criterion = nn.MSELoss()
317
 
318
  for layer_idx in self.linear_probes:
 
137
  self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
138
  self.linear_optimizers = {
139
  layer_idx: {
140
+ probe_type: optim.Adam(self.linear_probes[layer_idx][probe_type].parameters(), lr=0.01)
141
  for probe_type in ['q_value', 'q_value_delta', 'material_balance']
142
  }
143
  for layer_idx in self.linear_probes
 
312
  self.linear_probe_targets[layer_idx][bucket]['q_value_delta'].append(q_value_delta)
313
  self.linear_probe_targets[layer_idx][bucket]['material_balance'].append(material_bal)
314
 
315
+ def train_linear_probes(self):
316
  criterion = nn.MSELoss()
317
 
318
  for layer_idx in self.linear_probes: