ELIA / circuit_analysis /attribution_graphs_olmo.py
aaron0eidt's picture
Deploy static demo
5b6c556
#!/usr/bin/env python3
# This script generates attribution graphs for the OLMo2 7B model.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Tuple, Optional, Any, Set
import json
import logging
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM
from collections import defaultdict
import networkx as nx
from dataclasses import dataclass
from tqdm import tqdm
import pickle
import requests
import time
import random
import copy
import os
import argparse
# --- Add this block to fix the import path ---
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parent.parent))
# ---------------------------------------------
from utilities.utils import init_qwen_api, set_seed
# --- Constants ---
RESULTS_DIR = "circuit_analysis/results"
CLT_SAVE_PATH = "circuit_analysis/models/clt_model.pth"
# Configure logging.
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Set the device for training.
if torch.backends.mps.is_available():
DEVICE = torch.device("mps")
logger.info("Using MPS (Metal Performance Shaders) for GPU acceleration")
elif torch.cuda.is_available():
DEVICE = torch.device("cuda")
logger.info("Using CUDA for GPU acceleration")
else:
DEVICE = torch.device("cpu")
logger.info("Using CPU")
@dataclass
class AttributionGraphConfig:
# Configuration for building the attribution graph.
model_path: str = "./models/OLMo-2-1124-7B"
max_seq_length: int = 512
n_features_per_layer: int = 512 # Number of features in each CLT layer
sparsity_lambda: float = 1e-3 # Updated for L1 sparsity
reconstruction_loss_weight: float = 1.0
batch_size: int = 8
learning_rate: float = 1e-4
training_steps: int = 1000
device: str = str(DEVICE)
pruning_threshold: float = 0.8 # For graph pruning
intervention_strength: float = 5.0 # For perturbation experiments
qwen_api_config: Optional[Dict[str, str]] = None
max_ablation_experiments: Optional[int] = None
ablation_top_k_tokens: int = 5
ablation_features_per_layer: Optional[int] = 2
summary_max_layers: Optional[int] = None
summary_features_per_layer: Optional[int] = 2
random_baseline_trials: int = 5
random_baseline_features: int = 1
random_baseline_seed: int = 1234
path_ablation_top_k: int = 3
random_path_baseline_trials: int = 5
graph_max_features_per_layer: int = 40
graph_feature_activation_threshold: float = 0.01
graph_edge_weight_threshold: float = 0.0
graph_max_edges_per_node: int = 12
class JumpReLU(nn.Module):
# The JumpReLU activation function.
def __init__(self, threshold: float = 0.0):
super().__init__()
self.threshold = threshold
def forward(self, x):
return F.relu(x - self.threshold)
class CrossLayerTranscoder(nn.Module):
# The Cross-Layer Transcoder (CLT) model.
def __init__(self, model_config: Dict, clt_config: AttributionGraphConfig):
super().__init__()
self.config = clt_config
self.model_config = model_config
self.n_layers = model_config['num_hidden_layers']
self.hidden_size = model_config['hidden_size']
self.n_features = clt_config.n_features_per_layer
# Encoder weights for each layer.
self.encoders = nn.ModuleList([
nn.Linear(self.hidden_size, self.n_features, bias=False)
for _ in range(self.n_layers)
])
# Decoder weights for cross-layer connections.
self.decoders = nn.ModuleDict()
for source_layer in range(self.n_layers):
for target_layer in range(source_layer, self.n_layers):
key = f"{source_layer}_to_{target_layer}"
self.decoders[key] = nn.Linear(self.n_features, self.hidden_size, bias=False)
# The activation function.
self.activation = JumpReLU(threshold=0.0)
# Initialize the weights.
self._init_weights()
def _init_weights(self):
# Initializes the weights with small random values.
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.01)
def encode(self, layer_idx: int, residual_activations: torch.Tensor) -> torch.Tensor:
# Encodes residual stream activations to feature activations.
return self.activation(self.encoders[layer_idx](residual_activations))
def decode(self, source_layer: int, target_layer: int, feature_activations: torch.Tensor) -> torch.Tensor:
# Decodes feature activations to the MLP output space.
key = f"{source_layer}_to_{target_layer}"
return self.decoders[key](feature_activations)
def forward(self, residual_activations: List[torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
# The forward pass of the CLT.
feature_activations = []
reconstructed_mlp_outputs = []
# Encode features for each layer.
for layer_idx, residual in enumerate(residual_activations):
features = self.encode(layer_idx, residual)
feature_activations.append(features)
# Reconstruct MLP outputs with cross-layer connections.
for target_layer in range(self.n_layers):
reconstruction = torch.zeros_like(residual_activations[target_layer])
# Sum contributions from all previous layers.
for source_layer in range(target_layer + 1):
decoded = self.decode(source_layer, target_layer, feature_activations[source_layer])
reconstruction += decoded
reconstructed_mlp_outputs.append(reconstruction)
return feature_activations, reconstructed_mlp_outputs
class FeatureVisualizer:
# A class to visualize and interpret individual features.
def __init__(self, tokenizer, cache_dir: Optional[Path] = None):
self.tokenizer = tokenizer
self.feature_interpretations: Dict[str, str] = {}
self.cache_dir = cache_dir
if self.cache_dir is not None:
self.cache_dir = Path(self.cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
self._load_cache()
def _cache_file(self) -> Optional[Path]:
if self.cache_dir is None:
return None
return self.cache_dir / "feature_interpretations.json"
def _load_cache(self):
cache_file = self._cache_file()
if cache_file is None or not cache_file.exists():
return
try:
with open(cache_file, 'r', encoding='utf-8') as f:
data = json.load(f)
if isinstance(data, dict):
self.feature_interpretations.update({str(k): str(v) for k, v in data.items()})
except Exception as e:
logger.warning(f"Failed to load feature interpretation cache: {e}")
def _save_cache(self):
cache_file = self._cache_file()
if cache_file is None:
return
try:
with open(cache_file, 'w', encoding='utf-8') as f:
json.dump(self.feature_interpretations, f, indent=2)
except Exception as e:
logger.warning(f"Failed to save feature interpretation cache: {e}")
def visualize_feature(self, feature_idx: int, layer_idx: int,
activations: torch.Tensor, input_tokens: List[str],
top_k: int = 10) -> Dict:
# Creates a visualization for a single feature.
feature_acts = activations[:, feature_idx].detach().cpu().numpy()
# Find the top activating positions.
top_positions = np.argsort(feature_acts)[-top_k:][::-1]
visualization = {
'feature_idx': feature_idx,
'layer_idx': layer_idx,
'max_activation': float(feature_acts.max()),
'mean_activation': float(feature_acts.mean()),
'sparsity': float((feature_acts > 0.1).mean()),
'top_activations': []
}
for pos in top_positions:
if pos < len(input_tokens):
visualization['top_activations'].append({
'token': input_tokens[pos],
'position': int(pos),
'activation': float(feature_acts[pos])
})
return visualization
def interpret_feature(self, feature_idx: int, layer_idx: int,
visualization_data: Dict,
qwen_api_config: Optional[Dict[str, str]] = None) -> str:
# Interprets a feature based on its top activating tokens.
top_tokens = [item['token'] for item in visualization_data['top_activations']]
cache_key = f"L{layer_idx}_F{feature_idx}"
if cache_key in self.feature_interpretations:
return self.feature_interpretations[cache_key]
# Use the Qwen API if it is configured.
if qwen_api_config and qwen_api_config.get('api_key'):
feature_name = cache_key
interpretation = get_feature_interpretation_with_qwen(
qwen_api_config, top_tokens, feature_name, layer_idx
)
else:
# Use a simple heuristic as a fallback.
if len(set(top_tokens)) == 1 and top_tokens:
interpretation = f"Specific token: '{top_tokens[0]}'"
elif top_tokens and all(token.isalpha() for token in top_tokens):
interpretation = "Word/alphabetic tokens"
elif top_tokens and all(token.isdigit() for token in top_tokens):
interpretation = "Numeric tokens"
elif top_tokens and all(token in '.,!?;:' for token in top_tokens):
interpretation = "Punctuation"
else:
interpretation = "Mixed/polysemantic feature"
self.feature_interpretations[cache_key] = interpretation
self._save_cache()
return interpretation
class AttributionGraph:
# A class to construct and analyze attribution graphs.
def __init__(self, clt: CrossLayerTranscoder, tokenizer, config: AttributionGraphConfig):
self.clt = clt
self.tokenizer = tokenizer
self.config = config
self.graph = nx.DiGraph()
self.node_types = {} # Track node types (feature, embedding, error, output)
self.edge_weights = {}
self.feature_metadata: Dict[str, Dict[str, Any]] = {}
def compute_virtual_weights(self, source_layer: int, target_layer: int,
source_feature: int, target_feature: int) -> float:
# Computes the virtual weight between two features.
if target_layer <= source_layer:
return 0.0
# Get the encoder and decoder weights.
encoder_weight = self.clt.encoders[target_layer].weight[target_feature] # [hidden_size]
total_weight = 0.0
for intermediate_layer in range(source_layer, target_layer):
decoder_key = f"{source_layer}_to_{intermediate_layer}"
if decoder_key in self.clt.decoders:
decoder_weight = self.clt.decoders[decoder_key].weight[:, source_feature] # [hidden_size]
# The virtual weight is inner product
virtual_weight = torch.dot(decoder_weight, encoder_weight).item()
total_weight += virtual_weight
return total_weight
def construct_graph(self, input_tokens: List[str],
feature_activations: List[torch.Tensor],
target_token_idx: int = -1) -> nx.DiGraph:
# Constructs the attribution graph for a prompt.
self.graph.clear()
self.node_types.clear()
self.edge_weights.clear()
seq_len = len(input_tokens)
n_layers = len(feature_activations)
# Add embedding nodes for the input tokens.
for i, token in enumerate(input_tokens):
node_id = f"emb_{i}_{token}"
self.graph.add_node(node_id)
self.node_types[node_id] = "embedding"
# Add nodes for the features.
active_features = {} # Track which features are significantly active
max_features_per_layer = self.config.graph_max_features_per_layer or 20 # Limit features per layer to prevent explosion
activation_threshold = self.config.graph_feature_activation_threshold
edge_weight_threshold = self.config.graph_edge_weight_threshold
max_edges_per_node_cfg = self.config.graph_max_edges_per_node or 5
for layer_idx, features in enumerate(feature_activations):
# features shape: [batch_size, seq_len, n_features]
batch_size, seq_len_layer, n_features = features.shape
# Get the top activating features for this layer.
layer_activations = features[0].mean(dim=0) # Average across sequence
top_features = torch.topk(layer_activations,
k=min(max_features_per_layer, n_features)).indices
for token_pos in range(min(seq_len, seq_len_layer)):
for feat_idx in top_features:
activation = features[0, token_pos, feat_idx.item()].item()
if activation > activation_threshold:
node_id = f"feat_L{layer_idx}_T{token_pos}_F{feat_idx.item()}"
self.graph.add_node(node_id)
self.node_types[node_id] = "feature"
active_features[node_id] = {
'layer': layer_idx,
'token_pos': token_pos,
'feature_idx': feat_idx.item(),
'activation': activation
}
self.feature_metadata[node_id] = {
'layer': layer_idx,
'token_position': token_pos,
'feature_index': feat_idx.item(),
'activation': activation,
'input_token': input_tokens[token_pos] if token_pos < len(input_tokens) else None
}
# Add an output node for the target token.
output_node = f"output_{target_token_idx}"
self.graph.add_node(output_node)
self.node_types[output_node] = "output"
# Add edges based on virtual weights and activations.
feature_nodes = [node for node, type_ in self.node_types.items() if type_ == "feature"]
print(f" Building attribution graph: {len(feature_nodes)} feature nodes, {len(self.graph.nodes())} total nodes")
# Limit the number of edges to compute.
max_edges_per_node = max(max_edges_per_node_cfg, 1) # Limit connections per node
for i, source_node in enumerate(feature_nodes):
if i % 50 == 0: # Progress indicator
print(f" Processing node {i+1}/{len(feature_nodes)}")
edges_added = 0
source_info = active_features[source_node]
source_activation = source_info['activation']
# Add edges to other features.
for target_node in feature_nodes:
if source_node == target_node or edges_added >= max_edges_per_node:
continue
target_info = active_features[target_node]
# Only add edges that go forward in the network.
if (target_info['layer'] > source_info['layer'] or
(target_info['layer'] == source_info['layer'] and
target_info['token_pos'] > source_info['token_pos'])):
virtual_weight = self.compute_virtual_weights(
source_info['layer'], target_info['layer'],
source_info['feature_idx'], target_info['feature_idx']
)
if abs(virtual_weight) > edge_weight_threshold:
edge_weight = source_activation * virtual_weight
self.graph.add_edge(source_node, target_node, weight=edge_weight)
self.edge_weights[(source_node, target_node)] = edge_weight
edges_added += 1
# Add edges to the output node.
layer_position = source_info['layer']
# Allow contributions from all layers, with smaller weights for early layers.
layer_scale = 0.1 if layer_position >= n_layers - 2 else max(0.05, 0.1 * (layer_position + 1) / n_layers)
output_weight = source_activation * layer_scale
if abs(output_weight) > 0:
self.graph.add_edge(source_node, output_node, weight=output_weight)
self.edge_weights[(source_node, output_node)] = output_weight
# Add edges from embeddings to early features.
for emb_node in [node for node, type_ in self.node_types.items() if type_ == "embedding"]:
token_idx = int(emb_node.split('_')[1])
for feat_node in feature_nodes:
feat_info = active_features[feat_node]
if feat_info['layer'] == 0 and feat_info['token_pos'] == token_idx:
# Direct connection from an embedding to a first-layer feature.
weight = feat_info['activation'] * 0.5 # Simplified
self.graph.add_edge(emb_node, feat_node, weight=weight)
self.edge_weights[(emb_node, feat_node)] = weight
return self.graph
def prune_graph(self, threshold: float = 0.8) -> nx.DiGraph:
# Prunes the graph to keep only the most important nodes.
# Calculate node importance based on edge weights.
node_importance = defaultdict(float)
for (source, target), weight in self.edge_weights.items():
node_importance[source] += abs(weight)
node_importance[target] += abs(weight)
# Keep the top nodes by importance.
sorted_nodes = sorted(node_importance.items(), key=lambda x: x[1], reverse=True)
n_keep = int(len(sorted_nodes) * threshold)
important_nodes = set([node for node, _ in sorted_nodes[:n_keep]])
# Always keep the output and embedding nodes.
for node, type_ in self.node_types.items():
if type_ in ["output", "embedding"]:
important_nodes.add(node)
# Create the pruned graph.
pruned_graph = self.graph.subgraph(important_nodes).copy()
return pruned_graph
def visualize_graph(self, graph: nx.DiGraph = None, save_path: str = None):
# Visualizes the attribution graph.
if graph is None:
graph = self.graph
plt.figure(figsize=(12, 8))
# Create a layout for the graph.
pos = nx.spring_layout(graph, k=1, iterations=50)
# Color the nodes by type.
node_colors = []
for node in graph.nodes():
node_type = self.node_types.get(node, "unknown")
if node_type == "embedding":
node_colors.append('lightblue')
elif node_type == "feature":
node_colors.append('lightgreen')
elif node_type == "output":
node_colors.append('orange')
else:
node_colors.append('gray')
# Draw the nodes.
nx.draw_networkx_nodes(graph, pos, node_color=node_colors,
node_size=300, alpha=0.8)
# Draw the edges with thickness based on weight.
edges = graph.edges()
edge_weights = [abs(self.edge_weights.get((u, v), 0.1)) for u, v in edges]
max_weight = max(edge_weights) if edge_weights else 1
edge_widths = [w / max_weight * 3 for w in edge_weights]
nx.draw_networkx_edges(graph, pos, width=edge_widths, alpha=0.6,
edge_color='gray', arrows=True)
# Draw the labels.
nx.draw_networkx_labels(graph, pos, font_size=8)
plt.title("Attribution Graph")
plt.axis('off')
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.show()
class PerturbationExperiments:
# Conducts perturbation experiments to validate hypotheses.
def __init__(self, model, clt: CrossLayerTranscoder, tokenizer):
self.model = model
self.clt = clt
self.tokenizer = tokenizer
self._transformer_blocks: Optional[List[nn.Module]] = None
def _get_transformer_blocks(self) -> List[nn.Module]:
if self._transformer_blocks is not None:
return self._transformer_blocks
n_layers = getattr(self.model.config, "num_hidden_layers", None)
if n_layers is None:
raise ValueError("Model config does not expose num_hidden_layers; cannot resolve transformer blocks.")
candidate_lists: List[Tuple[str, nn.ModuleList]] = []
for name, module in self.model.named_modules():
if isinstance(module, nn.ModuleList) and len(module) == n_layers:
candidate_lists.append((name, module))
if not candidate_lists:
raise ValueError("Unable to locate transformer block ModuleList in model.")
# Prefer names that look like transformer blocks.
def _score(name: str) -> Tuple[int, str]:
preferred_suffixes = ("layers", "blocks", "h")
for idx, suffix in enumerate(preferred_suffixes):
if name.endswith(suffix):
return (idx, name)
return (len(preferred_suffixes), name)
selected_name, selected_list = sorted(candidate_lists, key=lambda item: _score(item[0]))[0]
self._transformer_blocks = list(selected_list)
logger.debug(f"Resolved transformer blocks from ModuleList '{selected_name}'.")
return self._transformer_blocks
def _format_top_tokens(self, top_tokens: torch.return_types.topk) -> List[Tuple[str, float]]:
return [
(self.tokenizer.decode([idx]), prob.item())
for idx, prob in zip(top_tokens.indices, top_tokens.values)
]
def _prepare_inputs(self, input_text: str, top_k: int) -> Dict[str, Any]:
if torch.backends.mps.is_available():
torch.mps.empty_cache()
device = next(self.model.parameters()).device
inputs = self.tokenizer(
input_text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
)
if inputs["input_ids"].size(0) != 1:
raise ValueError("Perturbation experiments currently support only batch size 1.")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
baseline_outputs = self.model(**inputs, output_hidden_states=True, return_dict=True)
baseline_logits = baseline_outputs.logits[0]
target_position = baseline_logits.size(0) - 1
baseline_last_token_logits = baseline_logits[target_position]
baseline_probs = F.softmax(baseline_last_token_logits, dim=-1)
baseline_top_tokens = torch.topk(baseline_probs, k=top_k)
hidden_states: List[torch.Tensor] = list(baseline_outputs.hidden_states[1:])
with torch.no_grad():
feature_activations, _ = self.clt(hidden_states)
return {
'inputs': inputs,
'baseline_outputs': baseline_outputs,
'baseline_logits': baseline_logits,
'baseline_last_token_logits': baseline_last_token_logits,
'baseline_probs': baseline_probs,
'baseline_top_tokens': baseline_top_tokens,
'target_position': target_position,
'hidden_states': hidden_states,
'feature_activations': feature_activations,
'default_target_token_id': baseline_top_tokens.indices[0].item()
}
def _compute_feature_contributions(
self,
feature_activations: List[torch.Tensor],
feature_set: List[Tuple[int, int]]
) -> Dict[int, torch.Tensor]:
contributions: Dict[int, torch.Tensor] = {}
with torch.no_grad():
for layer_idx, feature_idx in feature_set:
if layer_idx >= len(feature_activations):
continue
features = feature_activations[layer_idx]
if feature_idx >= features.size(-1):
continue
feature_values = features[:, :, feature_idx].detach()
for dest_layer in range(layer_idx, self.clt.n_layers):
decoder_key = f"{layer_idx}_to_{dest_layer}"
if decoder_key not in self.clt.decoders:
continue
decoder = self.clt.decoders[decoder_key]
weight_column = decoder.weight[:, feature_idx]
contrib = torch.einsum('bs,h->bsh', feature_values, weight_column).detach()
if dest_layer in contributions:
contributions[dest_layer] += contrib
else:
contributions[dest_layer] = contrib
return contributions
def _run_with_hooks(
self,
inputs: Dict[str, torch.Tensor],
contributions: Dict[int, torch.Tensor],
intervention_strength: float
):
blocks = self._get_transformer_blocks()
handles: List[Any] = []
def _make_hook(cached_contrib: torch.Tensor):
def hook(module, module_input, module_output):
if isinstance(module_output, torch.Tensor):
target_tensor = module_output
elif isinstance(module_output, (tuple, list)):
target_tensor = module_output[0]
elif hasattr(module_output, "last_hidden_state"):
target_tensor = module_output.last_hidden_state
else:
raise TypeError(
f"Unsupported module output type '{type(module_output)}' for perturbation hook."
)
tensor_contrib = cached_contrib.to(target_tensor.device).to(target_tensor.dtype)
scaled = intervention_strength * tensor_contrib
if isinstance(module_output, torch.Tensor):
return module_output - scaled
elif isinstance(module_output, tuple):
modified = module_output[0] - scaled
return (modified,) + tuple(module_output[1:])
elif isinstance(module_output, list):
modified = [module_output[0] - scaled, *module_output[1:]]
return modified
else:
module_output.last_hidden_state = module_output.last_hidden_state - scaled
return module_output
return hook
try:
for dest_layer, contrib in contributions.items():
if dest_layer >= len(blocks):
continue
handles.append(blocks[dest_layer].register_forward_hook(_make_hook(contrib)))
with torch.no_grad():
outputs = self.model(**inputs, output_hidden_states=True, return_dict=True)
finally:
for handle in handles:
handle.remove()
return outputs
def feature_set_ablation_experiment(
self,
input_text: str,
feature_set: List[Tuple[int, int]],
intervention_strength: float = 5.0,
target_token_id: Optional[int] = None,
top_k: int = 5,
ablation_label: str = "feature_set",
extra_metadata: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
try:
baseline_data = self._prepare_inputs(input_text, top_k)
if target_token_id is None:
target_token_id = baseline_data['default_target_token_id']
feature_set_normalized = [
(int(layer_idx), int(feature_idx)) for layer_idx, feature_idx in feature_set
]
contributions = self._compute_feature_contributions(
baseline_data['feature_activations'],
feature_set_normalized
)
baseline_probs = baseline_data['baseline_probs']
baseline_top_tokens = baseline_data['baseline_top_tokens']
baseline_last_token_logits = baseline_data['baseline_last_token_logits']
target_position = baseline_data['target_position']
hidden_states = baseline_data['hidden_states']
baseline_prob = baseline_probs[target_token_id].item()
baseline_logit = baseline_last_token_logits[target_token_id].item()
baseline_summary = {
'baseline_top_tokens': self._format_top_tokens(baseline_top_tokens),
'baseline_probability': baseline_prob,
'baseline_logit': baseline_logit
}
if not contributions:
result = {
**baseline_summary,
'ablated_top_tokens': baseline_summary['baseline_top_tokens'],
'ablated_probability': baseline_prob,
'ablated_logit': baseline_logit,
'probability_change': 0.0,
'logit_change': 0.0,
'kl_divergence': 0.0,
'entropy_change': 0.0,
'hidden_state_delta_norm': 0.0,
'hidden_state_relative_change': 0.0,
'ablation_flips_top_prediction': False,
'feature_set': [
{'layer': layer_idx, 'feature': feature_idx}
for layer_idx, feature_idx in feature_set_normalized
],
'feature_set_size': len(feature_set_normalized),
'intervention_strength': intervention_strength,
'target_token_id': target_token_id,
'target_token': self.tokenizer.decode([target_token_id]),
'contributing_layers': [],
'ablation_applied': False,
'ablation_type': ablation_label,
'warning': 'no_contributions_found'
}
if extra_metadata:
result.update(extra_metadata)
return result
ablated_outputs = self._run_with_hooks(
baseline_data['inputs'],
contributions,
intervention_strength
)
ablated_logits = ablated_outputs.logits[0, target_position]
ablated_probs = F.softmax(ablated_logits, dim=-1)
ablated_top_tokens = torch.topk(ablated_probs, k=top_k)
ablated_prob = ablated_probs[target_token_id].item()
ablated_logit = ablated_logits[target_token_id].item()
epsilon = 1e-9
kl_divergence = torch.sum(
baseline_probs * (torch.log(baseline_probs + epsilon) - torch.log(ablated_probs + epsilon))
).item()
if not np.isfinite(kl_divergence):
kl_divergence = 0.0
entropy_baseline = -(baseline_probs * torch.log(baseline_probs + epsilon)).sum().item()
entropy_ablated = -(ablated_probs * torch.log(ablated_probs + epsilon)).sum().item()
entropy_change = entropy_ablated - entropy_baseline
if not np.isfinite(entropy_change):
entropy_change = 0.0
baseline_hidden = hidden_states[-1][:, target_position, :]
ablated_hidden = ablated_outputs.hidden_states[-1][:, target_position, :]
hidden_delta_norm = torch.norm(baseline_hidden - ablated_hidden, dim=-1).item()
hidden_baseline_norm = torch.norm(baseline_hidden, dim=-1).item()
hidden_relative_change = hidden_delta_norm / (hidden_baseline_norm + 1e-9)
result = {
**baseline_summary,
'ablated_top_tokens': self._format_top_tokens(ablated_top_tokens),
'ablated_probability': ablated_prob,
'ablated_logit': ablated_logit,
'probability_change': baseline_prob - ablated_prob,
'logit_change': baseline_logit - ablated_logit,
'kl_divergence': kl_divergence,
'entropy_change': entropy_change,
'hidden_state_delta_norm': hidden_delta_norm,
'hidden_state_relative_change': hidden_relative_change,
'ablation_flips_top_prediction': bool(
baseline_top_tokens.indices[0].item() != ablated_top_tokens.indices[0].item()
),
'feature_set': [
{'layer': layer_idx, 'feature': feature_idx}
for layer_idx, feature_idx in feature_set_normalized
],
'feature_set_size': len(feature_set_normalized),
'intervention_strength': intervention_strength,
'target_token_id': target_token_id,
'target_token': self.tokenizer.decode([target_token_id]),
'contributing_layers': sorted(list(contributions.keys())),
'ablation_applied': True,
'ablation_type': ablation_label
}
if extra_metadata:
result.update(extra_metadata)
return result
except Exception as e:
logger.warning(f"Perturbation experiment failed: {e}")
return {
'baseline_top_tokens': [],
'ablated_top_tokens': [],
'feature_set': [
{'layer': layer_idx, 'feature': feature_idx}
for layer_idx, feature_idx in feature_set
],
'feature_set_size': len(feature_set),
'intervention_strength': intervention_strength,
'probability_change': 0.0,
'logit_change': 0.0,
'kl_divergence': 0.0,
'entropy_change': 0.0,
'hidden_state_delta_norm': 0.0,
'hidden_state_relative_change': 0.0,
'ablation_flips_top_prediction': False,
'ablation_applied': False,
'ablation_type': ablation_label,
'error': str(e)
}
def feature_ablation_experiment(
self,
input_text: str,
target_layer: int,
target_feature: int,
intervention_strength: float = 5.0,
target_token_id: Optional[int] = None,
top_k: int = 5,
) -> Dict[str, Any]:
return self.feature_set_ablation_experiment(
input_text=input_text,
feature_set=[(target_layer, target_feature)],
intervention_strength=intervention_strength,
target_token_id=target_token_id,
top_k=top_k,
ablation_label="targeted_feature"
)
def random_feature_ablation_experiment(
self,
input_text: str,
num_features: int = 1,
intervention_strength: float = 5.0,
target_token_id: Optional[int] = None,
top_k: int = 5,
seed: Optional[int] = None
) -> Dict[str, Any]:
rng = random.Random(seed)
num_features = max(1, int(num_features))
feature_set: List[Tuple[int, int]] = []
for _ in range(num_features):
layer_idx = rng.randrange(self.clt.n_layers)
feature_idx = rng.randrange(self.clt.n_features)
feature_set.append((layer_idx, feature_idx))
result = self.feature_set_ablation_experiment(
input_text=input_text,
feature_set=feature_set,
intervention_strength=intervention_strength,
target_token_id=target_token_id,
top_k=top_k,
ablation_label="random_baseline",
extra_metadata={'random_seed': seed}
)
return result
class AttributionGraphsPipeline:
# The main pipeline for the attribution graph analysis.
def __init__(self, config: AttributionGraphConfig):
self.config = config
self.device = torch.device(config.device)
# Load the model and tokenizer.
logger.info(f"Loading OLMo2 7B model from {config.model_path}")
self.tokenizer = AutoTokenizer.from_pretrained(config.model_path)
# Configure model loading based on the device.
if "mps" in config.device:
# MPS supports float16 but not device_map.
self.model = AutoModelForCausalLM.from_pretrained(
config.model_path,
torch_dtype=torch.float16,
device_map=None
).to(self.device)
elif "cuda" in config.device:
self.model = AutoModelForCausalLM.from_pretrained(
config.model_path,
torch_dtype=torch.float16,
device_map="auto"
)
else:
# CPU
self.model = AutoModelForCausalLM.from_pretrained(
config.model_path,
torch_dtype=torch.float32,
device_map=None
).to(self.device)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Initialize the CLT.
model_config = self.model.config.to_dict()
self.clt = CrossLayerTranscoder(model_config, config).to(self.device)
# Initialize the other components.
# cache_dir = Path(RESULTS_DIR) / "feature_interpretations_cache"
# Disable persistent caching to ensure interpretations are prompt-specific and not reused from other contexts.
self.feature_visualizer = FeatureVisualizer(self.tokenizer, cache_dir=None)
self.attribution_graph = AttributionGraph(self.clt, self.tokenizer, config)
self.perturbation_experiments = PerturbationExperiments(self.model, self.clt, self.tokenizer)
logger.info("Attribution Graphs Pipeline initialized successfully")
def train_clt(self, training_texts: List[str]) -> Dict:
# Trains the Cross-Layer Transcoder.
logger.info("Starting CLT training...")
optimizer = torch.optim.Adam(self.clt.parameters(), lr=self.config.learning_rate)
training_stats = {
'reconstruction_losses': [],
'sparsity_losses': [],
'total_losses': []
}
for step in tqdm(range(self.config.training_steps), desc="Training CLT"):
# Sample a batch of texts.
batch_texts = np.random.choice(training_texts, size=self.config.batch_size)
total_loss = 0.0
total_recon_loss = 0.0
total_sparsity_loss = 0.0
for text in batch_texts:
# Tokenize the text.
inputs = self.tokenizer(text, return_tensors="pt", max_length=self.config.max_seq_length,
truncation=True, padding=True).to(self.device)
# Get the model activations.
with torch.no_grad():
outputs = self.model(**inputs, output_hidden_states=True)
hidden_states = outputs.hidden_states[1:]
# Forward pass through the CLT.
feature_activations, reconstructed_outputs = self.clt(hidden_states)
# Compute the reconstruction loss.
recon_loss = 0.0
for i, (target, pred) in enumerate(zip(hidden_states, reconstructed_outputs)):
recon_loss += F.mse_loss(pred, target)
# Compute the sparsity loss.
sparsity_loss = 0.0
for features in feature_activations:
sparsity_loss += torch.mean(torch.tanh(self.config.sparsity_lambda * features))
# Total loss.
loss = (self.config.reconstruction_loss_weight * recon_loss +
self.config.sparsity_lambda * sparsity_loss)
total_loss += loss
total_recon_loss += recon_loss
total_sparsity_loss += sparsity_loss
# Average the losses.
total_loss /= self.config.batch_size
total_recon_loss /= self.config.batch_size
total_sparsity_loss /= self.config.batch_size
# Backward pass.
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
# Log the progress.
training_stats['total_losses'].append(total_loss.item())
training_stats['reconstruction_losses'].append(total_recon_loss.item())
training_stats['sparsity_losses'].append(total_sparsity_loss.item())
if step % 100 == 0:
logger.info(f"Step {step}: Total Loss = {total_loss.item():.4f}, "
f"Recon Loss = {total_recon_loss.item():.4f}, "
f"Sparsity Loss = {total_sparsity_loss.item():.4f}")
logger.info("CLT training completed")
return training_stats
def analyze_prompt(self, prompt: str, target_token_idx: int = -1) -> Dict:
# Performs a complete analysis for a single prompt.
logger.info(f"Analyzing prompt: '{prompt[:50]}...'")
# Tokenize the prompt.
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
input_tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
# Get the model activations.
with torch.no_grad():
outputs = self.model(**inputs, output_hidden_states=True)
hidden_states = outputs.hidden_states[1:]
# Forward pass through the CLT.
feature_activations, reconstructed_outputs = self.clt(hidden_states)
logger.info(" > Starting feature visualization and interpretation...")
feature_visualizations = {}
for layer_idx, features in enumerate(feature_activations):
logger.info(f" - Processing Layer {layer_idx}...")
layer_viz = {}
# Analyze the top features for this layer.
# features shape: [batch_size, seq_len, n_features]
feature_importance = torch.mean(features, dim=(0, 1)) # Average over batch and sequence
top_features = torch.topk(feature_importance, k=min(5, feature_importance.size(0))).indices
for feat_idx in top_features:
viz = self.feature_visualizer.visualize_feature(
feat_idx.item(), layer_idx, features[0], input_tokens
)
interpretation = self.feature_visualizer.interpret_feature(
feat_idx.item(), layer_idx, viz, self.config.qwen_api_config
)
viz['interpretation'] = interpretation
layer_viz[f"feature_{feat_idx.item()}"] = viz
feature_visualizations[f"layer_{layer_idx}"] = layer_viz
# Construct the attribution graph.
graph = self.attribution_graph.construct_graph(
input_tokens, feature_activations, target_token_idx
)
# Prune the graph.
pruned_graph = self.attribution_graph.prune_graph(self.config.pruning_threshold)
# Analyze the most important paths.
important_paths = []
if len(pruned_graph.nodes()) > 0:
# Find paths from embeddings to the output.
embedding_nodes = [node for node, type_ in self.attribution_graph.node_types.items()
if type_ == "embedding" and node in pruned_graph]
output_nodes = [node for node, type_ in self.attribution_graph.node_types.items()
if type_ == "output" and node in pruned_graph]
for emb_node in embedding_nodes[:3]: # Top 3 embedding nodes
for out_node in output_nodes:
try:
paths = list(nx.all_simple_paths(pruned_graph, emb_node, out_node, cutoff=5))
for path in paths[:2]: # Top 2 paths
path_weight = 1.0
for i in range(len(path) - 1):
edge_weight = self.attribution_graph.edge_weights.get(
(path[i], path[i+1]), 0.0
)
path_weight *= abs(edge_weight)
important_paths.append({
'path': path,
'weight': path_weight,
'description': self._describe_path(path)
})
except nx.NetworkXNoPath:
continue
# Sort paths by importance.
important_paths.sort(key=lambda x: x['weight'], reverse=True)
# Run targeted perturbation experiments for highlighted features.
targeted_feature_ablation_results: List[Dict[str, Any]] = []
max_total_experiments = self.config.max_ablation_experiments
per_layer_limit = self.config.ablation_features_per_layer
total_run = 0
stop_all = False
for layer_name, layer_features in feature_visualizations.items():
if stop_all:
break
try:
layer_idx = int(layer_name.split('_')[1])
except (IndexError, ValueError):
logger.warning(f"Unable to parse layer index from key '{layer_name}'. Skipping perturbation experiments for this layer.")
continue
feature_items = list(layer_features.items())
if per_layer_limit is not None:
feature_items = feature_items[:per_layer_limit]
for feature_name, feature_payload in feature_items:
if max_total_experiments is not None and total_run >= max_total_experiments:
stop_all = True
break
try:
feature_idx = int(feature_name.split('_')[1])
except (IndexError, ValueError):
logger.warning(f"Unable to parse feature index from key '{feature_name}'. Skipping perturbation experiment.")
continue
ablation = self.perturbation_experiments.feature_ablation_experiment(
prompt,
layer_idx,
feature_idx,
intervention_strength=self.config.intervention_strength,
target_token_id=None,
top_k=self.config.ablation_top_k_tokens,
)
ablation.update({
'layer_name': layer_name,
'feature_name': feature_name,
'feature_interpretation': feature_payload.get('interpretation'),
'feature_max_activation': feature_payload.get('max_activation'),
})
targeted_feature_ablation_results.append(ablation)
total_run += 1
# Random baseline perturbations for comparison.
random_baseline_results: List[Dict[str, Any]] = []
baseline_trials = self.config.random_baseline_trials
if baseline_trials and baseline_trials > 0:
num_features = self.config.random_baseline_features or 1
for trial_idx in range(baseline_trials):
seed = None
if self.config.random_baseline_seed is not None:
seed = self.config.random_baseline_seed + trial_idx
random_result = self.perturbation_experiments.random_feature_ablation_experiment(
prompt,
num_features=num_features,
intervention_strength=self.config.intervention_strength,
target_token_id=None,
top_k=self.config.ablation_top_k_tokens,
seed=seed
)
random_result['trial_index'] = trial_idx
random_baseline_results.append(random_result)
# Path-level ablations for the most important circuits.
path_ablation_results: List[Dict[str, Any]] = []
max_paths = self.config.path_ablation_top_k or 0
extracted_paths: List[Dict[str, Any]] = []
if max_paths > 0 and important_paths:
for path_info in important_paths[:max_paths]:
feature_set = self._extract_feature_set_from_path(path_info.get('path', []))
if not feature_set:
continue
path_result = self.perturbation_experiments.feature_set_ablation_experiment(
prompt,
feature_set=feature_set,
intervention_strength=self.config.intervention_strength,
target_token_id=None,
top_k=self.config.ablation_top_k_tokens,
ablation_label="path",
extra_metadata={
'path_nodes': path_info.get('path'),
'path_description': path_info.get('description'),
'path_weight': path_info.get('weight')
}
)
path_ablation_results.append(path_result)
enriched_path_info = path_info.copy()
enriched_path_info['feature_set'] = feature_set
extracted_paths.append(enriched_path_info)
random_path_baseline_results: List[Dict[str, Any]] = []
path_baseline_trials = self.config.random_path_baseline_trials
if path_baseline_trials and path_baseline_trials > 0 and extracted_paths:
rng = random.Random(self.config.random_baseline_seed)
available_nodes = [
data for data in self.attribution_graph.node_types.items()
if data[1] == "feature"
]
for trial in range(path_baseline_trials):
selected_path = extracted_paths[min(trial % len(extracted_paths), len(extracted_paths) - 1)]
target_length = len(selected_path.get('feature_set', []))
source_layers = [layer for layer, _ in selected_path.get('feature_set', [])]
min_layer = min(source_layers) if source_layers else 0
max_layer = max(source_layers) if source_layers else self.clt.n_layers - 1
excluded_keys = {
(layer, feature)
for layer, feature in selected_path.get('feature_set', [])
}
random_feature_set: List[Tuple[int, int]] = []
attempts = 0
while len(random_feature_set) < target_length and attempts < target_length * 5:
attempts += 1
if not available_nodes:
break
node_name, node_type = rng.choice(available_nodes)
metadata = self.attribution_graph.feature_metadata.get(node_name)
if metadata is None:
continue
if metadata['layer'] < min_layer or metadata['layer'] > max_layer:
continue
key = (metadata['layer'], metadata['feature_index'])
if key in excluded_keys:
continue
if key not in random_feature_set:
random_feature_set.append(key)
if not random_feature_set:
continue
if len(random_feature_set) < max(1, target_length):
continue
random_path_result = self.perturbation_experiments.feature_set_ablation_experiment(
prompt,
feature_set=random_feature_set,
intervention_strength=self.config.intervention_strength,
target_token_id=None,
top_k=self.config.ablation_top_k_tokens,
ablation_label="random_path_baseline",
extra_metadata={
'trial_index': trial,
'sampled_feature_set': random_feature_set,
'reference_path_weight': selected_path.get('weight')
}
)
random_path_baseline_results.append(random_path_result)
targeted_summary = self._summarize_ablation_results(targeted_feature_ablation_results)
random_summary = self._summarize_ablation_results(random_baseline_results)
path_summary = self._summarize_ablation_results(path_ablation_results)
random_path_summary = self._summarize_ablation_results(random_path_baseline_results)
summary_statistics = {
'targeted': targeted_summary,
'random_baseline': random_summary,
'path': path_summary,
'random_path_baseline': random_path_summary,
'target_minus_random_abs_probability_change': targeted_summary.get('avg_abs_probability_change', 0.0) - random_summary.get('avg_abs_probability_change', 0.0),
'target_flip_rate_minus_random': targeted_summary.get('flip_rate', 0.0) - random_summary.get('flip_rate', 0.0),
'path_minus_random_abs_probability_change': path_summary.get('avg_abs_probability_change', 0.0) - random_path_summary.get('avg_abs_probability_change', 0.0),
'path_flip_rate_minus_random': path_summary.get('flip_rate', 0.0) - random_path_summary.get('flip_rate', 0.0)
}
results = {
'prompt': prompt,
'input_tokens': input_tokens,
'feature_visualizations': feature_visualizations,
'full_graph_stats': {
'n_nodes': len(graph.nodes()),
'n_edges': len(graph.edges()),
'node_types': dict(self.attribution_graph.node_types)
},
'pruned_graph_stats': {
'n_nodes': len(pruned_graph.nodes()),
'n_edges': len(pruned_graph.edges())
},
'important_paths': important_paths[:5], # Top 5 paths
'graph': pruned_graph,
'perturbation_experiments': targeted_feature_ablation_results,
'random_baseline_experiments': random_baseline_results,
'path_ablation_experiments': path_ablation_results,
'random_path_baseline_experiments': random_path_baseline_results,
'summary_statistics': summary_statistics
}
return results
def _extract_feature_set_from_path(self, path: List[str]) -> List[Tuple[int, int]]:
feature_set: List[Tuple[int, int]] = []
seen: Set[Tuple[int, int]] = set()
for node in path:
if not isinstance(node, str):
continue
if not node.startswith("feat_"):
continue
parts = node.split('_')
try:
layer_str = parts[1] # e.g., "L0"
feature_str = parts[3] # e.g., "F123"
layer_idx = int(layer_str[1:])
feature_idx = int(feature_str[1:])
except (IndexError, ValueError):
continue
key = (layer_idx, feature_idx)
if key not in seen:
seen.add(key)
feature_set.append(key)
return feature_set
def _summarize_ablation_results(self, experiments: List[Dict[str, Any]]) -> Dict[str, Any]:
summary = {
'count': len(experiments),
'avg_probability_change': 0.0,
'avg_abs_probability_change': 0.0,
'std_probability_change': 0.0,
'avg_logit_change': 0.0,
'avg_abs_logit_change': 0.0,
'std_logit_change': 0.0,
'avg_kl_divergence': 0.0,
'avg_entropy_change': 0.0,
'avg_hidden_state_delta_norm': 0.0,
'avg_hidden_state_relative_change': 0.0,
'flip_rate': 0.0,
'count_flipped': 0
}
if not experiments:
return summary
probability_changes = np.array([exp.get('probability_change', 0.0) for exp in experiments], dtype=float)
logit_changes = np.array([exp.get('logit_change', 0.0) for exp in experiments], dtype=float)
kl_divergences = np.array([exp.get('kl_divergence', 0.0) for exp in experiments], dtype=float)
entropy_changes = np.array([exp.get('entropy_change', 0.0) for exp in experiments], dtype=float)
hidden_norms = np.array([exp.get('hidden_state_delta_norm', 0.0) for exp in experiments], dtype=float)
hidden_relative = np.array([exp.get('hidden_state_relative_change', 0.0) for exp in experiments], dtype=float)
flip_flags = np.array([1.0 if exp.get('ablation_flips_top_prediction') else 0.0 for exp in experiments], dtype=float)
# Helper to safely compute mean/std ignoring NaNs
def safe_mean(arr):
with np.errstate(all='ignore'):
m = np.nanmean(arr)
return float(m) if np.isfinite(m) else 0.0
def safe_std(arr):
with np.errstate(all='ignore'):
s = np.nanstd(arr)
return float(s) if np.isfinite(s) else 0.0
summary.update({
'avg_probability_change': safe_mean(probability_changes),
'avg_abs_probability_change': safe_mean(np.abs(probability_changes)),
'std_probability_change': safe_std(probability_changes),
'avg_logit_change': safe_mean(logit_changes),
'avg_abs_logit_change': safe_mean(np.abs(logit_changes)),
'std_logit_change': safe_std(logit_changes),
'avg_kl_divergence': safe_mean(kl_divergences),
'avg_entropy_change': safe_mean(entropy_changes),
'avg_hidden_state_delta_norm': safe_mean(hidden_norms),
'avg_hidden_state_relative_change': safe_mean(hidden_relative),
'flip_rate': safe_mean(flip_flags),
'count_flipped': int(np.round(np.nansum(flip_flags)))
})
return summary
def analyze_prompts_batch(self, prompts: List[str]) -> Dict[str, Any]:
analyses: Dict[str, Dict[str, Any]] = {}
aggregated_targeted: List[Dict[str, Any]] = []
aggregated_random: List[Dict[str, Any]] = []
aggregated_path: List[Dict[str, Any]] = []
for idx, prompt in enumerate(prompts):
logger.info(f"[Batch Eval] Processing prompt {idx + 1}/{len(prompts)}")
analysis = self.analyze_prompt(prompt)
key = f"prompt_{idx + 1}"
analyses[key] = analysis
aggregated_targeted.extend(analysis.get('perturbation_experiments', []))
aggregated_random.extend(analysis.get('random_baseline_experiments', []))
aggregated_path.extend(analysis.get('path_ablation_experiments', []))
aggregate_summary = {
'targeted': self._summarize_ablation_results(aggregated_targeted),
'random_baseline': self._summarize_ablation_results(aggregated_random),
'path': self._summarize_ablation_results(aggregated_path),
'random_path_baseline': self._summarize_ablation_results(
[
exp
for analysis in analyses.values()
for exp in analysis.get('random_path_baseline_experiments', [])
]
)
}
aggregate_summary['target_minus_random_abs_probability_change'] = (
aggregate_summary['targeted'].get('avg_abs_probability_change', 0.0)
- aggregate_summary['random_baseline'].get('avg_abs_probability_change', 0.0)
)
aggregate_summary['target_flip_rate_minus_random'] = (
aggregate_summary['targeted'].get('flip_rate', 0.0)
- aggregate_summary['random_baseline'].get('flip_rate', 0.0)
)
aggregate_summary['path_minus_random_abs_probability_change'] = (
aggregate_summary['path'].get('avg_abs_probability_change', 0.0)
- aggregate_summary['random_path_baseline'].get('avg_abs_probability_change', 0.0)
)
aggregate_summary['path_flip_rate_minus_random'] = (
aggregate_summary['path'].get('flip_rate', 0.0)
- aggregate_summary['random_path_baseline'].get('flip_rate', 0.0)
)
return {
'analyses': analyses,
'aggregate_summary': aggregate_summary,
'prompt_texts': prompts
}
def _describe_path(self, path: List[str]) -> str:
# Generates a human-readable description of a path.
descriptions = []
for node in path:
if self.attribution_graph.node_types[node] == "embedding":
token = node.split('_')[2]
descriptions.append(f"Token '{token}'")
elif self.attribution_graph.node_types[node] == "feature":
parts = node.split('_')
layer = parts[1][1:] # Remove 'L'
feature = parts[3][1:] # Remove 'F'
# Try to get the interpretation.
key = f"L{layer}_F{feature}"
interpretation = self.feature_visualizer.feature_interpretations.get(key, "unknown")
descriptions.append(f"Feature L{layer}F{feature} ({interpretation})")
elif self.attribution_graph.node_types[node] == "output":
descriptions.append("Output")
return " → ".join(descriptions)
def save_results(self, results: Dict, save_path: str):
# Saves the analysis results to a file.
serializable_results = copy.deepcopy(results)
if 'graph' in serializable_results:
serializable_results['graph'] = nx.node_link_data(serializable_results['graph'])
analyses = serializable_results.get('analyses', {})
for key, analysis in analyses.items():
if 'graph' in analysis:
analysis['graph'] = nx.node_link_data(analysis['graph'])
with open(save_path, 'w') as f:
json.dump(serializable_results, f, indent=2, default=str)
logger.info(f"Results saved to {save_path}")
def save_clt(self, path: str):
# Saves the trained CLT model.
torch.save(self.clt.state_dict(), path)
logger.info(f"CLT model saved to {path}")
def load_clt(self, path: str):
# Loads a trained CLT model.
self.clt.load_state_dict(torch.load(path, map_location=self.device))
self.clt.to(self.device)
self.clt.eval() # Set the model to evaluation mode
logger.info(f"Loaded CLT model from {path}")
# --- Configuration ---
MAX_SEQ_LEN = 256
N_FEATURES_PER_LAYER = 512
TRAINING_STEPS = 2500
BATCH_SIZE = 64
LEARNING_RATE = 1e-3
# Prompts for generating the final analysis.
ANALYSIS_PROMPTS = [
"The capital of France is",
"def factorial(n):",
"The literary device in the phrase 'The wind whispered through the trees' is"
]
# A larger set of prompts for training.
TRAINING_PROMPTS = [
"The capital of France is", "To be or not to be, that is the", "A stitch in time saves",
"The first person to walk on the moon was", "The chemical formula for water is H2O.",
"Translate to German: 'The cat sits on the mat.'", "def factorial(n):", "import numpy as np",
"The main ingredients in a pizza are", "What is the powerhouse of the cell?",
"The equation E=mc^2 relates energy to", "Continue the story: Once upon a time, there was a",
"Classify the sentiment: 'I am overjoyed!'", "Extract the entities: 'Apple Inc. is in Cupertino.'",
"What is the next number: 2, 4, 8, 16, __?", "A rolling stone gathers no",
"The opposite of hot is", "import torch", "import pandas as pd", "class MyClass:",
"def __init__(self):", "The primary colors are", "What is the capital of Japan?",
"Who wrote 'Hamlet'?", "The square root of 64 is", "The sun rises in the",
"The Pacific Ocean is the largest ocean on Earth.", "The mitochondria is the powerhouse of the cell.",
"What is the capital of Mongolia?", "The movie 'The Matrix' can be classified into the following genre:",
"The French translation of 'I would like to order a coffee, please.' is:",
"The literary device in the phrase 'The wind whispered through the trees' is",
"A Python function that calculates the factorial of a number is:",
"The main ingredient in a Negroni cocktail is",
"Summarize the plot of 'Hamlet' in one sentence:",
"The sentence 'The cake was eaten by the dog' is in the following voice:",
"A good headline for an article about a new breakthrough in battery technology would be:"
]
# --- Qwen API for Feature Interpretation ---
@torch.no_grad()
def get_feature_interpretation_with_qwen(
api_config: dict,
top_tokens: list[str],
feature_name: str,
layer_index: int,
max_retries: int = 3,
initial_backoff: float = 2.0
) -> str:
# Generates a high-quality interpretation for a feature using the Qwen API.
if not api_config or not api_config.get('api_key'):
logger.warning("Qwen API not configured. Skipping interpretation.")
return "API not configured"
headers = {
"Authorization": f"Bearer {api_config['api_key']}",
"Content-Type": "application/json"
}
# Create a specialized prompt.
prompt_text = f"""
You are an expert in transformer interpretability. A feature in a language model (feature '{feature_name}' at layer {layer_index}) is most strongly activated by the following tokens:
{', '.join(f"'{token}'" for token in top_tokens)}
Based *only* on these tokens, what is the most likely function or role of this feature?
Your answer must be a short, concise phrase (e.g., "Detecting proper nouns", "Identifying JSON syntax", "Completing lists", "Recognizing negative sentiment"). Do not write a full sentence.
"""
data = {
"model": api_config["model"],
"messages": [
{
"role": "user",
"content": [{"type": "text", "text": prompt_text}]
}
],
"max_tokens": 50,
"temperature": 0.1,
"top_p": 0.9,
"seed": 42
}
logger.info(f" > Interpreting {feature_name} (Layer {layer_index})...")
for attempt in range(max_retries):
try:
logger.info(f" - Attempt {attempt + 1}/{max_retries}: Sending request to Qwen API...")
response = requests.post(
f"{api_config['api_endpoint']}/chat/completions",
headers=headers,
json=data,
timeout=60
)
response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
result = response.json()
interpretation = result["choices"][0]["message"]["content"].strip()
# Remove quotes from the output.
if interpretation.startswith('"') and interpretation.endswith('"'):
interpretation = interpretation[1:-1]
logger.info(f" - Success! Interpretation: '{interpretation}'")
return interpretation
except requests.exceptions.RequestException as e:
logger.warning(f" - Qwen API request failed (Attempt {attempt + 1}/{max_retries}): {e}")
if attempt < max_retries - 1:
backoff_time = initial_backoff * (2 ** attempt)
logger.info(f" - Retrying in {backoff_time:.1f} seconds...")
time.sleep(backoff_time)
else:
logger.error(" - Max retries reached. Failing.")
return f"API Error: {e}"
except (KeyError, IndexError) as e:
logger.error(f" - Failed to parse Qwen API response: {e}")
return "API Error: Invalid response format"
finally:
# Add a delay to respect API rate limits.
time.sleep(2.1)
return "API Error: Max retries exceeded"
def train_transcoder(transcoder, model, tokenizer, training_prompts, device, steps=1000, batch_size=16, optimizer=None):
# Trains the Cross-Layer Transcoder.
transcoder.train()
# Use a progress bar for visual feedback.
progress_bar = tqdm(range(steps), desc="Training CLT")
for step in progress_bar:
# Get a random batch of prompts.
batch_prompts = random.choices(training_prompts, k=batch_size)
# Tokenize the batch.
inputs = tokenizer(
batch_prompts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=MAX_SEQ_LEN
)
inputs = {k: v.to(device) for k, v in inputs.items()}
# Get the model activations.
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
hidden_states = outputs.hidden_states[1:]
# Forward pass through the CLT.
feature_activations, reconstructed_outputs = transcoder(hidden_states)
# Compute the reconstruction loss.
recon_loss = 0.0
for i, (target, pred) in enumerate(zip(hidden_states, reconstructed_outputs)):
recon_loss += F.mse_loss(pred, target)
# Compute the sparsity loss.
sparsity_loss = 0.0
for features in feature_activations:
sparsity_loss += torch.mean(torch.tanh(0.01 * features)) # Use config.sparsity_lambda
# Total loss.
loss = (0.8 * recon_loss + 0.2 * sparsity_loss) # Use config.reconstruction_loss_weight
if optimizer:
optimizer.zero_grad()
loss.backward()
optimizer.step()
progress_bar.set_postfix({
"Recon Loss": f"{recon_loss.item():.4f}",
"Sparsity Loss": f"{sparsity_loss.item():.4f}",
"Total Loss": f"{loss.item():.4f}"
})
def generate_feature_visualizations(transcoder, model, tokenizer, prompt, device, qwen_api_config=None, graph_config: Optional[AttributionGraphConfig] = None):
# Generates feature visualizations and interpretations for a prompt.
# Tokenize the prompt.
inputs = tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=MAX_SEQ_LEN
)
inputs = {k: v.to(device) for k, v in inputs.items()}
# Get the model activations.
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
hidden_states = outputs.hidden_states[1:]
# Forward pass through the CLT.
feature_activations, reconstructed_outputs = transcoder(hidden_states)
# Visualize the features.
feature_visualizations = {}
for layer_idx, features in enumerate(feature_activations):
layer_viz = {}
# Analyze the top features for this layer.
# features shape: [batch_size, seq_len, n_features]
feature_importance = torch.mean(features, dim=(0, 1)) # Average over batch and sequence
top_features = torch.topk(feature_importance, k=min(5, feature_importance.size(0))).indices
for feat_idx in top_features:
viz = FeatureVisualizer(tokenizer).visualize_feature(
feat_idx.item(), layer_idx, features[0], tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
)
interpretation = FeatureVisualizer(tokenizer).interpret_feature(
feat_idx.item(), layer_idx, viz, qwen_api_config
)
viz['interpretation'] = interpretation
layer_viz[f"feature_{feat_idx.item()}"] = viz
feature_visualizations[f"layer_{layer_idx}"] = layer_viz
# Construct the attribution graph.
if graph_config is None:
graph_config = AttributionGraphConfig()
attribution_graph = AttributionGraph(transcoder, tokenizer, graph_config)
graph = attribution_graph.construct_graph(
tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]), feature_activations, -1 # No target token for visualization
)
# Prune the graph.
pruned_graph = attribution_graph.prune_graph(0.8) # Use config.pruning_threshold
# Analyze the most important paths.
important_paths = []
if len(pruned_graph.nodes()) > 0:
# Find paths from embeddings to the output.
embedding_nodes = [node for node, type_ in attribution_graph.node_types.items()
if type_ == "embedding" and node in pruned_graph]
output_nodes = [node for node, type_ in attribution_graph.node_types.items()
if type_ == "output" and node in pruned_graph]
for emb_node in embedding_nodes[:3]: # Top 3 embedding nodes
for out_node in output_nodes:
try:
paths = list(nx.all_simple_paths(pruned_graph, emb_node, out_node, cutoff=5))
for path in paths[:2]: # Top 2 paths
path_weight = 1.0
for i in range(len(path) - 1):
edge_weight = attribution_graph.edge_weights.get(
(path[i], path[i+1]), 0.0
)
path_weight *= abs(edge_weight)
important_paths.append({
'path': path,
'weight': path_weight,
'description': attribution_graph._describe_path(path)
})
except nx.NetworkXNoPath:
continue
# Sort paths by importance.
important_paths.sort(key=lambda x: x['weight'], reverse=True)
return {
"prompt": prompt,
"full_graph_stats": {
"n_nodes": len(graph.nodes()),
"n_edges": len(graph.edges()),
"node_types": dict(attribution_graph.node_types)
},
"pruned_graph_stats": {
"n_nodes": len(pruned_graph.nodes()),
"n_edges": len(pruned_graph.edges())
},
"feature_visualizations": feature_visualizations,
"important_paths": important_paths[:5] # Top 5 paths
}
def main():
# Main function to run the analysis for a single prompt.
# Set a seed for reproducibility.
set_seed()
# --- Argument Parser ---
parser = argparse.ArgumentParser(description="Run Attribution Graph analysis for a single prompt.")
parser.add_argument(
'--prompt-index',
type=int,
required=True,
help=f"The 0-based index of the prompt to analyze from the ANALYSIS_PROMPTS list (0 to {len(ANALYSIS_PROMPTS) - 1})."
)
parser.add_argument(
'--force-retrain-clt',
action='store_true',
help="Force re-training of the Cross-Layer Transcoder, even if a saved model exists."
)
parser.add_argument(
'--batch-eval',
action='store_true',
help="Analyze all predefined prompts and compute aggregate faithfulness metrics."
)
args = parser.parse_args()
prompt_idx = args.prompt_index
if not (0 <= prompt_idx < len(ANALYSIS_PROMPTS)):
print(f"❌ Error: --prompt-index must be between 0 and {len(ANALYSIS_PROMPTS) - 1}.")
return
# Get the API config from the utility function.
qwen_api_config = init_qwen_api()
# Configuration - Use consistent settings matching trained CLT
config = AttributionGraphConfig(
model_path="./models/OLMo-2-1124-7B",
n_features_per_layer=512, # Match trained CLT
training_steps=500,
batch_size=4,
max_seq_length=256,
learning_rate=1e-4,
sparsity_lambda=1e-3, # Match training (L1 sparsity)
graph_feature_activation_threshold=0.01,
graph_edge_weight_threshold=0.003,
graph_max_features_per_layer=40,
graph_max_edges_per_node=20,
qwen_api_config=qwen_api_config
)
print("Attribution Graphs for OLMo2 7B - Single Prompt Pipeline")
print("=" * 50)
print(f"Model path: {config.model_path}")
print(f"Device: {config.device}")
try:
# Initialize the full pipeline.
print("🚀 Initializing Attribution Graphs Pipeline...")
pipeline = AttributionGraphsPipeline(config)
print("✓ Pipeline initialized successfully")
print()
# Load an existing CLT model or train a new one.
if os.path.exists(CLT_SAVE_PATH) and not args.force_retrain_clt:
print(f"🧠 Loading existing CLT model from {CLT_SAVE_PATH}...")
pipeline.load_clt(CLT_SAVE_PATH)
print("✓ CLT model loaded successfully.")
else:
if args.force_retrain_clt and os.path.exists(CLT_SAVE_PATH):
print("��‍♂️ --force-retrain-clt flag is set. Overwriting existing model.")
# Train a new CLT model.
print("📚 Training a new CLT model...")
print(f" Training on {len(TRAINING_PROMPTS)} example texts...")
training_stats = pipeline.train_clt(TRAINING_PROMPTS)
print("✓ CLT training completed.")
# Save the training statistics.
stats_save_path = os.path.join(RESULTS_DIR, "clt_training_stats.json")
with open(stats_save_path, 'w') as f:
json.dump(training_stats, f, indent=2)
print(f" Saved training stats to {stats_save_path}")
# Save the new model.
pipeline.save_clt(CLT_SAVE_PATH)
print(f" Saved trained model to {CLT_SAVE_PATH} for future use.")
print()
if args.batch_eval:
print("📊 Running batch faithfulness evaluation across all prompts...")
batch_payload = pipeline.analyze_prompts_batch(ANALYSIS_PROMPTS)
final_results = copy.deepcopy(batch_payload)
final_results['config'] = config.__dict__
final_results['timestamp'] = str(time.time())
for analysis_entry in final_results['analyses'].values():
analysis_entry.pop('graph', None)
batch_save_path = os.path.join(RESULTS_DIR, "attribution_graphs_batch_results.json")
pipeline.save_results(final_results, batch_save_path)
print(f"💾 Batch results saved to {batch_save_path}")
aggregate_summary = batch_payload['aggregate_summary']
targeted_summary = aggregate_summary.get('targeted', {})
random_summary = aggregate_summary.get('random_baseline', {})
path_summary = aggregate_summary.get('path', {})
def _format_summary(label: str, summary: Dict[str, Any]) -> str:
return (
f"{label}: count={summary.get('count', 0)}, "
f"avg|Δp|={summary.get('avg_abs_probability_change', 0.0):.4f}, "
f"flip_rate={summary.get('flip_rate', 0.0):.2%}"
)
print("📈 Aggregate faithfulness summary")
print(f" {_format_summary('Targeted', targeted_summary)}")
print(f" {_format_summary('Random baseline', random_summary)}")
print(f" {_format_summary('Path', path_summary)}")
print(f" {_format_summary('Random path baseline', aggregate_summary.get('random_path_baseline', {}))}")
diff_abs = aggregate_summary.get('target_minus_random_abs_probability_change', 0.0)
diff_flip = aggregate_summary.get('target_flip_rate_minus_random', 0.0)
path_diff_abs = aggregate_summary.get('path_minus_random_abs_probability_change', 0.0)
path_diff_flip = aggregate_summary.get('path_flip_rate_minus_random', 0.0)
print(f" Targeted vs Random |Δp| difference: {diff_abs:.4f}")
print(f" Targeted vs Random flip rate difference: {diff_flip:.4f}")
print(f" Path vs Random path |Δp| difference: {path_diff_abs:.4f}")
print(f" Path vs Random path flip rate difference: {path_diff_flip:.4f}")
print("\n🎉 Batch evaluation completed successfully!")
return
# Analyze the selected prompt.
prompt_to_analyze = ANALYSIS_PROMPTS[prompt_idx]
print(f"🔍 Analyzing prompt {prompt_idx + 1}/{len(ANALYSIS_PROMPTS)}: '{prompt_to_analyze}'")
analysis = pipeline.analyze_prompt(prompt_to_analyze, target_token_idx=-1)
# Display the key results.
print(f" ✓ Tokenized into {len(analysis['input_tokens'])} tokens")
print(f" ✓ Full graph: {analysis['full_graph_stats']['n_nodes']} nodes, {analysis['full_graph_stats']['n_edges']} edges")
print(f" ✓ Pruned graph: {analysis['pruned_graph_stats']['n_nodes']} nodes, {analysis['pruned_graph_stats']['n_edges']} edges")
# Show the top features.
print(" 📊 Top active features:")
feature_layers_items = list(analysis['feature_visualizations'].items())
if config.summary_max_layers is not None:
feature_layers_items = feature_layers_items[:config.summary_max_layers]
for layer_name, layer_features in feature_layers_items:
print(f" {layer_name}:")
feature_items = layer_features.items()
if config.summary_features_per_layer is not None:
feature_items = list(feature_items)[:config.summary_features_per_layer]
for feat_name, feat_data in feature_items:
print(f" {feat_name}: {feat_data['interpretation']} (max: {feat_data['max_activation']:.3f})")
print()
# Summarize perturbation experiments and baselines.
print("🧪 Targeted feature ablations:")
targeted_results = analysis.get('perturbation_experiments', [])
if targeted_results:
for experiment in targeted_results:
layer_name = experiment.get('layer_name', f"L{experiment.get('feature_set', [{}])[0].get('layer', '?')}")
feature_name = experiment.get('feature_name', f"F{experiment.get('feature_set', [{}])[0].get('feature', '?')}")
prob_delta = experiment.get('probability_change', 0.0)
logit_delta = experiment.get('logit_change', 0.0)
flips = experiment.get('ablation_flips_top_prediction', False)
print(f" {layer_name}/{feature_name}: Δp={prob_delta:.4f}, Δlogit={logit_delta:.4f}, flips_top={flips}")
else:
print(" - No targeted ablations were recorded.")
print("\n🎲 Random baseline ablations:")
random_baseline = analysis.get('random_baseline_experiments', [])
if random_baseline:
for experiment in random_baseline:
prob_delta = experiment.get('probability_change', 0.0)
logit_delta = experiment.get('logit_change', 0.0)
flips = experiment.get('ablation_flips_top_prediction', False)
trial_idx = experiment.get('trial_index', '?')
print(f" Trial {trial_idx}: Δp={prob_delta:.4f}, Δlogit={logit_delta:.4f}, flips_top={flips}")
else:
print(" - No random baseline trials were run.")
print("\n🛤️ Path ablations:")
path_results = analysis.get('path_ablation_experiments', [])
if path_results:
for path_exp in path_results:
description = path_exp.get('path_description', 'Path')
prob_delta = path_exp.get('probability_change', 0.0)
logit_delta = path_exp.get('logit_change', 0.0)
flips = path_exp.get('ablation_flips_top_prediction', False)
print(f" {description}: Δp={prob_delta:.4f}, Δlogit={logit_delta:.4f}, flips_top={flips}")
else:
print(" - No path ablations were run.")
summary_stats = analysis.get('summary_statistics', {})
targeted_summary = summary_stats.get('targeted', {})
random_summary = summary_stats.get('random_baseline', {})
path_summary = summary_stats.get('path', {})
random_path_summary = summary_stats.get('random_path_baseline', {})
print("\n📈 Summary statistics:")
print(f" Targeted: avg|Δp|={targeted_summary.get('avg_abs_probability_change', 0.0):.4f}, flip_rate={targeted_summary.get('flip_rate', 0.0):.2%}")
print(f" Random baseline: avg|Δp|={random_summary.get('avg_abs_probability_change', 0.0):.4f}, flip_rate={random_summary.get('flip_rate', 0.0):.2%}")
print(f" Path: avg|Δp|={path_summary.get('avg_abs_probability_change', 0.0):.4f}, flip_rate={path_summary.get('flip_rate', 0.0):.2%}")
print(f" Random path baseline: avg|Δp|={random_path_summary.get('avg_abs_probability_change', 0.0):.4f}, flip_rate={random_path_summary.get('flip_rate', 0.0):.2%}")
print(f" Targeted vs Random |Δp| diff: {summary_stats.get('target_minus_random_abs_probability_change', 0.0):.4f}")
print(f" Targeted vs Random flip diff: {summary_stats.get('target_flip_rate_minus_random', 0.0):.4f}")
print(f" Path vs Random path |Δp| diff: {summary_stats.get('path_minus_random_abs_probability_change', 0.0):.4f}")
print(f" Path vs Random path flip diff: {summary_stats.get('path_flip_rate_minus_random', 0.0):.4f}")
print("\n✓ Faithfulness experiments summarized\n")
# Generate a visualization for the prompt.
print("📈 Generating visualization...")
if 'graph' in analysis and analysis['pruned_graph_stats']['n_nodes'] > 0:
viz_path = os.path.join(RESULTS_DIR, f"attribution_graph_prompt_{prompt_idx + 1}.png")
pipeline.attribution_graph.visualize_graph(analysis['graph'], save_path=viz_path)
print(f" ✓ Graph visualization saved to {viz_path}")
else:
print(" - Skipping visualization as no graph was generated or it was empty.")
# Save the results in a format for the web app.
save_path = os.path.join(RESULTS_DIR, f"attribution_graphs_results_prompt_{prompt_idx + 1}.json")
# Create a JSON file that can be merged with others.
final_results = {
"analyses": {
f"prompt_{prompt_idx + 1}": analysis
},
"config": config.__dict__,
"timestamp": str(time.time())
}
# The web page doesn't use the graph object, so remove it.
if 'graph' in final_results['analyses'][f"prompt_{prompt_idx + 1}"]:
del final_results['analyses'][f"prompt_{prompt_idx + 1}"]['graph']
pipeline.save_results(final_results, save_path)
print(f"💾 Results saved to {save_path}")
print("\n🎉 Analysis for this prompt completed successfully!")
except Exception as e:
print(f"❌ Error during execution: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()