HaileyStorm
commited on
Update chess-gpt-eval-contrastive/mamba_module.py
Browse files
chess-gpt-eval-contrastive/mamba_module.py
CHANGED
@@ -227,17 +227,40 @@ class MambaPlayer:
|
|
227 |
with open(path, "rb") as f:
|
228 |
activations_sum, activations_count = pickle.load(f)
|
229 |
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
for
|
239 |
-
|
240 |
-
|
241 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
242 |
|
|
|
|
|
|
|
|
|
243 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
227 |
with open(path, "rb") as f:
|
228 |
activations_sum, activations_count = pickle.load(f)
|
229 |
|
230 |
+
self.contrastive_activations_cache = {}
|
231 |
+
|
232 |
+
def hook(module, input, output, layer_idx):
|
233 |
+
if isinstance(output, tuple):
|
234 |
+
tensor_output = output[0]
|
235 |
+
else:
|
236 |
+
tensor_output = output
|
237 |
+
seq_len = tensor_output.shape[1]
|
238 |
+
bucket = next(b for b in self.move_buckets if self.move_num <= b)
|
239 |
+
|
240 |
+
# Check cache first
|
241 |
+
if layer_idx in self.contrastive_activations_cache and bucket in self.contrastive_activations_cache[layer_idx]:
|
242 |
+
safe_contrastive_activations = self.contrastive_activations_cache[layer_idx][bucket]
|
243 |
+
else:
|
244 |
+
won_activations = activations_sum[layer_idx][bucket]["won"] / activations_count[layer_idx][bucket]["won"]
|
245 |
+
lost_activations = activations_sum[layer_idx][bucket]["lost"] / activations_count[layer_idx][bucket]["lost"]
|
246 |
+
contrastive_activations = won_activations - lost_activations
|
247 |
+
contrastive_activations_tensor = torch.from_numpy(contrastive_activations).to(tensor_output.device)
|
248 |
+
valid_activations = torch.isfinite(contrastive_activations_tensor)
|
249 |
+
safe_contrastive_activations = torch.zeros_like(contrastive_activations_tensor)
|
250 |
+
safe_contrastive_activations[valid_activations] = contrastive_activations_tensor[valid_activations]
|
251 |
|
252 |
+
# Cache the safe activations
|
253 |
+
if layer_idx not in self.contrastive_activations_cache:
|
254 |
+
self.contrastive_activations_cache[layer_idx] = {}
|
255 |
+
self.contrastive_activations_cache[layer_idx][bucket] = safe_contrastive_activations
|
256 |
|
257 |
+
tensor_output += safe_contrastive_activations[:, :seq_len, :] * weight
|
258 |
+
if isinstance(output, tuple):
|
259 |
+
return tensor_output, output[1]
|
260 |
+
else:
|
261 |
+
return tensor_output
|
262 |
+
|
263 |
+
for layer_idx in activations_sum:
|
264 |
+
self.hooks.append(self.model.backbone.layers[layer_idx].register_forward_hook(
|
265 |
+
lambda module, input, output, layer_idx=layer_idx: hook(module, input, output, layer_idx)
|
266 |
+
))
|