HaileyStorm commited on
Commit
4b10b3e
·
verified ·
1 Parent(s): da8a8f2

Update chess-gpt-eval-contrastive/mamba_module.py

Browse files
chess-gpt-eval-contrastive/mamba_module.py CHANGED
@@ -227,17 +227,40 @@ class MambaPlayer:
227
  with open(path, "rb") as f:
228
  activations_sum, activations_count = pickle.load(f)
229
 
230
- def hook(module, input, output, layer_idx, bucket):
231
- seq_len = output.shape[1]
232
- won_activations = activations_sum[layer_idx][bucket]["won"] / activations_count[layer_idx][bucket]["won"]
233
- lost_activations = activations_sum[layer_idx][bucket]["lost"] / activations_count[layer_idx][bucket]["lost"]
234
- contrastive_activations = won_activations - lost_activations
235
- return output + torch.from_numpy(contrastive_activations[:, :seq_len, :]).to(output.device) * weight
236
-
237
- for layer_idx in activations_sum:
238
- for bucket in self.move_buckets:
239
- self.hooks.append(self.model.backbone.layers[layer_idx].register_forward_hook(
240
- lambda module, input, output, layer_idx=layer_idx, bucket=bucket: hook(module, input, output, layer_idx, bucket)
241
- ))
 
 
 
 
 
 
 
 
 
242
 
 
 
 
 
243
 
 
 
 
 
 
 
 
 
 
 
 
227
  with open(path, "rb") as f:
228
  activations_sum, activations_count = pickle.load(f)
229
 
230
+ self.contrastive_activations_cache = {}
231
+
232
+ def hook(module, input, output, layer_idx):
233
+ if isinstance(output, tuple):
234
+ tensor_output = output[0]
235
+ else:
236
+ tensor_output = output
237
+ seq_len = tensor_output.shape[1]
238
+ bucket = next(b for b in self.move_buckets if self.move_num <= b)
239
+
240
+ # Check cache first
241
+ if layer_idx in self.contrastive_activations_cache and bucket in self.contrastive_activations_cache[layer_idx]:
242
+ safe_contrastive_activations = self.contrastive_activations_cache[layer_idx][bucket]
243
+ else:
244
+ won_activations = activations_sum[layer_idx][bucket]["won"] / activations_count[layer_idx][bucket]["won"]
245
+ lost_activations = activations_sum[layer_idx][bucket]["lost"] / activations_count[layer_idx][bucket]["lost"]
246
+ contrastive_activations = won_activations - lost_activations
247
+ contrastive_activations_tensor = torch.from_numpy(contrastive_activations).to(tensor_output.device)
248
+ valid_activations = torch.isfinite(contrastive_activations_tensor)
249
+ safe_contrastive_activations = torch.zeros_like(contrastive_activations_tensor)
250
+ safe_contrastive_activations[valid_activations] = contrastive_activations_tensor[valid_activations]
251
 
252
+ # Cache the safe activations
253
+ if layer_idx not in self.contrastive_activations_cache:
254
+ self.contrastive_activations_cache[layer_idx] = {}
255
+ self.contrastive_activations_cache[layer_idx][bucket] = safe_contrastive_activations
256
 
257
+ tensor_output += safe_contrastive_activations[:, :seq_len, :] * weight
258
+ if isinstance(output, tuple):
259
+ return tensor_output, output[1]
260
+ else:
261
+ return tensor_output
262
+
263
+ for layer_idx in activations_sum:
264
+ self.hooks.append(self.model.backbone.layers[layer_idx].register_forward_hook(
265
+ lambda module, input, output, layer_idx=layer_idx: hook(module, input, output, layer_idx)
266
+ ))