import gradio as gr import torch import torch.nn.functional as F from transformers import AutoTokenizer from safetensors import safe_open import os import requests import json from sklearn.decomposition import PCA import logging import time from dotenv import load_dotenv from huggingface_hub import hf_hub_download import spaces import traceback # Load environment variables load_dotenv() # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) logger.info(f"HF_TOKEN_GEMMA set: {'HF_TOKEN_GEMMA' in os.environ}") logger.info(f"HF_TOKEN_EMBEDDINGS set: {'HF_TOKEN_EMBEDDINGS' in os.environ}") class Config: def __init__(self): self.MODEL_NAME = "google/gemma-2b" self.ACCESS_TOKEN = os.environ.get("HF_TOKEN_GEMMA") self.EMBEDDINGS_TOKEN = os.environ.get("HF_TOKEN_EMBEDDINGS") self.DEVICE = "cuda" if torch.cuda.is_available() else "cpu" self.DTYPE = torch.float32 config = Config() def load_tokenizer(): try: logger.info(f"Attempting to load tokenizer with token: {config.ACCESS_TOKEN[:5]}...") tokenizer = AutoTokenizer.from_pretrained(config.MODEL_NAME, token=config.ACCESS_TOKEN) logger.info("Tokenizer loaded successfully") return tokenizer except Exception as e: logger.error(f"Error loading tokenizer: {str(e)}") return None def load_token_embeddings(): try: logger.info(f"Attempting to load token embeddings with token: {config.EMBEDDINGS_TOKEN[:5]}...") embeddings_path = hf_hub_download( repo_id="mwatkins1970/gemma-2b-embeddings", filename="gemma_2b_embeddings.pt", token=config.EMBEDDINGS_TOKEN ) logger.info(f"Embeddings downloaded to: {embeddings_path}") embeddings = torch.load(embeddings_path, map_location=config.DEVICE) logger.info("Embeddings loaded successfully") return embeddings.to(dtype=config.DTYPE) except Exception as e: logger.error(f"Error loading token embeddings: {str(e)}") return None def load_sae_weights(sae_name): start_time = time.time() base_url = 'https://huggingface.co/jbloom/Gemma-2b-Residual-Stream-SAEs/resolve/main/' sae_urls = { "Gemma-2B layer 6": "gemma_2b_blocks.6.hook_resid_post_16384_anthropic_fast_lr/sae_weights.safetensors", "Gemma-2B layer 0": "gemma_2b_blocks.0.hook_resid_post_16384_anthropic/sae_weights.safetensors", "Gemma-2B layer 10": "gemma_2b_blocks.10.hook_resid_post_16384/sae_weights.safetensors", "Gemma-2B layer 12": "gemma_2b_blocks.12.hook_resid_post_16384/sae_weights.safetensors" } if sae_name not in sae_urls: raise ValueError(f"Unknown SAE: {sae_name}") url = f'{base_url}{sae_urls[sae_name]}?download=true' local_filename = f'sae_{sae_name.replace(" ", "_").lower()}.safetensors' if not os.path.exists(local_filename): try: response = requests.get(url) response.raise_for_status() with open(local_filename, 'wb') as f: f.write(response.content) logger.info(f'SAE weights for {sae_name} downloaded successfully!') except requests.RequestException as e: logger.error(f"Failed to download SAE weights for {sae_name}: {str(e)}") return None, None try: with safe_open(local_filename, framework="pt") as f: w_dec = f.get_tensor("W_dec").to(device=config.DEVICE, dtype=config.DTYPE) w_enc = f.get_tensor("W_enc").to(device=config.DEVICE, dtype=config.DTYPE) logger.info(f"Successfully loaded weights for {sae_name}") logger.info(f"Time taken to load weights: {time.time() - start_time:.2f} seconds") return w_enc, w_dec except Exception as e: logger.error(f"Error loading SAE weights for {sae_name}: {str(e)}") return None, None @torch.no_grad() def create_feature_vector(w_enc, w_dec, feature_number, weight_type, token_centroid, use_token_centroid, scaling_factor): if weight_type == "encoder": feature_vector = w_enc[:, feature_number] else: feature_vector = w_dec[feature_number] if use_token_centroid: feature_vector = token_centroid + scaling_factor * (feature_vector - token_centroid) / torch.norm(feature_vector - token_centroid) return feature_vector def perform_pca(_embeddings): pca = PCA(n_components=1) pca.fit(_embeddings.cpu().numpy()) pca_direction = torch.tensor(pca.components_[0], dtype=config.DTYPE, device=config.DEVICE) return F.normalize(pca_direction, p=2, dim=0) @torch.no_grad() def create_ghost_token(_feature_vector, _token_centroid, _pca_direction, target_distance, pca_weight): feature_direction = F.normalize(_feature_vector - _token_centroid, p=2, dim=0) combined_direction = (1 - pca_weight) * feature_direction + pca_weight * _pca_direction combined_direction = F.normalize(combined_direction, p=2, dim=0) return _token_centroid + target_distance * combined_direction @torch.no_grad() def find_closest_tokens(_emb, _token_embeddings, _tokenizer, top_k=500, num_exp=1.4, denom_exp=1.0): token_centroid = torch.mean(_token_embeddings, dim=0) emb_norm = F.normalize(_emb.view(1, -1), p=2, dim=1) centroid_norm = F.normalize(token_centroid.view(1, -1), p=2, dim=1) normalized_embeddings = F.normalize(_token_embeddings, p=2, dim=1) similarities_emb = torch.mm(emb_norm, normalized_embeddings.t()).squeeze() similarities_centroid = torch.mm(centroid_norm, normalized_embeddings.t()).squeeze() distances_emb = torch.pow(1 - similarities_emb, num_exp) distances_centroid = torch.pow(1 - similarities_centroid, denom_exp) ratios = distances_emb / distances_centroid top_ratios, top_indices = torch.topk(ratios, k=top_k, largest=False) closest_tokens = [_tokenizer.decode([idx.item()]) for idx in top_indices] return list(zip(closest_tokens, top_ratios.tolist())) def get_neuronpedia_url(layer, feature): return f"https://neuronpedia.org/gemma-2b/{layer}-res-jb/{feature}?embed=true&embedexplanation=true&embedplots=false&embedtest=false&height=300" # Global variables to store loaded resources tokenizer = None token_embeddings = None w_enc_dict = {} w_dec_dict = {} @spaces.GPU def process_input(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp): global tokenizer, token_embeddings, w_enc_dict, w_dec_dict if tokenizer is None: tokenizer = load_tokenizer() if tokenizer is None: return "Failed to load tokenizer. Please check the logs for more details." if token_embeddings is None: token_embeddings = load_token_embeddings() if token_embeddings is None: return "Failed to load token embeddings. Please check the logs for more details." if selected_sae not in w_enc_dict or selected_sae not in w_dec_dict: w_enc, w_dec = load_sae_weights(selected_sae) if w_enc is None or w_dec is None: return f"Failed to load SAE weights for {selected_sae}. Please try a different SAE or check your connection." w_enc_dict[selected_sae] = w_enc w_dec_dict[selected_sae] = w_dec else: w_enc, w_dec = w_enc_dict[selected_sae], w_dec_dict[selected_sae] if w_enc is None or w_dec is None: return "Failed to load SAE weights. Please try selecting a different SAE or rerun the app." token_centroid = torch.mean(token_embeddings, dim=0) feature_vector = create_feature_vector(w_enc, w_dec, feature_number, weight_type, token_centroid, use_token_centroid, scaling_factor) if use_pca: pca_direction = perform_pca(token_embeddings) feature_vector = create_ghost_token(feature_vector, token_centroid, pca_direction, scaling_factor, pca_weight) closest_tokens_with_values = find_closest_tokens( feature_vector, token_embeddings, tokenizer, top_k=500, num_exp=num_exp, denom_exp=denom_exp ) token_list = [token for token, _ in closest_tokens_with_values] result = f"100 tokens whose embeddings produce the smallest ratio:\n\n" result += f"[{', '.join(repr(token) for token in token_list[:100])}]\n\n" result += "Top 500 list:\n" result += "\n".join([f"{token!r}: {value:.4f}" for token, value in closest_tokens_with_values]) return result def gradio_interface(): with gr.Blocks() as demo: gr.Markdown("# Gemma-2B SAE Feature Explorer") with gr.Row(): with gr.Column(scale=2): selected_sae = gr.Dropdown(choices=["Gemma-2B layer 0", "Gemma-2B layer 6", "Gemma-2B layer 10", "Gemma-2B layer 12"], label="Select SAE") feature_number = gr.Number(label="Select feature number", minimum=0, maximum=16383, value=0) mode = gr.Radio( choices=["cosine distance token lists", "definition tree generation"], label="Select mode", value="cosine distance token lists" ) weight_type = gr.Radio(["encoder", "decoder"], label="Select weight type for feature vector construction", value="encoder") use_token_centroid = gr.Checkbox(label="Use token centroid offset", value=True) scaling_factor = gr.Slider(minimum=0.1, maximum=10.0, value=3.8, label="Scaling factor (3.8 is mean distance from token embeddings to token centroid)") num_exp = gr.Slider(minimum=0.1, maximum=5.0, value=1.4, label="Numerator exponent m") denom_exp = gr.Slider(minimum=0.1, maximum=5.0, value=1.0, label="Denominator exponent n") use_pca = gr.Checkbox(label="Introduce first PCA component") pca_weight = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="PCA weight") with gr.Column(scale=3): generate_btn = gr.Button("Generate Output") cosine_output = gr.Column(visible=True) with cosine_output: gr.Markdown("100 tokens whose embeddings produce the smallest values of the ratio") gr.Markdown("(cos distance from feature vector)^m/(cos distance from token centroid)^n") output_100 = gr.Textbox(label="Top 100 tokens", lines=10) show_500_btn = gr.Button("Show top 500 tokens and values") output_500 = gr.Textbox(label="Top 500 tokens and values", visible=False, lines=25) tree_output = gr.Column(visible=False) with tree_output: output = gr.Image(label="Tree Diagram Output") neuronpedia_embed = gr.HTML(label="Neuronpedia Embed") trim_slider = gr.Slider(minimum=0.00001, maximum=0.1, value=0.00001, label="Trim cutoff for cumulative probability") def update_output(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode): if mode == "cosine distance token lists": result = process_input(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp) top_100 = result.split("Top 500 list:")[0].strip() top_500 = "Top 500 list:\n" + result.split("Top 500 list:")[1].strip() return top_100, top_500, gr.update(visible=True), gr.update(visible=False) else: # Placeholder for tree generation functionality return "", "", gr.update(visible=False), gr.update(visible=True) def show_top_500(): return gr.update(visible=True) def update_ui(mode_selected): if mode_selected == "cosine distance token lists": return gr.update(visible=True), gr.update(visible=False) else: return gr.update(visible=False), gr.update(visible=True) inputs = [selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode] mode.change(update_ui, inputs=[mode], outputs=[cosine_output, tree_output]) generate_btn.click(update_output, inputs=inputs, outputs=[output_100, output_500, cosine_output, tree_output]) show_500_btn.click(show_top_500, outputs=output_500) return demo