#fnord23UFO import gradio as gr import torch import torch.nn.functional as F from transformers import AutoTokenizer, AutoModelForCausalLM from safetensors import safe_open import os import requests import json import math import numpy as np from sklearn.decomposition import PCA import logging import time from dotenv import load_dotenv from huggingface_hub import hf_hub_download import spaces import traceback from graphviz import Digraph from PIL import Image, ImageDraw, ImageFont from io import BytesIO import functools # Load environment variables load_dotenv() # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) logger.info(f"HF_TOKEN_GEMMA set: {'HF_TOKEN_GEMMA' in os.environ}") logger.info(f"HF_TOKEN_EMBEDDINGS set: {'HF_TOKEN_EMBEDDINGS' in os.environ}") class Config: def __init__(self): self.MODEL_NAME = "google/gemma-2b" self.ACCESS_TOKEN = os.environ.get("HF_TOKEN_GEMMA") self.EMBEDDINGS_TOKEN = os.environ.get("HF_TOKEN_EMBEDDINGS") self.DEVICE = "cuda" if torch.cuda.is_available() else "cpu" self.DTYPE = torch.float32 self.TOPK = 5 self.CUTOFF = 0.00001 # Cumulative probability cutoff for tree branches self.OUTPUT_LENGTH = 20 self.SUB_TOKEN_ID = 23070 # Arbitrary token ID to overwrite with embedding self.LOG_BASE = 10 config = Config() def load_tokenizer(): try: logger.info(f"Attempting to load tokenizer with token: {config.ACCESS_TOKEN[:5]}...") tokenizer = AutoTokenizer.from_pretrained(config.MODEL_NAME, token=config.ACCESS_TOKEN) logger.info("Tokenizer loaded successfully") return tokenizer except Exception as e: logger.error(f"Error loading tokenizer: {str(e)}") return None def load_model(): try: logger.info(f"Attempting to load model with token: {config.ACCESS_TOKEN[:5]}...") model = AutoModelForCausalLM.from_pretrained(config.MODEL_NAME, device_map="auto", token=config.ACCESS_TOKEN) logger.info("Model loaded successfully") return model except Exception as e: logger.error(f"Error loading model: {str(e)}") return None def load_token_embeddings(): try: logger.info(f"Attempting to load token embeddings with token: {config.EMBEDDINGS_TOKEN[:5]}...") embeddings_path = hf_hub_download( repo_id="mwatkins1970/gemma-2b-embeddings", filename="gemma_2b_embeddings.pt", token=config.EMBEDDINGS_TOKEN ) logger.info(f"Embeddings downloaded to: {embeddings_path}") embeddings = torch.load(embeddings_path, map_location=config.DEVICE, weights_only=True) logger.info("Embeddings loaded successfully") return embeddings.to(dtype=config.DTYPE) except Exception as e: logger.error(f"Error loading token embeddings: {str(e)}") return None def load_sae_weights(sae_name): start_time = time.time() base_url = 'https://huggingface.co/jbloom/Gemma-2b-Residual-Stream-SAEs/resolve/main/' sae_urls = { "Gemma-2B layer 6": "gemma_2b_blocks.6.hook_resid_post_16384_anthropic_fast_lr/sae_weights.safetensors", "Gemma-2B layer 0": "gemma_2b_blocks.0.hook_resid_post_16384_anthropic/sae_weights.safetensors", "Gemma-2B layer 10": "gemma_2b_blocks.10.hook_resid_post_16384/sae_weights.safetensors", "Gemma-2B layer 12": "gemma_2b_blocks.12.hook_resid_post_16384/sae_weights.safetensors" } if sae_name not in sae_urls: raise ValueError(f"Unknown SAE: {sae_name}") url = f'{base_url}{sae_urls[sae_name]}?download=true' local_filename = f'sae_{sae_name.replace(" ", "_").lower()}.safetensors' if not os.path.exists(local_filename): try: response = requests.get(url) response.raise_for_status() with open(local_filename, 'wb') as f: f.write(response.content) logger.info(f'SAE weights for {sae_name} downloaded successfully!') except requests.RequestException as e: logger.error(f"Failed to download SAE weights for {sae_name}: {str(e)}") return None, None try: with safe_open(local_filename, framework="pt") as f: w_dec = f.get_tensor("W_dec").to(device=config.DEVICE, dtype=config.DTYPE) w_enc = f.get_tensor("W_enc").to(device=config.DEVICE, dtype=config.DTYPE) logger.info(f"Successfully loaded weights for {sae_name}") logger.info(f"Time taken to load weights: {time.time() - start_time:.2f} seconds") return w_enc, w_dec except Exception as e: logger.error(f"Error loading SAE weights for {sae_name}: {str(e)}") return None, None @torch.no_grad() def create_feature_vector(w_enc, w_dec, feature_number, weight_type, token_centroid, use_token_centroid, scaling_factor): if weight_type == "encoder": feature_vector = w_enc[:, feature_number] else: feature_vector = w_dec[feature_number] if use_token_centroid: feature_vector = token_centroid + scaling_factor * (feature_vector - token_centroid) / torch.norm(feature_vector - token_centroid) return feature_vector def perform_pca(_embeddings): try: logger.info(f"Starting PCA. Embeddings shape: {_embeddings.shape}") pca = PCA(n_components=1) embeddings_cpu = _embeddings.detach().cpu().numpy() logger.info(f"Embeddings converted to numpy. Shape: {embeddings_cpu.shape}") pca.fit(embeddings_cpu) logger.info("PCA fit completed") pca_direction = torch.tensor(pca.components_[0], dtype=config.DTYPE, device=config.DEVICE) logger.info(f"PCA direction calculated. Shape: {pca_direction.shape}") normalized_direction = F.normalize(pca_direction, p=2, dim=0) logger.info(f"PCA direction normalized. Shape: {normalized_direction.shape}") return normalized_direction except Exception as e: logger.error(f"Error in perform_pca: {str(e)}") logger.error(f"Embeddings stats - min: {_embeddings.min()}, max: {_embeddings.max()}, mean: {_embeddings.mean()}, std: {_embeddings.std()}") logger.error(traceback.format_exc()) raise RuntimeError(f"PCA calculation failed: {str(e)}") @torch.no_grad() def create_ghost_token(_feature_vector, _token_centroid, _pca_direction, target_distance, pca_weight): feature_direction = F.normalize(_feature_vector - _token_centroid, p=2, dim=0) combined_direction = (1 - pca_weight) * feature_direction + pca_weight * _pca_direction combined_direction = F.normalize(combined_direction, p=2, dim=0) return _token_centroid + target_distance * combined_direction @torch.no_grad() def find_closest_tokens(_emb, _token_embeddings, _tokenizer, top_k=500, num_exp=1.4, denom_exp=1.0): token_centroid = torch.mean(_token_embeddings, dim=0) emb_norm = F.normalize(_emb.view(1, -1), p=2, dim=1) centroid_norm = F.normalize(token_centroid.view(1, -1), p=2, dim=1) normalized_embeddings = F.normalize(_token_embeddings, p=2, dim=1) similarities_emb = torch.mm(emb_norm, normalized_embeddings.t()).squeeze() similarities_centroid = torch.mm(centroid_norm, normalized_embeddings.t()).squeeze() distances_emb = torch.pow(1 - similarities_emb, num_exp) distances_centroid = torch.pow(1 - similarities_centroid, denom_exp) ratios = distances_emb / distances_centroid top_ratios, top_indices = torch.topk(ratios, k=top_k, largest=False) closest_tokens = [_tokenizer.decode([idx.item()]) for idx in top_indices] return list(zip(closest_tokens, top_ratios.tolist())) def get_neuronpedia_url(layer, feature): return f"https://neuronpedia.org/gemma-2b/{layer}-res-jb/{feature}?embed=true&embedexplanation=true&embedplots=false&embedtest=false&height=300" # New functions for tree generation and visualization def update_token_embedding(model, token_id, new_embedding): new_embedding = new_embedding.to(model.get_input_embeddings().weight.device) model.get_input_embeddings().weight.data[token_id] = new_embedding def produce_next_token_ids(input_ids, model, topk, sub_token_id): input_ids = input_ids.to(model.device) with torch.no_grad(): outputs = model(input_ids) logits = outputs.logits last_logits = logits[:, -1, :] last_logits[:, sub_token_id] = float('-inf') softmax_probs = torch.softmax(last_logits, dim=-1) top_k_probs, top_k_ids = torch.topk(softmax_probs, k=topk, dim=-1) return top_k_ids[0], top_k_probs[0] def build_def_tree(input_ids, data, base_prompt, model, tokenizer, config, depth=0, max_depth=25, cumulative_prob=1.0): if depth >= max_depth or cumulative_prob < config.CUTOFF: return current_prompt = tokenizer.decode(input_ids[0], skip_special_tokens=True) yield f"Depth {depth}: {current_prompt} PROB: {cumulative_prob}\n" top_k_ids, top_k_probs = produce_next_token_ids(input_ids, model, config.TOPK, config.SUB_TOKEN_ID) for idx, token_id in enumerate(top_k_ids.tolist()): if token_id == config.SUB_TOKEN_ID: continue token_id_tensor = torch.tensor([token_id], dtype=torch.long).to(model.device) new_input_ids = torch.cat([input_ids, token_id_tensor.view(1, 1)], dim=-1) new_cumulative_prob = cumulative_prob * top_k_probs[idx].item() if new_cumulative_prob < config.CUTOFF: continue token_str = tokenizer.decode([token_id], skip_special_tokens=True) new_child = { "token_id": token_id, "token": token_str, "cumulative_prob": new_cumulative_prob, "children": [] } data['children'].append(new_child) yield from build_def_tree(new_input_ids, new_child, base_prompt, model, tokenizer, config, depth=depth+1, max_depth=max_depth, cumulative_prob=new_cumulative_prob) def generate_definition_tree(base_prompt, embedding, model, tokenizer, config): results_dict = {"token": "", "cumulative_prob": 1, "children": []} token_embedding = torch.unsqueeze(embedding, dim=0).to(model.device) update_token_embedding(model, config.SUB_TOKEN_ID, token_embedding) if hasattr(model, 'reset_cache'): model.reset_cache() input_ids = tokenizer.encode(base_prompt, return_tensors="pt").to(model.device) yield from build_def_tree(input_ids, results_dict, base_prompt, model, tokenizer, config) return results_dict def find_max_min_cumulative_weight(node, current_max=0, current_min=float('inf')): current_max = max(current_max, node.get('cumulative_prob', 0)) if node.get('cumulative_prob', 1) > 0: current_min = min(current_min, node.get('cumulative_prob', 1)) for child in node.get('children', []): current_max, current_min = find_max_min_cumulative_weight(child, current_max, current_min) return current_max, current_min def scale_edge_width(cumulative_weight, max_weight, min_weight, log_base, max_thickness=33, min_thickness=1): cumulative_weight = max(cumulative_weight, min_weight) log_weight = math.log(cumulative_weight, log_base) - math.log(min_weight, log_base) log_max = math.log(max_weight, log_base) - math.log(min_weight, log_base) amplified_weight = (log_weight / log_max) ** 2.5 scaled_weight = (amplified_weight * (max_thickness - min_thickness)) + min_thickness return scaled_weight def add_nodes_edges(dot, node, config, max_weight, min_weight, parent=None, is_root=True, depth=0, trim_cutoff=0): node_id = str(id(node)) token = node.get('token', '').strip() cumulative_prob = node.get('cumulative_prob', 1) if cumulative_prob < trim_cutoff and not is_root: return if is_root or token: if parent and not is_root: edge_weight = scale_edge_width(cumulative_prob, max_weight, min_weight, config.LOG_BASE) dot.edge(parent, node_id, arrowhead='dot', arrowsize='1', color='darkblue', penwidth=str(edge_weight)) label = "*" if is_root else token dot.node(node_id, label=label, shape='plaintext', fontsize="36", fontname='Helvetica') for child in node.get('children', []): add_nodes_edges(dot, child, config, max_weight, min_weight, parent=node_id, is_root=False, depth=depth+1, trim_cutoff=trim_cutoff) def create_tree_diagram(data, config, max_weight, min_weight, trim_cutoff=0): dot = Digraph(comment='Definition Tree', format='png') dot.attr(rankdir='LR', size='5040,5000', margin='0.06', nodesep='0.06', ranksep='1', dpi='120', bgcolor='white') add_nodes_edges(dot, data, config, max_weight, min_weight, trim_cutoff=trim_cutoff) output = BytesIO() dot.render(outfile=output, format='png') output.seek(0) # Add white background with Image.open(output) as img: bg = Image.new("RGB", (img.width, 5000), (255, 255, 255)) y_offset = (5000 - img.height) // 2 bg.paste(img, (0, y_offset)) final_output = BytesIO() bg.save(final_output, 'PNG') final_output.seek(0) return final_output # Global variables to store loaded resources tokenizer = None model = None token_embeddings = None w_enc_dict = {} w_dec_dict = {} @functools.lru_cache(maxsize=None) def cached_load_tokenizer(): return load_tokenizer() @functools.lru_cache(maxsize=None) def cached_load_model(): return load_model() @functools.lru_cache(maxsize=None) def cached_load_token_embeddings(): return load_token_embeddings() def initialize_resources(): global tokenizer, model, token_embeddings logger.info("Initializing resources...") tokenizer = cached_load_tokenizer() if tokenizer is None: raise RuntimeError("Failed to load tokenizer.") model = cached_load_model() if model is None: raise RuntimeError("Failed to load model.") token_embeddings = cached_load_token_embeddings() if token_embeddings is None: raise RuntimeError("Failed to load token embeddings.") logger.info("Resources initialized successfully.") @spaces.GPU def process_input(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode, top_500=False): global w_enc_dict, w_dec_dict, model, tokenizer, token_embeddings try: logger.info(f"Processing input: SAE={selected_sae}, feature_number={feature_number}, mode={mode}") # Load the SAE weights if they are not already loaded if selected_sae not in w_enc_dict or selected_sae not in w_dec_dict: logger.info("Loading SAE weights for {}".format(selected_sae)) w_enc, w_dec = load_sae_weights(selected_sae) if w_enc is None or w_dec is None: error_message = f"Failed to load SAE weights for {selected_sae}. Please try a different SAE or check your connection." logger.error(error_message) return error_message, None w_enc_dict[selected_sae] = w_enc w_dec_dict[selected_sae] = w_dec else: w_enc, w_dec = w_enc_dict[selected_sae], w_dec_dict[selected_sae] # Create the feature vector token_centroid = torch.mean(token_embeddings, dim=0) feature_vector = create_feature_vector(w_enc, w_dec, int(feature_number), weight_type, token_centroid, use_token_centroid, scaling_factor) logger.info(f"Feature vector created. Shape: {feature_vector.shape}") if mode == "cosine distance token lists": logger.info("Generating cosine distance token list") closest_tokens_with_values = find_closest_tokens( feature_vector, token_embeddings, tokenizer, top_k=500, num_exp=num_exp, denom_exp=denom_exp ) if top_500: # Generate the top 500 list result = "Top 500 list:\n" result += "\n".join([f"{token!r}: {value:.4f}" for token, value in closest_tokens_with_values]) logger.info("Returning top 500 list") return result, None else: # Generate the top 100 list token_list = [token for token, _ in closest_tokens_with_values[:100]] result = f"100 tokens whose embeddings produce the smallest ratio:\n\n" result += f"[{', '.join(repr(token) for token in token_list)}]\n" logger.info("Returning top 100 tokens") return result, None return "Mode not recognized or not implemented in this step.", None except Exception as e: logger.error(f"Error in process_input: {str(e)}") return f"Error: {str(e)}", None def trim_tree(trim_cutoff, tree_data): max_weight, min_weight = find_max_min_cumulative_weight(tree_data) trimmed_tree_image = create_tree_diagram(tree_data, config, max_weight, min_weight, trim_cutoff=float(trim_cutoff)) return trimmed_tree_image def gradio_interface(): def update_visibility(mode): if mode == "definition tree generation": return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True) else: return gr.update(visible=False), gr.update(visible=False), gr.update(visible(False)) def update_neuronpedia(selected_sae, feature_number): layer_number = int(selected_sae.split()[-1]) url = get_neuronpedia_url(layer_number, feature_number) return f'' @spaces.GPU def update_output(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode): # Call process_input without generating the top 500 list initially return process_input(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode, top_500=False) @spaces.GPU def generate_top_500(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode): # Call process_input with top_500=True to generate the full list return process_input(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode, top_500=True) def trim_tree(trim_cutoff, tree_data): if tree_data is None: return None max_weight, min_weight = find_max_min_cumulative_weight(tree_data) trimmed_tree_image = create_tree_diagram(tree_data, config, max_weight, min_weight, trim_cutoff=float(trim_cutoff)) return trimmed_tree_image with gr.Blocks() as demo: gr.Markdown("# Gemma-2B SAE Feature Explorer (almost there?)") with gr.Row(): with gr.Column(scale=2): selected_sae = gr.Dropdown(choices=["Gemma-2B layer 0", "Gemma-2B layer 6", "Gemma-2B layer 10", "Gemma-2B layer 12"], label="Select SAE") feature_number = gr.Number(label="Select feature number", minimum=0, maximum=16383, value=0) mode = gr.Radio( choices=["cosine distance token lists", "definition tree generation"], label="Select mode", value="cosine distance token lists" ) weight_type = gr.Radio(["encoder", "decoder"], label="Select weight type for feature vector construction", value="encoder") use_token_centroid = gr.Checkbox(label="Use token centroid offset", value=True) scaling_factor = gr.Slider(minimum=0.1, maximum=10.0, value=3.8, label="Scaling factor (3.8 is mean distance from token embeddings to token centroid)") num_exp = gr.Slider(minimum=0.1, maximum=5.0, value=1.4, label="Numerator exponent m") denom_exp = gr.Slider(minimum=0.1, maximum=5.0, value=1.0, label="Denominator exponent n") use_pca = gr.Checkbox(label="Introduce first PCA component") pca_weight = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="PCA weight") with gr.Column(scale=3): generate_btn = gr.Button("Generate Output") progress = gr.Progress() output_text = gr.Textbox(label="Output", lines=20) output_image = gr.Image(label="Tree Diagram", visible=False) generate_top_500_btn = gr.Button("Generate Top 500 Tokens and Power Ratios", visible=False) output_500_text = gr.Textbox(label="Top 500 Output", lines=20, visible=False) trim_slider = gr.Slider(minimum=0.00001, maximum=0.1, value=0.00001, label="Trim cutoff for cumulative probability", visible=False) trim_btn = gr.Button("Trim Tree", visible=False) tree_data_state = gr.State() neuronpedia_html = gr.HTML(label="Neuronpedia") inputs = [selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode] generate_btn.click( update_output, inputs=inputs, outputs=[output_text, output_image], show_progress="full" ) generate_top_500_btn.click( generate_top_500, inputs=inputs, outputs=[output_500_text], show_progress="full" ) trim_btn.click(trim_tree, inputs=[trim_slider, tree_data_state], outputs=[output_image]) mode.change(update_visibility, inputs=[mode], outputs=[output_image, trim_slider, trim_btn]) selected_sae.change(update_neuronpedia, inputs=[selected_sae, feature_number], outputs=[neuronpedia_html]) feature_number.change(update_neuronpedia, inputs=[selected_sae, feature_number], outputs=[neuronpedia_html]) output_text.change( lambda text: (gr.update(visible=True), gr.update(visible=True)) if "100 tokens" in text else (gr.update(visible(False)), gr.update(visible(False))), inputs=[output_text], outputs=[generate_top_500_btn, output_500_text] ) return demo if __name__ == "__main__": try: logger.info("Starting application initialization...") initialize_resources() logger.info("Creating Gradio interface...") iface = gradio_interface() logger.info("Launching Gradio interface...") iface.launch() logger.info("Gradio interface launched successfully") except Exception as e: logger.error(f"Error during application startup: {str(e)}") logger.error(traceback.format_exc())