Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import torch.nn.functional as F | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from safetensors import safe_open | |
import os | |
import requests | |
import json | |
import math | |
import numpy as np | |
from sklearn.decomposition import PCA | |
import logging | |
import time | |
from dotenv import load_dotenv | |
from huggingface_hub import hf_hub_download | |
import spaces | |
import traceback | |
from PIL import Image, ImageDraw, ImageFont | |
from io import BytesIO | |
import functools | |
import logging | |
import tempfile | |
# Set up custom logger | |
custom_logger = logging.getLogger("custom_logger") | |
custom_logger.setLevel(logging.INFO) | |
# Prevent the root logger from duplicating messages | |
custom_logger.propagate = False | |
# Set up custom handler and formatter | |
custom_handler = logging.StreamHandler() | |
custom_handler.setFormatter(logging.Formatter('%(message)s')) | |
custom_logger.addHandler(custom_handler) | |
# Load environment variables | |
load_dotenv() | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
logger.info(f"HF_TOKEN_GEMMA set: {'HF_TOKEN_GEMMA' in os.environ}") | |
logger.info(f"HF_TOKEN_EMBEDDINGS set: {'HF_TOKEN_EMBEDDINGS' in os.environ}") | |
class Config: | |
def __init__(self): | |
self.MODEL_NAME = "google/gemma-2b" | |
self.ACCESS_TOKEN = os.environ.get("HF_TOKEN_GEMMA") | |
self.EMBEDDINGS_TOKEN = os.environ.get("HF_TOKEN_EMBEDDINGS") | |
self.DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
self.DTYPE = torch.float32 | |
self.TOPK = 5 | |
self.CUTOFF = 0.0001 # Cumulative probability cutoff for tree branches | |
self.OUTPUT_LENGTH = 20 | |
self.SUB_TOKEN_ID = 23070 # Arbitrary token ID to overwrite with embedding (token = "OSS") | |
self.LOG_BASE = 10 | |
def get_sub_token_string(self, tokenizer): | |
return tokenizer.decode([self.SUB_TOKEN_ID]) | |
config = Config() | |
def load_tokenizer(): | |
try: | |
logger.info(f"Attempting to load tokenizer with token: {config.ACCESS_TOKEN[:5]}...") | |
tokenizer = AutoTokenizer.from_pretrained(config.MODEL_NAME, token=config.ACCESS_TOKEN) | |
logger.info("Tokenizer loaded successfully") | |
return tokenizer | |
except Exception as e: | |
logger.error(f"Error loading tokenizer: {str(e)}") | |
return None | |
def load_model(): | |
try: | |
logger.info(f"Attempting to load model with token: {config.ACCESS_TOKEN[:5]}...") | |
model = AutoModelForCausalLM.from_pretrained(config.MODEL_NAME, device_map="auto", token=config.ACCESS_TOKEN) | |
logger.info("Model loaded successfully") | |
return model | |
except Exception as e: | |
logger.error(f"Error loading model: {str(e)}") | |
return None | |
def load_token_embeddings(): | |
try: | |
logger.info(f"Attempting to load token embeddings with token: {config.EMBEDDINGS_TOKEN[:5]}...") | |
embeddings_path = hf_hub_download( | |
repo_id="mwatkins1970/gemma-2b-embeddings", | |
filename="gemma_2b_embeddings.pt", | |
token=config.EMBEDDINGS_TOKEN | |
) | |
logger.info(f"Embeddings downloaded to: {embeddings_path}") | |
embeddings = torch.load(embeddings_path, map_location=config.DEVICE, weights_only=True) | |
logger.info("Embeddings loaded successfully") | |
return embeddings.to(dtype=config.DTYPE) | |
except Exception as e: | |
logger.error(f"Error loading token embeddings: {str(e)}") | |
return None | |
def load_sae_weights(sae_name): | |
start_time = time.time() | |
base_url = 'https://huggingface.co/jbloom/Gemma-2b-Residual-Stream-SAEs/resolve/main/' | |
sae_urls = { | |
"Gemma-2B layer 6": "gemma_2b_blocks.6.hook_resid_post_16384_anthropic_fast_lr/sae_weights.safetensors", | |
"Gemma-2B layer 0": "gemma_2b_blocks.0.hook_resid_post_16384_anthropic/sae_weights.safetensors", | |
"Gemma-2B layer 10": "gemma_2b_blocks.10.hook_resid_post_16384/sae_weights.safetensors", | |
"Gemma-2B layer 12": "gemma_2b_blocks.12.hook_resid_post_16384/sae_weights.safetensors" | |
} | |
if sae_name not in sae_urls: | |
raise ValueError(f"Unknown SAE: {sae_name}") | |
url = f'{base_url}{sae_urls[sae_name]}?download=true' | |
local_filename = f'sae_{sae_name.replace(" ", "_").lower()}.safetensors' | |
if not os.path.exists(local_filename): | |
try: | |
response = requests.get(url) | |
response.raise_for_status() | |
with open(local_filename, 'wb') as f: | |
f.write(response.content) | |
logger.info(f'SAE weights for {sae_name} downloaded successfully!') | |
except requests.RequestException as e: | |
logger.error(f"Failed to download SAE weights for {sae_name}: {str(e)}") | |
return None, None | |
try: | |
with safe_open(local_filename, framework="pt") as f: | |
w_dec = f.get_tensor("W_dec").to(device=config.DEVICE, dtype=config.DTYPE) | |
w_enc = f.get_tensor("W_enc").to(device=config.DEVICE, dtype=config.DTYPE) | |
logger.info(f"Successfully loaded weights for {sae_name}") | |
logger.info(f"Time taken to load weights: {time.time() - start_time:.2f} seconds") | |
return w_enc, w_dec | |
except Exception as e: | |
logger.error(f"Error loading SAE weights for {sae_name}: {str(e)}") | |
return None, None | |
def create_feature_vector(w_enc, w_dec, feature_number, weight_type, token_centroid, use_token_centroid, scaling_factor): | |
if weight_type == "encoder": | |
feature_vector = w_enc[:, feature_number] | |
else: | |
feature_vector = w_dec[feature_number] | |
if use_token_centroid: | |
feature_vector = token_centroid + scaling_factor * (feature_vector - token_centroid) / torch.norm(feature_vector - token_centroid) | |
return feature_vector | |
def perform_pca(_embeddings): | |
try: | |
logger.info(f"Starting PCA. Embeddings shape: {_embeddings.shape}") | |
pca = PCA(n_components=1) | |
embeddings_cpu = _embeddings.detach().cpu().numpy() | |
logger.info(f"Embeddings converted to numpy. Shape: {embeddings_cpu.shape}") | |
pca.fit(embeddings_cpu) | |
logger.info("PCA fit completed") | |
pca_direction = torch.tensor(pca.components_[0], dtype=config.DTYPE, device=config.DEVICE) | |
logger.info(f"PCA direction calculated. Shape: {pca_direction.shape}") | |
normalized_direction = F.normalize(pca_direction, p=2, dim=0) | |
logger.info(f"PCA direction normalized. Shape: {normalized_direction.shape}") | |
return normalized_direction | |
except Exception as e: | |
logger.error(f"Error in perform_pca: {str(e)}") | |
logger.error(f"Embeddings stats - min: {_embeddings.min()}, max: {_embeddings.max()}, mean: {_embeddings.mean()}, std: {_embeddings.std()}") | |
logger.error(traceback.format_exc()) | |
raise RuntimeError(f"PCA calculation failed: {str(e)}") | |
def create_ghost_token(_feature_vector, _token_centroid, _pca_direction, target_distance, pca_weight): | |
feature_direction = F.normalize(_feature_vector - _token_centroid, p=2, dim=0) | |
combined_direction = (1 - pca_weight) * feature_direction + pca_weight * _pca_direction | |
combined_direction = F.normalize(combined_direction, p=2, dim=0) | |
return _token_centroid + target_distance * combined_direction | |
def find_closest_tokens(_emb, _token_embeddings, _tokenizer, top_k=500, num_exp=1.4, denom_exp=1.0): | |
token_centroid = torch.mean(_token_embeddings, dim=0) | |
emb_norm = F.normalize(_emb.view(1, -1), p=2, dim=1) | |
centroid_norm = F.normalize(token_centroid.view(1, -1), p=2, dim=1) | |
normalized_embeddings = F.normalize(_token_embeddings, p=2, dim=1) | |
similarities_emb = torch.mm(emb_norm, normalized_embeddings.t()).squeeze() | |
similarities_centroid = torch.mm(centroid_norm, normalized_embeddings.t()).squeeze() | |
distances_emb = torch.pow(1 - similarities_emb, num_exp) | |
distances_centroid = torch.pow(1 - similarities_centroid, denom_exp) | |
ratios = distances_emb / distances_centroid | |
top_ratios, top_indices = torch.topk(ratios, k=top_k, largest=False) | |
closest_tokens = [_tokenizer.decode([idx.item()]) for idx in top_indices] | |
return list(zip(closest_tokens, top_ratios.tolist())) | |
def get_neuronpedia_url(layer, feature): | |
return f"https://neuronpedia.org/gemma-2b/{layer}-res-jb/{feature}?embed=true&embedexplanation=true&embedplots=false&embedtest=false&height=300" | |
# New functions for tree generation and visualization | |
def update_token_embedding(model, token_id, new_embedding): | |
new_embedding = new_embedding.to(model.get_input_embeddings().weight.device) | |
model.get_input_embeddings().weight.data[token_id] = new_embedding | |
def produce_next_token_ids(input_ids, model, topk, sub_token_id): | |
input_ids = input_ids.to(model.device) | |
with torch.no_grad(): | |
outputs = model(input_ids) | |
logits = outputs.logits | |
last_logits = logits[:, -1, :] | |
last_logits[:, sub_token_id] = float('-inf') | |
softmax_probs = torch.softmax(last_logits, dim=-1) | |
top_k_probs, top_k_ids = torch.topk(softmax_probs, k=topk, dim=-1) | |
return top_k_ids[0], top_k_probs[0] | |
def build_def_tree(input_ids, data, base_prompt, model, tokenizer, config, depth=0, max_depth=25, cumulative_prob=1.0): | |
if depth >= max_depth or cumulative_prob < config.CUTOFF: | |
return | |
current_prompt = tokenizer.decode(input_ids[0], skip_special_tokens=True) | |
# Extract only the part that extends the base prompt | |
extended_prompt = current_prompt[len(base_prompt):].strip() | |
extended_prompt = extended_prompt.replace("\n", "|") # Replace \n with | | |
# Format the line to align "PROB:..." vertically, with additional padding | |
formatted_line = f"Depth {depth}: {extended_prompt:<45} PROB: {cumulative_prob:.4f}" | |
# Log only the formatted line without the "INFO:custom_logger" prefix | |
custom_logger.info(formatted_line) | |
top_k_ids, top_k_probs = produce_next_token_ids(input_ids, model, config.TOPK, config.SUB_TOKEN_ID) | |
for idx, token_id in enumerate(top_k_ids.tolist()): | |
if token_id == config.SUB_TOKEN_ID: | |
continue # Skip the substitute token to avoid circular definitions | |
token_id_tensor = torch.tensor([token_id], dtype=torch.long).to(model.device) | |
new_input_ids = torch.cat([input_ids, token_id_tensor.view(1, 1)], dim=-1) | |
new_cumulative_prob = cumulative_prob * top_k_probs[idx].item() | |
if new_cumulative_prob < config.CUTOFF: | |
continue | |
token_str = tokenizer.decode([token_id], skip_special_tokens=True) | |
new_child = { | |
"token_id": token_id, | |
"token": token_str, | |
"cumulative_prob": new_cumulative_prob, | |
"children": [] | |
} | |
data['children'].append(new_child) | |
yield from build_def_tree(new_input_ids, new_child, base_prompt, model, tokenizer, config, depth=depth+1, max_depth=max_depth, cumulative_prob=new_cumulative_prob) | |
def generate_definition_tree(base_prompt, embedding, model, tokenizer, config): | |
logger.info(f"Starting generate_definition_tree with base_prompt: {base_prompt}") | |
results_dict = {"token": "", "cumulative_prob": 1, "children": []} | |
# Reset the token embedding | |
token_embedding = torch.unsqueeze(embedding, dim=0).to(model.device) | |
update_token_embedding(model, config.SUB_TOKEN_ID, token_embedding) | |
# Clear the model's cache if it has one | |
if hasattr(model, 'reset_cache'): | |
model.reset_cache() | |
input_ids = tokenizer.encode(base_prompt, return_tensors="pt").to(model.device) | |
logger.info(f"Encoded input_ids: {input_ids}") | |
for item in build_def_tree(input_ids, results_dict, base_prompt, model, tokenizer, config): | |
yield item | |
logger.info("Finished building tree, yielding results_dict") | |
yield results_dict | |
def find_max_min_cumulative_weight(node, current_max=0, current_min=float('inf')): | |
current_max = max(current_max, node.get('cumulative_prob', 0)) | |
if node.get('cumulative_prob', 1) > 0: | |
current_min = min(current_min, node.get('cumulative_prob', 1)) | |
for child in node.get('children', []): | |
current_max, current_min = find_max_min_cumulative_weight(child, current_max, current_min) | |
return current_max, current_min | |
def scale_edge_width(cumulative_weight, max_weight, min_weight, log_base, max_thickness=33, min_thickness=1): | |
cumulative_weight = max(cumulative_weight, min_weight) | |
log_weight = math.log(cumulative_weight, log_base) - math.log(min_weight, log_base) | |
log_max = math.log(max_weight, log_base) - math.log(min_weight, log_base) | |
amplified_weight = (log_weight / log_max) ** 2.5 | |
scaled_weight = (amplified_weight * (max_thickness - min_thickness)) + min_thickness | |
return scaled_weight | |
def add_nodes_edges(dot, node, config, max_weight, min_weight, parent=None, is_root=True, depth=0, trim_cutoff=0): | |
node_id = str(id(node)) | |
token = node.get('token', '').strip() | |
cumulative_prob = node.get('cumulative_prob', 1) | |
if cumulative_prob < trim_cutoff and not is_root: | |
return | |
if is_root or token: | |
if parent and not is_root: | |
edge_weight = scale_edge_width(cumulative_prob, max_weight, min_weight, config.LOG_BASE) | |
dot.edge(parent, node_id, arrowhead='dot', arrowsize='1', color='darkblue', penwidth=str(edge_weight)) | |
label = "*" if is_root else token | |
dot.node(node_id, label=label, shape='plaintext', fontsize="36", fontname='Helvetica') | |
for child in node.get('children', []): | |
add_nodes_edges(dot, child, config, max_weight, min_weight, parent=node_id, is_root=False, depth=depth+1, trim_cutoff=trim_cutoff) | |
import networkx as nx | |
import matplotlib.pyplot as plt | |
def add_nodes_edges_nx(G, node, parent=None, is_root=True): | |
node_id = str(id(node)) | |
token = node.get('token', '').strip() | |
if is_root or token: | |
G.add_node(node_id, label=token if not is_root else "*") | |
if parent: | |
G.add_edge(parent, node_id) | |
for child in node.get('children', []): | |
add_nodes_edges_nx(G, child, parent=node_id, is_root=False) | |
def create_tree_diagram(data, config, max_weight, min_weight, trim_cutoff=0): | |
G = nx.DiGraph() | |
add_nodes_edges_nx(G, data) | |
# Draw the tree using matplotlib | |
plt.figure(figsize=(12, 12)) | |
pos = nx.spring_layout(G, k=0.5, iterations=50) | |
labels = nx.get_node_attributes(G, 'label') | |
nx.draw(G, pos, labels=labels, with_labels=True, node_size=5000, node_color="lightblue", font_size=10, font_weight="bold", edge_color="gray", arrows=False) | |
# Save the image to a temporary file | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png") | |
plt.savefig(temp_file.name) | |
plt.close() | |
return temp_file.name | |
# Global variables to store loaded resources | |
tokenizer = None | |
model = None | |
token_embeddings = None | |
w_enc_dict = {} | |
w_dec_dict = {} | |
def cached_load_tokenizer(): | |
return load_tokenizer() | |
def cached_load_model(): | |
return load_model() | |
def cached_load_token_embeddings(): | |
return load_token_embeddings() | |
def initialize_resources(): | |
global tokenizer, model, token_embeddings | |
logger.info("Initializing resources...") | |
tokenizer = cached_load_tokenizer() | |
if tokenizer is None: | |
raise RuntimeError("Failed to load tokenizer.") | |
model = cached_load_model() | |
if model is None: | |
raise RuntimeError("Failed to load model.") | |
token_embeddings = cached_load_token_embeddings() | |
if token_embeddings is None: | |
raise RuntimeError("Failed to load token embeddings.") | |
logger.info("Resources initialized successfully.") | |
def process_input(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode, top_500=False, progress=None): | |
global w_enc_dict, w_dec_dict, model, tokenizer, token_embeddings | |
try: | |
logger.info(f"Processing input: SAE={selected_sae}, feature_number={feature_number}, mode={mode}") | |
# Load the SAE weights if they are not already loaded | |
if selected_sae not in w_enc_dict or selected_sae not in w_dec_dict: | |
logger.info("Loading SAE weights for {}".format(selected_sae)) | |
w_enc, w_dec = load_sae_weights(selected_sae) | |
if w_enc is None or w_dec is None: | |
error_message = f"Failed to load SAE weights for {selected_sae}. Please try a different SAE or check your connection." | |
logger.error(error_message) | |
return error_message, None | |
w_enc_dict[selected_sae] = w_enc | |
w_dec_dict[selected_sae] = w_dec | |
else: | |
w_enc, w_dec = w_enc_dict[selected_sae], w_dec_dict[selected_sae] | |
# Create the feature vector | |
token_centroid = torch.mean(token_embeddings, dim=0) | |
feature_vector = create_feature_vector(w_enc, w_dec, int(feature_number), weight_type, token_centroid, use_token_centroid, scaling_factor) | |
logger.info(f"Feature vector created. Shape: {feature_vector.shape}") | |
# Apply PCA if requested | |
if use_pca: | |
pca_direction = perform_pca(token_embeddings) | |
feature_vector = create_ghost_token(feature_vector, token_centroid, pca_direction, scaling_factor, pca_weight) | |
logger.info(f"PCA applied. New feature vector shape: {feature_vector.shape}") | |
if mode == "cosine distance token lists": | |
logger.info("Generating cosine distance token list") | |
closest_tokens_with_values = find_closest_tokens( | |
feature_vector, token_embeddings, tokenizer, | |
top_k=500, num_exp=num_exp, denom_exp=denom_exp | |
) | |
if top_500: | |
# Generate the top 500 list | |
result = ", ".join([f"'{token}': {value:.4f}" for token, value in closest_tokens_with_values]) | |
logger.info("Returning top 500 list") | |
return result, None | |
else: | |
# Generate the top 100 list | |
token_list = [token for token, _ in closest_tokens_with_values[:100]] | |
result = f"100 tokens whose embeddings produce the smallest ratio (cos distance to feature vector)^m/(cos distance to token centroid)^n:\n\n" | |
result += f"[{', '.join(repr(token) for token in token_list)}]\n" | |
logger.info("Returning top 100 tokens") | |
return result, None | |
elif mode == "definition tree generation": | |
logger.info("Generating definition tree") | |
base_prompt = f'A typical definition of "{config.get_sub_token_string(tokenizer)}" would be "' | |
tree_generator = generate_definition_tree(base_prompt, feature_vector, model, tokenizer, config) | |
# Collect the log output | |
log_output = [] | |
tree_data = None | |
for item in tree_generator: | |
if isinstance(item, str): | |
log_output.append(item) | |
else: | |
tree_data = item | |
# Join the log output into a single string | |
log_text = "\n".join(log_output) | |
# Generate the tree image | |
if tree_data: | |
logger.info("Generating tree image") | |
max_weight, min_weight = find_max_min_cumulative_weight(tree_data) | |
tree_image = create_tree_diagram(tree_data, config, max_weight, min_weight) | |
logger.info("Tree image generated successfully") | |
return log_text, tree_image | |
else: | |
logger.error("Failed to generate tree data") | |
return "Error: Failed to generate tree data.", None | |
return "Mode not recognized or not implemented in this step.", None | |
except Exception as e: | |
logger.error(f"Error in process_input: {str(e)}") | |
return f"Error: {str(e)}", None | |
finally: | |
del feature_vector | |
if 'token_centroid' in locals(): | |
del token_centroid | |
if use_pca and 'pca_direction' in locals(): | |
del pca_direction | |
torch.cuda.empty_cache() | |
def trim_tree(trim_cutoff, tree_data): | |
max_weight, min_weight = find_max_min_cumulative_weight(tree_data) | |
trimmed_tree_image = create_tree_diagram(tree_data, config, max_weight, min_weight, trim_cutoff=float(trim_cutoff)) | |
return trimmed_tree_image | |
def gradio_interface(): | |
def update_visibility(mode): | |
if mode == "definition tree generation": | |
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) | |
else: | |
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True) | |
def update_neuronpedia(selected_sae, feature_number): | |
layer_number = int(selected_sae.split()[-1]) | |
url = get_neuronpedia_url(layer_number, feature_number) | |
return f'<iframe src="{url}" width="100%" height="300px"></iframe>' | |
def update_output(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode, progress=gr.Progress()): | |
global w_enc_dict, w_dec_dict, model, tokenizer, token_embeddings | |
try: | |
logger.info(f"Processing input: SAE={selected_sae}, feature_number={feature_number}, mode={mode}") | |
# Load the SAE weights if they are not already loaded | |
if selected_sae not in w_enc_dict or selected_sae not in w_dec_dict: | |
logger.info("Loading SAE weights for {}".format(selected_sae)) | |
w_enc, w_dec = load_sae_weights(selected_sae) | |
if w_enc is None or w_dec is None: | |
error_message = f"Failed to load SAE weights for {selected_sae}. Please try a different SAE or check your connection." | |
logger.error(error_message) | |
return error_message, None | |
w_enc_dict[selected_sae] = w_enc | |
w_dec_dict[selected_sae] = w_dec | |
else: | |
w_enc, w_dec = w_enc_dict[selected_sae], w_dec_dict[selected_sae] | |
# Create the feature vector | |
token_centroid = torch.mean(token_embeddings, dim=0) | |
feature_vector = create_feature_vector(w_enc, w_dec, int(feature_number), weight_type, token_centroid, use_token_centroid, scaling_factor) | |
logger.info(f"Feature vector created. Shape: {feature_vector.shape}") | |
# Apply PCA if requested | |
if use_pca: | |
pca_direction = perform_pca(token_embeddings) | |
feature_vector = create_ghost_token(feature_vector, token_centroid, pca_direction, scaling_factor, pca_weight) | |
logger.info(f"PCA applied. New feature vector shape: {feature_vector.shape}") | |
if mode == "cosine distance token lists": | |
logger.info("Generating cosine distance token list") | |
closest_tokens_with_values = find_closest_tokens( | |
feature_vector, token_embeddings, tokenizer, | |
top_k=500, num_exp=num_exp, denom_exp=denom_exp | |
) | |
token_list = [token for token, _ in closest_tokens_with_values[:100]] | |
result = f"100 tokens whose embeddings produce the smallest ratio (cos distance to feature vector)^m/(cos distance to token centroid)^n:\n\n" | |
result += f"[{', '.join(repr(token) for token in token_list)}]\n" | |
logger.info("Returning top 100 tokens") | |
return result, None | |
elif mode == "definition tree generation": | |
logger.info("Generating definition tree") | |
base_prompt = f'A typical definition of "{config.get_sub_token_string(tokenizer)}" would be "' | |
tree_generator = generate_definition_tree(base_prompt, feature_vector, model, tokenizer, config) | |
# Collect the log output | |
log_output = [] | |
tree_data = None | |
for item in tree_generator: | |
if isinstance(item, str): | |
log_output.append(item) | |
logger.info(item) # Log each step | |
else: | |
tree_data = item | |
logger.info("Received tree data") | |
# Join the log output into a single string | |
log_text = "\n".join(log_output) | |
# Generate the tree image | |
if tree_data: | |
logger.info("Generating tree image") | |
max_weight, min_weight = find_max_min_cumulative_weight(tree_data) | |
tree_image = create_tree_diagram(tree_data, config, max_weight, min_weight) | |
logger.info("Tree image generated successfully") | |
return log_text, tree_image | |
else: | |
logger.error("Failed to generate tree data") | |
return "Error: Failed to generate tree data.", None | |
return "Mode not recognized or not implemented in this step.", None | |
except Exception as e: | |
logger.error(f"Error in process_input: {str(e)}") | |
return f"Error: {str(e)}", None | |
finally: | |
del feature_vector | |
del token_centroid | |
if use_pca: | |
del pca_direction | |
torch.cuda.empty_cache() | |
def generate_top_500(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode): | |
# Call process_input with top_500=True to generate the full list | |
return process_input(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode, top_500=True) | |
def trim_tree(trim_cutoff, tree_data): | |
if tree_data is None: | |
return None | |
max_weight, min_weight = find_max_min_cumulative_weight(tree_data) | |
trimmed_tree_image = create_tree_diagram(tree_data, config, max_weight, min_weight, trim_cutoff=float(trim_cutoff)) | |
return trimmed_tree_image | |
with gr.Blocks() as demo: | |
gr.Markdown("# Gemma-2B SAE Feature Explorer") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
selected_sae = gr.Dropdown(choices=["Gemma-2B layer 0", "Gemma-2B layer 6", "Gemma-2B layer 10", "Gemma-2B layer 12"], label="Select SAE") | |
feature_number = gr.Number(label="Select feature number", minimum=0, maximum=16383, value=0) | |
mode = gr.Radio( | |
choices=["cosine distance token lists", "definition tree generation"], | |
label="Select mode", | |
value="cosine distance token lists" | |
) | |
weight_type = gr.Radio(["encoder", "decoder"], label="Select weight type for feature vector construction", value="encoder") | |
use_token_centroid = gr.Checkbox(label="Use token centroid offset", value=True) | |
scaling_factor = gr.Slider(minimum=0.1, maximum=10.0, value=3.8, label="Scaling factor (3.8 is mean distance from token embeddings to token centroid)") | |
num_exp = gr.Slider(minimum=0.1, maximum=5.0, value=1.4, label="Numerator exponent m") | |
denom_exp = gr.Slider(minimum=0.1, maximum=5.0, value=1.0, label="Denominator exponent n") | |
use_pca = gr.Checkbox(label="Introduce first PCA component") | |
pca_weight = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="PCA weight") | |
with gr.Column(scale=3): | |
generate_btn = gr.Button("Generate Output") | |
output_stream = gr.Textbox(label="Output", lines=20) | |
output_image = gr.Image(label="Tree Diagram", visible=False) | |
generate_btn.click( | |
update_output, | |
inputs=[selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode], | |
outputs=[output_stream, output_image], | |
show_progress="full" | |
) | |
generate_top_500_btn = gr.Button("Generate Top 500 Tokens and Power Ratios", visible=True) | |
output_500_text = gr.Textbox(label="Top 500 Output", lines=20, visible=False) | |
trim_slider = gr.Slider(minimum=0.00001, maximum=0.1, value=0.00001, label="Trim cutoff for cumulative probability", visible=False) | |
trim_btn = gr.Button("Trim Tree", visible=False) | |
tree_data_state = gr.State() | |
neuronpedia_html = gr.HTML(label="Neuronpedia") | |
inputs = [selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode] | |
generate_btn.click( | |
update_output, | |
inputs=inputs, | |
outputs=[output_stream, output_image], | |
show_progress="full" | |
).then(lambda: gr.update(visible=False, value=""), None, [output_500_text]) | |
generate_top_500_btn.click( | |
generate_top_500, | |
inputs=inputs, | |
outputs=[output_500_text], | |
show_progress="full" | |
).then(lambda: gr.update(visible=True), None, [output_500_text]) | |
trim_btn.click(trim_tree, inputs=[trim_slider, tree_data_state], outputs=[output_image]) | |
mode.change(update_visibility, inputs=[mode], outputs=[output_image, trim_slider, trim_btn, generate_top_500_btn, output_500_text]) | |
selected_sae.change(update_neuronpedia, inputs=[selected_sae, feature_number], outputs=[neuronpedia_html]) | |
feature_number.change(update_neuronpedia, inputs=[selected_sae, feature_number], outputs=[neuronpedia_html]) | |
output_stream.change( | |
lambda text: (gr.update(visible=True), gr.update(visible=True)) if "100 tokens" in text else (gr.update(visible=False), gr.update(visible=False)), | |
inputs=[output_stream], | |
outputs=[generate_top_500_btn, output_500_text] | |
) | |
return demo | |
if __name__ == "__main__": | |
try: | |
logger.info("Starting application initialization...") | |
initialize_resources() | |
logger.info("Creating Gradio interface...") | |
iface = gradio_interface() | |
logger.info("Launching Gradio interface...") | |
iface.launch() | |
logger.info("Gradio interface launched successfully") | |
except Exception as e: | |
logger.error(f"Error during application startup: {str(e)}") | |
logger.error(traceback.format_exc()) |