dlouapre's picture
dlouapre HF Staff
Removing nnsight imports
3a0c265
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
def load_saes_from_file(file_path, cfg, device):
"""
Load pre-extracted steering vectors from a local file.
This is much faster than load_saes() since it doesn't download large SAE files.
The file should be created using extract_steering_vectors.py script.
Args:
file_path: Path to the .pt file containing steering vectors
cfg: Configuration dict with 'features' list
device: Device to load tensors on ('cuda' or 'cpu')
Returns:
List of steering component dicts with keys: 'layer', 'feature', 'strength', 'vector'
"""
import os
if not os.path.exists(file_path):
raise FileNotFoundError(
f"Steering vectors file not found: {file_path}\n"
f"Please run: python extract_steering_vectors.py"
)
print(f"Loading pre-extracted steering vectors from {file_path}...")
# Load the dictionary of vectors
steering_vectors_dict = torch.load(file_path, map_location="cpu")
if not cfg['features'] or len(cfg['features']) == 0:
print("No features specified in config.")
return []
steering_components = []
features = cfg['features']
reduced_strengths = cfg.get('reduced_strengths', False)
for i, feature in enumerate(features):
layer_idx, feature_idx = feature[0], feature[1]
strength = feature[2] if len(feature) > 2 else 0.0
if reduced_strengths:
strength *= layer_idx
# Look up the pre-extracted vector
key = (layer_idx, feature_idx)
if key not in steering_vectors_dict:
raise KeyError(
f"Vector for layer {layer_idx}, feature {feature_idx} not found in {file_path}.\n"
f"Please re-run: python extract_steering_vectors.py"
)
vec = steering_vectors_dict[key].to(device, non_blocking=True)
# Display
reduced_str = f"[{strength/layer_idx:.2f}]" if layer_idx > 0 else "[N/A]"
print(f"Loaded feature {layer_idx} {feature_idx} {strength:.2f} {reduced_str}")
steering_components.append({
'layer': layer_idx,
'feature': feature_idx,
'strength': strength,
'vector': vec # Already normalized in the file
})
print(f"Loaded {len(steering_components)} steering vector(s) from local file")
return steering_components
def create_steering_hook(layer_idx, steering_components, clamp_intensity=False):
"""
Create a forward hook for a specific layer that applies steering.
Args:
layer_idx: Which layer this hook is for
steering_components: List of steering components (all layers)
clamp_intensity: Whether to clamp steering intensity
Returns:
Forward hook function
"""
layer_components = [sc for sc in steering_components if sc['layer'] == layer_idx]
if not layer_components:
return None
def hook(module, input, output):
"""Forward hook that modifies the output hidden states."""
# Handle different output formats (tuple vs tensor)
if isinstance(output, tuple):
hidden_states = output[0]
rest_of_output = output[1:]
else:
hidden_states = output
rest_of_output = None
# Handle different shapes during generation
original_shape = hidden_states.shape
if len(original_shape) == 2:
# During generation: [batch, hidden_dim] -> add seq_len dimension
hidden_states = hidden_states.unsqueeze(1) # [batch, 1, hidden_dim]
for sc in layer_components:
strength = sc['strength']
vector = sc['vector'] # Already normalized
# Ensure vector matches hidden_states dtype and device
vector = vector.to(dtype=hidden_states.dtype, device=hidden_states.device)
# Match nnsight's expansion pattern exactly
seq_len = hidden_states.shape[1]
amount = (strength * vector).unsqueeze(0).expand(seq_len, -1).unsqueeze(0) # [1, seq_len, hidden_dim]
if clamp_intensity:
# Remove existing projection (prevents over-steering)
projection_scalars = torch.einsum('bsh,h->bs', hidden_states, vector).unsqueeze(-1)
projection_vectors = projection_scalars * vector.view(1, 1, -1)
amount = amount - projection_vectors
hidden_states = hidden_states + amount
# Restore original shape if we added a dimension
if len(original_shape) == 2:
hidden_states = hidden_states.squeeze(1) # [batch, hidden_dim]
# Return in the same format as input
if rest_of_output is not None:
return (hidden_states,) + rest_of_output
else:
return hidden_states
return hook
def stream_steered_answer_hf(model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
chat,
steering_components,
max_new_tokens=128,
temperature=0.0,
repetition_penalty=1.0,
clamp_intensity=False,
stream=True):
"""
Generate steered answer using pure HuggingFace Transformers with streaming.
Args:
model: HuggingFace transformers model
tokenizer: Tokenizer instance
chat: Chat history in OpenAI format
steering_components: List of dicts with 'layer', 'strength', 'vector'
max_new_tokens: Maximum tokens to generate
temperature: Sampling temperature (0 = greedy)
repetition_penalty: Repetition penalty
clamp_intensity: Whether to clamp steering intensity
Yields:
Partial text as tokens are generated
"""
input_ids_list = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=True)
input_ids = torch.tensor([input_ids_list]).to(model.device)
# Register steering hooks
hook_handles = []
layers_to_steer = set(sc['layer'] for sc in steering_components)
for layer_idx in layers_to_steer:
hook_fn = create_steering_hook(layer_idx, steering_components, clamp_intensity)
if hook_fn:
layer_module = model.model.layers[layer_idx]
handle = layer_module.register_forward_hook(hook_fn)
hook_handles.append(handle)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = {
"input_ids": input_ids,
"max_new_tokens": max_new_tokens,
"temperature": temperature if temperature > 0 else 1.0,
"do_sample": temperature > 0,
"repetition_penalty": repetition_penalty,
"streamer": streamer,
"pad_token_id": tokenizer.eos_token_id,
}
thread = Thread(target=lambda: model.generate(**generation_kwargs))
thread.start()
generated_text = ""
for token_text in streamer:
generated_text += token_text
yield generated_text
thread.join()
for handle in hook_handles:
handle.remove()