Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| from threading import Thread | |
| def load_saes_from_file(file_path, cfg, device): | |
| """ | |
| Load pre-extracted steering vectors from a local file. | |
| This is much faster than load_saes() since it doesn't download large SAE files. | |
| The file should be created using extract_steering_vectors.py script. | |
| Args: | |
| file_path: Path to the .pt file containing steering vectors | |
| cfg: Configuration dict with 'features' list | |
| device: Device to load tensors on ('cuda' or 'cpu') | |
| Returns: | |
| List of steering component dicts with keys: 'layer', 'feature', 'strength', 'vector' | |
| """ | |
| import os | |
| if not os.path.exists(file_path): | |
| raise FileNotFoundError( | |
| f"Steering vectors file not found: {file_path}\n" | |
| f"Please run: python extract_steering_vectors.py" | |
| ) | |
| print(f"Loading pre-extracted steering vectors from {file_path}...") | |
| # Load the dictionary of vectors | |
| steering_vectors_dict = torch.load(file_path, map_location="cpu") | |
| if not cfg['features'] or len(cfg['features']) == 0: | |
| print("No features specified in config.") | |
| return [] | |
| steering_components = [] | |
| features = cfg['features'] | |
| reduced_strengths = cfg.get('reduced_strengths', False) | |
| for i, feature in enumerate(features): | |
| layer_idx, feature_idx = feature[0], feature[1] | |
| strength = feature[2] if len(feature) > 2 else 0.0 | |
| if reduced_strengths: | |
| strength *= layer_idx | |
| # Look up the pre-extracted vector | |
| key = (layer_idx, feature_idx) | |
| if key not in steering_vectors_dict: | |
| raise KeyError( | |
| f"Vector for layer {layer_idx}, feature {feature_idx} not found in {file_path}.\n" | |
| f"Please re-run: python extract_steering_vectors.py" | |
| ) | |
| vec = steering_vectors_dict[key].to(device, non_blocking=True) | |
| # Display | |
| reduced_str = f"[{strength/layer_idx:.2f}]" if layer_idx > 0 else "[N/A]" | |
| print(f"Loaded feature {layer_idx} {feature_idx} {strength:.2f} {reduced_str}") | |
| steering_components.append({ | |
| 'layer': layer_idx, | |
| 'feature': feature_idx, | |
| 'strength': strength, | |
| 'vector': vec # Already normalized in the file | |
| }) | |
| print(f"Loaded {len(steering_components)} steering vector(s) from local file") | |
| return steering_components | |
| def create_steering_hook(layer_idx, steering_components, clamp_intensity=False): | |
| """ | |
| Create a forward hook for a specific layer that applies steering. | |
| Args: | |
| layer_idx: Which layer this hook is for | |
| steering_components: List of steering components (all layers) | |
| clamp_intensity: Whether to clamp steering intensity | |
| Returns: | |
| Forward hook function | |
| """ | |
| layer_components = [sc for sc in steering_components if sc['layer'] == layer_idx] | |
| if not layer_components: | |
| return None | |
| def hook(module, input, output): | |
| """Forward hook that modifies the output hidden states.""" | |
| # Handle different output formats (tuple vs tensor) | |
| if isinstance(output, tuple): | |
| hidden_states = output[0] | |
| rest_of_output = output[1:] | |
| else: | |
| hidden_states = output | |
| rest_of_output = None | |
| # Handle different shapes during generation | |
| original_shape = hidden_states.shape | |
| if len(original_shape) == 2: | |
| # During generation: [batch, hidden_dim] -> add seq_len dimension | |
| hidden_states = hidden_states.unsqueeze(1) # [batch, 1, hidden_dim] | |
| for sc in layer_components: | |
| strength = sc['strength'] | |
| vector = sc['vector'] # Already normalized | |
| # Ensure vector matches hidden_states dtype and device | |
| vector = vector.to(dtype=hidden_states.dtype, device=hidden_states.device) | |
| # Match nnsight's expansion pattern exactly | |
| seq_len = hidden_states.shape[1] | |
| amount = (strength * vector).unsqueeze(0).expand(seq_len, -1).unsqueeze(0) # [1, seq_len, hidden_dim] | |
| if clamp_intensity: | |
| # Remove existing projection (prevents over-steering) | |
| projection_scalars = torch.einsum('bsh,h->bs', hidden_states, vector).unsqueeze(-1) | |
| projection_vectors = projection_scalars * vector.view(1, 1, -1) | |
| amount = amount - projection_vectors | |
| hidden_states = hidden_states + amount | |
| # Restore original shape if we added a dimension | |
| if len(original_shape) == 2: | |
| hidden_states = hidden_states.squeeze(1) # [batch, hidden_dim] | |
| # Return in the same format as input | |
| if rest_of_output is not None: | |
| return (hidden_states,) + rest_of_output | |
| else: | |
| return hidden_states | |
| return hook | |
| def stream_steered_answer_hf(model: AutoModelForCausalLM, | |
| tokenizer: AutoTokenizer, | |
| chat, | |
| steering_components, | |
| max_new_tokens=128, | |
| temperature=0.0, | |
| repetition_penalty=1.0, | |
| clamp_intensity=False, | |
| stream=True): | |
| """ | |
| Generate steered answer using pure HuggingFace Transformers with streaming. | |
| Args: | |
| model: HuggingFace transformers model | |
| tokenizer: Tokenizer instance | |
| chat: Chat history in OpenAI format | |
| steering_components: List of dicts with 'layer', 'strength', 'vector' | |
| max_new_tokens: Maximum tokens to generate | |
| temperature: Sampling temperature (0 = greedy) | |
| repetition_penalty: Repetition penalty | |
| clamp_intensity: Whether to clamp steering intensity | |
| Yields: | |
| Partial text as tokens are generated | |
| """ | |
| input_ids_list = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=True) | |
| input_ids = torch.tensor([input_ids_list]).to(model.device) | |
| # Register steering hooks | |
| hook_handles = [] | |
| layers_to_steer = set(sc['layer'] for sc in steering_components) | |
| for layer_idx in layers_to_steer: | |
| hook_fn = create_steering_hook(layer_idx, steering_components, clamp_intensity) | |
| if hook_fn: | |
| layer_module = model.model.layers[layer_idx] | |
| handle = layer_module.register_forward_hook(hook_fn) | |
| hook_handles.append(handle) | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| generation_kwargs = { | |
| "input_ids": input_ids, | |
| "max_new_tokens": max_new_tokens, | |
| "temperature": temperature if temperature > 0 else 1.0, | |
| "do_sample": temperature > 0, | |
| "repetition_penalty": repetition_penalty, | |
| "streamer": streamer, | |
| "pad_token_id": tokenizer.eos_token_id, | |
| } | |
| thread = Thread(target=lambda: model.generate(**generation_kwargs)) | |
| thread.start() | |
| generated_text = "" | |
| for token_text in streamer: | |
| generated_text += token_text | |
| yield generated_text | |
| thread.join() | |
| for handle in hook_handles: | |
| handle.remove() | |