Spaces:
Sleeping
Sleeping
Nu Appleblossom
commited on
Commit
·
2eae6d2
1
Parent(s):
0f58a0a
updated appliation
Browse files
app.py
CHANGED
@@ -1,14 +1,268 @@
|
|
|
|
1 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import logging
|
|
|
|
|
|
|
3 |
|
4 |
# Set up logging
|
5 |
logging.basicConfig(level=logging.INFO)
|
6 |
logger = logging.getLogger(__name__)
|
7 |
|
8 |
-
|
|
|
|
|
9 |
if not torch.cuda.is_available():
|
10 |
raise RuntimeError("GPU is required but not available. ZeroGPU may not be initialized properly.")
|
11 |
else:
|
12 |
logger.info(f"CUDA is available. Device: {torch.cuda.get_device_name(0)}")
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
5 |
+
from safetensors import safe_open
|
6 |
+
import os
|
7 |
+
import json
|
8 |
+
import math
|
9 |
+
import random
|
10 |
+
import numpy as np
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
from graphviz import Digraph
|
13 |
+
from PIL import Image, ImageDraw, ImageFont
|
14 |
+
from io import BytesIO
|
15 |
+
from sklearn.decomposition import PCA
|
16 |
import logging
|
17 |
+
import time
|
18 |
+
from dotenv import load_dotenv
|
19 |
+
from huggingface_hub import hf_hub_download
|
20 |
|
21 |
# Set up logging
|
22 |
logging.basicConfig(level=logging.INFO)
|
23 |
logger = logging.getLogger(__name__)
|
24 |
|
25 |
+
logger.info(f"HF_TOKEN_GEMMA set: {'HF_TOKEN_GEMMA' in os.environ}")
|
26 |
+
logger.info(f"HF_TOKEN_EMBEDDINGS set: {'HF_TOKEN_EMBEDDINGS' in os.environ}")
|
27 |
+
|
28 |
if not torch.cuda.is_available():
|
29 |
raise RuntimeError("GPU is required but not available. ZeroGPU may not be initialized properly.")
|
30 |
else:
|
31 |
logger.info(f"CUDA is available. Device: {torch.cuda.get_device_name(0)}")
|
32 |
+
|
33 |
+
# Load environment variables
|
34 |
+
load_dotenv()
|
35 |
+
|
36 |
+
class Config:
|
37 |
+
def __init__(self):
|
38 |
+
self.MODEL_NAME = "google/gemma-2b"
|
39 |
+
self.ACCESS_TOKEN = os.environ.get("HF_TOKEN_GEMMA")
|
40 |
+
self.EMBEDDINGS_TOKEN = os.environ.get("HF_TOKEN_EMBEDDINGS")
|
41 |
+
self.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
42 |
+
self.DTYPE = torch.float32
|
43 |
+
self.TOPK = 5
|
44 |
+
self.CUTOFF = 0.00001 # Cumulative probability cutoff for tree branches
|
45 |
+
self.OUTPUT_LENGTH = 20
|
46 |
+
self.SUB_TOKEN_ID = 23070 # Arbitrary token ID to overwrite with embedding
|
47 |
+
self.LOG_BASE = 10
|
48 |
+
|
49 |
+
config = Config()
|
50 |
+
|
51 |
+
def load_tokenizer_and_model():
|
52 |
+
try:
|
53 |
+
logger.info(f"Loading tokenizer and model with token: {config.ACCESS_TOKEN[:5]}...")
|
54 |
+
tokenizer = AutoTokenizer.from_pretrained(config.MODEL_NAME, use_auth_token=config.ACCESS_TOKEN)
|
55 |
+
model = AutoModelForCausalLM.from_pretrained(config.MODEL_NAME, device_map="auto", use_auth_token=config.ACCESS_TOKEN)
|
56 |
+
model.to(config.DEVICE) # Ensure the model is on the correct device
|
57 |
+
logger.info("Model and tokenizer loaded successfully")
|
58 |
+
return model, tokenizer
|
59 |
+
except Exception as e:
|
60 |
+
logger.error(f"Error loading model or tokenizer: {str(e)}")
|
61 |
+
return None, None
|
62 |
+
|
63 |
+
def get_embeddings(model):
|
64 |
+
return model.get_input_embeddings().weight.data.to(config.DEVICE)
|
65 |
+
|
66 |
+
|
67 |
+
def update_token_embedding(model, token_id, new_embedding):
|
68 |
+
new_embedding = new_embedding.to(model.get_input_embeddings().weight.device)
|
69 |
+
model.get_input_embeddings().weight.data[token_id] = new_embedding
|
70 |
+
|
71 |
+
def produce_next_token_ids(input_ids, model, topk, sub_token_id):
|
72 |
+
input_ids = input_ids.to(model.device)
|
73 |
+
with torch.no_grad():
|
74 |
+
outputs = model(input_ids)
|
75 |
+
logits = outputs.logits
|
76 |
+
last_logits = logits[:, -1, :]
|
77 |
+
last_logits[:, sub_token_id] = float('-inf')
|
78 |
+
softmax_probs = torch.softmax(last_logits, dim=-1)
|
79 |
+
top_k_probs, top_k_ids = torch.topk(softmax_probs, k=topk, dim=-1)
|
80 |
+
return top_k_ids[0], top_k_probs[0]
|
81 |
+
|
82 |
+
def build_def_tree(input_ids, data, base_prompt, model, tokenizer, config, depth=0, max_depth=25, cumulative_prob=1.0):
|
83 |
+
if depth >= max_depth or cumulative_prob < config.CUTOFF:
|
84 |
+
return
|
85 |
+
|
86 |
+
current_prompt = tokenizer.decode(input_ids[0], skip_special_tokens=True)
|
87 |
+
top_k_ids, top_k_probs = produce_next_token_ids(input_ids, model, config.TOPK, config.SUB_TOKEN_ID)
|
88 |
+
|
89 |
+
for idx, token_id in enumerate(top_k_ids.tolist()):
|
90 |
+
if token_id == config.SUB_TOKEN_ID:
|
91 |
+
continue # Skip the substitute token to avoid circular definitions
|
92 |
+
|
93 |
+
token_id_tensor = torch.tensor([token_id], dtype=torch.long).to(model.device)
|
94 |
+
new_input_ids = torch.cat([input_ids, token_id_tensor.view(1, 1)], dim=-1)
|
95 |
+
|
96 |
+
new_cumulative_prob = cumulative_prob * top_k_probs[idx].item()
|
97 |
+
|
98 |
+
if new_cumulative_prob < config.CUTOFF:
|
99 |
+
continue
|
100 |
+
|
101 |
+
token_str = tokenizer.decode([token_id], skip_special_tokens=True)
|
102 |
+
|
103 |
+
new_child = {
|
104 |
+
"token_id": token_id,
|
105 |
+
"token": token_str,
|
106 |
+
"cumulative_prob": new_cumulative_prob,
|
107 |
+
"children": []
|
108 |
+
}
|
109 |
+
data['children'].append(new_child)
|
110 |
+
|
111 |
+
build_def_tree(new_input_ids, new_child, base_prompt, model, tokenizer, config, depth=depth+1, max_depth=max_depth, cumulative_prob=new_cumulative_prob)
|
112 |
+
|
113 |
+
def generate_definition_tree(base_prompt, embedding, model, tokenizer, config):
|
114 |
+
results_dict = {"token": "", "cumulative_prob": 1, "children": []}
|
115 |
+
|
116 |
+
# Reset the token embedding
|
117 |
+
token_embedding = torch.unsqueeze(embedding, dim=0).to(model.device)
|
118 |
+
update_token_embedding(model, config.SUB_TOKEN_ID, token_embedding)
|
119 |
+
|
120 |
+
# Clear the model's cache if it has one
|
121 |
+
if hasattr(model, 'reset_cache'):
|
122 |
+
model.reset_cache()
|
123 |
+
|
124 |
+
input_ids = tokenizer.encode(base_prompt, return_tensors="pt").to(model.device)
|
125 |
+
build_def_tree(input_ids, results_dict, base_prompt, model, tokenizer, config)
|
126 |
+
|
127 |
+
return results_dict
|
128 |
+
|
129 |
+
def find_max_min_cumulative_weight(node, current_max=0, current_min=float('inf')):
|
130 |
+
current_max = max(current_max, node.get('cumulative_prob', 0))
|
131 |
+
if node.get('cumulative_prob', 1) > 0:
|
132 |
+
current_min = min(current_min, node.get('cumulative_prob', 1))
|
133 |
+
for child in node.get('children', []):
|
134 |
+
current_max, current_min = find_max_min_cumulative_weight(child, current_max, current_min)
|
135 |
+
return current_max, current_min
|
136 |
+
|
137 |
+
def scale_edge_width(cumulative_weight, max_weight, min_weight, log_base, max_thickness=33, min_thickness=1):
|
138 |
+
cumulative_weight = max(cumulative_weight, min_weight)
|
139 |
+
log_weight = math.log(cumulative_weight, log_base) - math.log(min_weight, log_base)
|
140 |
+
log_max = math.log(max_weight, log_base) - math.log(min_weight, log_base)
|
141 |
+
amplified_weight = (log_weight / log_max) ** 2.5
|
142 |
+
scaled_weight = (amplified_weight * (max_thickness - min_thickness)) + min_thickness
|
143 |
+
return scaled_weight
|
144 |
+
|
145 |
+
def add_nodes_edges(dot, node, config, max_weight, min_weight, parent=None, is_root=True, depth=0, trim_cutoff=0):
|
146 |
+
node_id = str(id(node))
|
147 |
+
token = node.get('token', '').strip()
|
148 |
+
cumulative_prob = node.get('cumulative_prob', 1)
|
149 |
+
|
150 |
+
if cumulative_prob < trim_cutoff and not is_root:
|
151 |
+
return
|
152 |
+
|
153 |
+
if is_root or token:
|
154 |
+
if parent and not is_root:
|
155 |
+
edge_weight = scale_edge_width(cumulative_prob, max_weight, min_weight, config.LOG_BASE)
|
156 |
+
dot.edge(parent, node_id, arrowhead='dot', arrowsize='1', color='darkblue', penwidth=str(edge_weight))
|
157 |
+
|
158 |
+
label = "*" if is_root else token
|
159 |
+
dot.node(node_id, label=label, shape='plaintext', fontsize="36", fontname='Helvetica')
|
160 |
+
|
161 |
+
for child in node.get('children', []):
|
162 |
+
add_nodes_edges(dot, child, config, max_weight, min_weight, parent=node_id, is_root=False, depth=depth+1, trim_cutoff=trim_cutoff)
|
163 |
+
|
164 |
+
def create_tree_diagram(data, config, trim_cutoff=0):
|
165 |
+
dot = Digraph(comment='Definition Tree', format='png')
|
166 |
+
dot.attr(rankdir='LR', size='5040,5000', margin='0.06', nodesep='0.06', ranksep='1', dpi='120', bgcolor='white')
|
167 |
+
|
168 |
+
max_weight, min_weight = find_max_min_cumulative_weight(data)
|
169 |
+
add_nodes_edges(dot, data, config, max_weight, min_weight, trim_cutoff=trim_cutoff)
|
170 |
+
|
171 |
+
output = BytesIO()
|
172 |
+
dot.render(outfile=output, format='png')
|
173 |
+
output.seek(0)
|
174 |
+
|
175 |
+
# Add white background
|
176 |
+
with Image.open(output) as img:
|
177 |
+
bg = Image.new("RGB", (img.width, 5000), (255, 255, 255))
|
178 |
+
y_offset = (5000 - img.height) // 2
|
179 |
+
bg.paste(img, (0, y_offset))
|
180 |
+
final_output = BytesIO()
|
181 |
+
bg.save(final_output, 'PNG')
|
182 |
+
final_output.seek(0)
|
183 |
+
|
184 |
+
return final_output
|
185 |
+
|
186 |
+
def get_neuronpedia_url(layer, feature):
|
187 |
+
return f"https://neuronpedia.org/gemma-2b/{layer}-res-jb/{feature}?embed=true&embedexplanation=true&embedplots=false&embedtest=false&height=300"
|
188 |
+
|
189 |
+
@torch.no_grad()
|
190 |
+
def generate_definition_tree_placeholder(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight):
|
191 |
+
return "Definition tree generation placeholder"
|
192 |
+
|
193 |
+
def gradio_interface():
|
194 |
+
model, tokenizer = load_tokenizer_and_model()
|
195 |
+
if model is None or tokenizer is None:
|
196 |
+
return gr.Interface(lambda: "Failed to load model and tokenizer. Please check the logs for more details.", inputs=[], outputs="text")
|
197 |
+
|
198 |
+
embeddings = get_embeddings(model)
|
199 |
+
|
200 |
+
with gr.Blocks() as demo:
|
201 |
+
gr.Markdown("# Gemma-2B SAE Feature Explorer")
|
202 |
+
|
203 |
+
with gr.Row():
|
204 |
+
with gr.Column():
|
205 |
+
selected_sae = gr.Dropdown(choices=["Gemma-2B layer 0", "Gemma-2B layer 6", "Gemma-2B layer 10", "Gemma-2B layer 12"], label="Select SAE")
|
206 |
+
feature_number = gr.Number(label="Select feature number", minimum=0, maximum=16383, value=0)
|
207 |
+
|
208 |
+
mode = gr.Radio(
|
209 |
+
choices=["cosine distance token lists", "definition tree generation"],
|
210 |
+
label="Select mode",
|
211 |
+
value="cosine distance token lists"
|
212 |
+
)
|
213 |
+
|
214 |
+
weight_type = gr.Radio(["encoder", "decoder"], label="Select weight type for feature vector construction", value="encoder")
|
215 |
+
use_token_centroid = gr.Checkbox(label="Use token centroid offset", value=True)
|
216 |
+
scaling_factor = gr.Slider(minimum=0.1, maximum=10.0, value=3.8, label="Scaling factor (3.8 is mean distance from token embeddings to token centroid)")
|
217 |
+
|
218 |
+
num_exp = gr.Slider(minimum=0.1, maximum=5.0, value=1.4, label="Numerator exponent m")
|
219 |
+
denom_exp = gr.Slider(minimum=0.1, maximum=5.0, value=1.0, label="Denominator exponent n")
|
220 |
+
use_pca = gr.Checkbox(label="Introduce first PCA component")
|
221 |
+
pca_weight = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="PCA weight")
|
222 |
+
|
223 |
+
with gr.Column():
|
224 |
+
output = gr.Image(label="Tree Diagram Output")
|
225 |
+
neuronpedia_embed = gr.HTML(label="Neuronpedia Embed")
|
226 |
+
trim_slider = gr.Slider(minimum=0.00001, maximum=0.1, value=0.00001, label="Trim cutoff for cumulative probability")
|
227 |
+
|
228 |
+
def update_output(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode, trim_cutoff):
|
229 |
+
neuronpedia_url = get_neuronpedia_url(selected_sae.split(" ")[-1], feature_number)
|
230 |
+
neuronpedia_embed.update(value=f'<iframe src="{neuronpedia_url}" width="100%" height="300" frameborder="0"></iframe>')
|
231 |
+
|
232 |
+
if mode == "cosine distance token lists":
|
233 |
+
# Keep the original functionality here
|
234 |
+
pass
|
235 |
+
elif mode == "definition tree generation":
|
236 |
+
embedding = embeddings[int(feature_number)].to(config.DEVICE)
|
237 |
+
if use_token_centroid:
|
238 |
+
token_centroid = torch.mean(embeddings, dim=0).to(config.DEVICE)
|
239 |
+
embedding = token_centroid + scaling_factor * (embedding - token_centroid) / torch.norm(embedding - token_centroid)
|
240 |
+
|
241 |
+
base_prompt = f'A typical definition of "{tokenizer.decode([config.SUB_TOKEN_ID], skip_special_tokens=True)}" would be "'
|
242 |
+
results_dict = generate_definition_tree(base_prompt, embedding, model, tokenizer, config)
|
243 |
+
tree_diagram = create_tree_diagram(results_dict, config, trim_cutoff=trim_cutoff)
|
244 |
+
return tree_diagram
|
245 |
+
|
246 |
+
def update_ui(mode_selected):
|
247 |
+
show_cosine_controls = mode_selected == "cosine distance token lists"
|
248 |
+
return (
|
249 |
+
gr.update(visible=show_cosine_controls),
|
250 |
+
gr.update(visible=show_cosine_controls),
|
251 |
+
gr.update(visible=show_cosine_controls),
|
252 |
+
gr.update(visible=show_cosine_controls)
|
253 |
+
)
|
254 |
+
|
255 |
+
inputs = [selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode, trim_slider]
|
256 |
+
mode.change(
|
257 |
+
update_ui, inputs=[mode],
|
258 |
+
outputs=[num_exp, denom_exp, output]
|
259 |
+
)
|
260 |
+
|
261 |
+
gr.Button("Generate Output").click(update_output, inputs=inputs, outputs=[output])
|
262 |
+
|
263 |
+
return demo
|
264 |
+
|
265 |
+
|
266 |
+
if __name__ == "__main__":
|
267 |
+
iface = gradio_interface()
|
268 |
+
iface.launch()
|