HaileyStorm
commited on
Update chess-gpt-eval-contrastive/mamba_module.py
Browse files
chess-gpt-eval-contrastive/mamba_module.py
CHANGED
@@ -137,7 +137,7 @@ class MambaPlayer:
|
|
137 |
self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
|
138 |
self.linear_optimizers = {
|
139 |
layer_idx: {
|
140 |
-
probe_type: optim.Adam(self.linear_probes[layer_idx][probe_type].parameters(), lr=
|
141 |
for probe_type in ['q_value', 'q_value_delta', 'material_balance']
|
142 |
}
|
143 |
for layer_idx in self.linear_probes
|
@@ -312,7 +312,7 @@ class MambaPlayer:
|
|
312 |
self.linear_probe_targets[layer_idx][bucket]['q_value_delta'].append(q_value_delta)
|
313 |
self.linear_probe_targets[layer_idx][bucket]['material_balance'].append(material_bal)
|
314 |
|
315 |
-
def train_linear_probes(self
|
316 |
criterion = nn.MSELoss()
|
317 |
|
318 |
for layer_idx in self.linear_probes:
|
|
|
137 |
self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
|
138 |
self.linear_optimizers = {
|
139 |
layer_idx: {
|
140 |
+
probe_type: optim.Adam(self.linear_probes[layer_idx][probe_type].parameters(), lr=0.01)
|
141 |
for probe_type in ['q_value', 'q_value_delta', 'material_balance']
|
142 |
}
|
143 |
for layer_idx in self.linear_probes
|
|
|
312 |
self.linear_probe_targets[layer_idx][bucket]['q_value_delta'].append(q_value_delta)
|
313 |
self.linear_probe_targets[layer_idx][bucket]['material_balance'].append(material_bal)
|
314 |
|
315 |
+
def train_linear_probes(self):
|
316 |
criterion = nn.MSELoss()
|
317 |
|
318 |
for layer_idx in self.linear_probes:
|