Nu Appleblossom commited on
Commit
2eae6d2
·
1 Parent(s): 0f58a0a

updated appliation

Browse files
Files changed (1) hide show
  1. app.py +257 -3
app.py CHANGED
@@ -1,14 +1,268 @@
 
1
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import logging
 
 
 
3
 
4
  # Set up logging
5
  logging.basicConfig(level=logging.INFO)
6
  logger = logging.getLogger(__name__)
7
 
8
- # Check if CUDA is available
 
 
9
  if not torch.cuda.is_available():
10
  raise RuntimeError("GPU is required but not available. ZeroGPU may not be initialized properly.")
11
  else:
12
  logger.info(f"CUDA is available. Device: {torch.cuda.get_device_name(0)}")
13
- logger.info(f"Current CUDA device: {torch.cuda.current_device()}")
14
- logger.info(f"Memory allocated: {torch.cuda.memory_allocated()} bytes")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
  import torch
3
+ import torch.nn.functional as F
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ from safetensors import safe_open
6
+ import os
7
+ import json
8
+ import math
9
+ import random
10
+ import numpy as np
11
+ import matplotlib.pyplot as plt
12
+ from graphviz import Digraph
13
+ from PIL import Image, ImageDraw, ImageFont
14
+ from io import BytesIO
15
+ from sklearn.decomposition import PCA
16
  import logging
17
+ import time
18
+ from dotenv import load_dotenv
19
+ from huggingface_hub import hf_hub_download
20
 
21
  # Set up logging
22
  logging.basicConfig(level=logging.INFO)
23
  logger = logging.getLogger(__name__)
24
 
25
+ logger.info(f"HF_TOKEN_GEMMA set: {'HF_TOKEN_GEMMA' in os.environ}")
26
+ logger.info(f"HF_TOKEN_EMBEDDINGS set: {'HF_TOKEN_EMBEDDINGS' in os.environ}")
27
+
28
  if not torch.cuda.is_available():
29
  raise RuntimeError("GPU is required but not available. ZeroGPU may not be initialized properly.")
30
  else:
31
  logger.info(f"CUDA is available. Device: {torch.cuda.get_device_name(0)}")
32
+
33
+ # Load environment variables
34
+ load_dotenv()
35
+
36
+ class Config:
37
+ def __init__(self):
38
+ self.MODEL_NAME = "google/gemma-2b"
39
+ self.ACCESS_TOKEN = os.environ.get("HF_TOKEN_GEMMA")
40
+ self.EMBEDDINGS_TOKEN = os.environ.get("HF_TOKEN_EMBEDDINGS")
41
+ self.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
42
+ self.DTYPE = torch.float32
43
+ self.TOPK = 5
44
+ self.CUTOFF = 0.00001 # Cumulative probability cutoff for tree branches
45
+ self.OUTPUT_LENGTH = 20
46
+ self.SUB_TOKEN_ID = 23070 # Arbitrary token ID to overwrite with embedding
47
+ self.LOG_BASE = 10
48
+
49
+ config = Config()
50
+
51
+ def load_tokenizer_and_model():
52
+ try:
53
+ logger.info(f"Loading tokenizer and model with token: {config.ACCESS_TOKEN[:5]}...")
54
+ tokenizer = AutoTokenizer.from_pretrained(config.MODEL_NAME, use_auth_token=config.ACCESS_TOKEN)
55
+ model = AutoModelForCausalLM.from_pretrained(config.MODEL_NAME, device_map="auto", use_auth_token=config.ACCESS_TOKEN)
56
+ model.to(config.DEVICE) # Ensure the model is on the correct device
57
+ logger.info("Model and tokenizer loaded successfully")
58
+ return model, tokenizer
59
+ except Exception as e:
60
+ logger.error(f"Error loading model or tokenizer: {str(e)}")
61
+ return None, None
62
+
63
+ def get_embeddings(model):
64
+ return model.get_input_embeddings().weight.data.to(config.DEVICE)
65
+
66
+
67
+ def update_token_embedding(model, token_id, new_embedding):
68
+ new_embedding = new_embedding.to(model.get_input_embeddings().weight.device)
69
+ model.get_input_embeddings().weight.data[token_id] = new_embedding
70
+
71
+ def produce_next_token_ids(input_ids, model, topk, sub_token_id):
72
+ input_ids = input_ids.to(model.device)
73
+ with torch.no_grad():
74
+ outputs = model(input_ids)
75
+ logits = outputs.logits
76
+ last_logits = logits[:, -1, :]
77
+ last_logits[:, sub_token_id] = float('-inf')
78
+ softmax_probs = torch.softmax(last_logits, dim=-1)
79
+ top_k_probs, top_k_ids = torch.topk(softmax_probs, k=topk, dim=-1)
80
+ return top_k_ids[0], top_k_probs[0]
81
+
82
+ def build_def_tree(input_ids, data, base_prompt, model, tokenizer, config, depth=0, max_depth=25, cumulative_prob=1.0):
83
+ if depth >= max_depth or cumulative_prob < config.CUTOFF:
84
+ return
85
+
86
+ current_prompt = tokenizer.decode(input_ids[0], skip_special_tokens=True)
87
+ top_k_ids, top_k_probs = produce_next_token_ids(input_ids, model, config.TOPK, config.SUB_TOKEN_ID)
88
+
89
+ for idx, token_id in enumerate(top_k_ids.tolist()):
90
+ if token_id == config.SUB_TOKEN_ID:
91
+ continue # Skip the substitute token to avoid circular definitions
92
+
93
+ token_id_tensor = torch.tensor([token_id], dtype=torch.long).to(model.device)
94
+ new_input_ids = torch.cat([input_ids, token_id_tensor.view(1, 1)], dim=-1)
95
+
96
+ new_cumulative_prob = cumulative_prob * top_k_probs[idx].item()
97
+
98
+ if new_cumulative_prob < config.CUTOFF:
99
+ continue
100
+
101
+ token_str = tokenizer.decode([token_id], skip_special_tokens=True)
102
+
103
+ new_child = {
104
+ "token_id": token_id,
105
+ "token": token_str,
106
+ "cumulative_prob": new_cumulative_prob,
107
+ "children": []
108
+ }
109
+ data['children'].append(new_child)
110
+
111
+ build_def_tree(new_input_ids, new_child, base_prompt, model, tokenizer, config, depth=depth+1, max_depth=max_depth, cumulative_prob=new_cumulative_prob)
112
+
113
+ def generate_definition_tree(base_prompt, embedding, model, tokenizer, config):
114
+ results_dict = {"token": "", "cumulative_prob": 1, "children": []}
115
+
116
+ # Reset the token embedding
117
+ token_embedding = torch.unsqueeze(embedding, dim=0).to(model.device)
118
+ update_token_embedding(model, config.SUB_TOKEN_ID, token_embedding)
119
+
120
+ # Clear the model's cache if it has one
121
+ if hasattr(model, 'reset_cache'):
122
+ model.reset_cache()
123
+
124
+ input_ids = tokenizer.encode(base_prompt, return_tensors="pt").to(model.device)
125
+ build_def_tree(input_ids, results_dict, base_prompt, model, tokenizer, config)
126
+
127
+ return results_dict
128
+
129
+ def find_max_min_cumulative_weight(node, current_max=0, current_min=float('inf')):
130
+ current_max = max(current_max, node.get('cumulative_prob', 0))
131
+ if node.get('cumulative_prob', 1) > 0:
132
+ current_min = min(current_min, node.get('cumulative_prob', 1))
133
+ for child in node.get('children', []):
134
+ current_max, current_min = find_max_min_cumulative_weight(child, current_max, current_min)
135
+ return current_max, current_min
136
+
137
+ def scale_edge_width(cumulative_weight, max_weight, min_weight, log_base, max_thickness=33, min_thickness=1):
138
+ cumulative_weight = max(cumulative_weight, min_weight)
139
+ log_weight = math.log(cumulative_weight, log_base) - math.log(min_weight, log_base)
140
+ log_max = math.log(max_weight, log_base) - math.log(min_weight, log_base)
141
+ amplified_weight = (log_weight / log_max) ** 2.5
142
+ scaled_weight = (amplified_weight * (max_thickness - min_thickness)) + min_thickness
143
+ return scaled_weight
144
+
145
+ def add_nodes_edges(dot, node, config, max_weight, min_weight, parent=None, is_root=True, depth=0, trim_cutoff=0):
146
+ node_id = str(id(node))
147
+ token = node.get('token', '').strip()
148
+ cumulative_prob = node.get('cumulative_prob', 1)
149
+
150
+ if cumulative_prob < trim_cutoff and not is_root:
151
+ return
152
+
153
+ if is_root or token:
154
+ if parent and not is_root:
155
+ edge_weight = scale_edge_width(cumulative_prob, max_weight, min_weight, config.LOG_BASE)
156
+ dot.edge(parent, node_id, arrowhead='dot', arrowsize='1', color='darkblue', penwidth=str(edge_weight))
157
+
158
+ label = "*" if is_root else token
159
+ dot.node(node_id, label=label, shape='plaintext', fontsize="36", fontname='Helvetica')
160
+
161
+ for child in node.get('children', []):
162
+ add_nodes_edges(dot, child, config, max_weight, min_weight, parent=node_id, is_root=False, depth=depth+1, trim_cutoff=trim_cutoff)
163
+
164
+ def create_tree_diagram(data, config, trim_cutoff=0):
165
+ dot = Digraph(comment='Definition Tree', format='png')
166
+ dot.attr(rankdir='LR', size='5040,5000', margin='0.06', nodesep='0.06', ranksep='1', dpi='120', bgcolor='white')
167
+
168
+ max_weight, min_weight = find_max_min_cumulative_weight(data)
169
+ add_nodes_edges(dot, data, config, max_weight, min_weight, trim_cutoff=trim_cutoff)
170
+
171
+ output = BytesIO()
172
+ dot.render(outfile=output, format='png')
173
+ output.seek(0)
174
+
175
+ # Add white background
176
+ with Image.open(output) as img:
177
+ bg = Image.new("RGB", (img.width, 5000), (255, 255, 255))
178
+ y_offset = (5000 - img.height) // 2
179
+ bg.paste(img, (0, y_offset))
180
+ final_output = BytesIO()
181
+ bg.save(final_output, 'PNG')
182
+ final_output.seek(0)
183
+
184
+ return final_output
185
+
186
+ def get_neuronpedia_url(layer, feature):
187
+ return f"https://neuronpedia.org/gemma-2b/{layer}-res-jb/{feature}?embed=true&embedexplanation=true&embedplots=false&embedtest=false&height=300"
188
+
189
+ @torch.no_grad()
190
+ def generate_definition_tree_placeholder(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight):
191
+ return "Definition tree generation placeholder"
192
+
193
+ def gradio_interface():
194
+ model, tokenizer = load_tokenizer_and_model()
195
+ if model is None or tokenizer is None:
196
+ return gr.Interface(lambda: "Failed to load model and tokenizer. Please check the logs for more details.", inputs=[], outputs="text")
197
+
198
+ embeddings = get_embeddings(model)
199
+
200
+ with gr.Blocks() as demo:
201
+ gr.Markdown("# Gemma-2B SAE Feature Explorer")
202
+
203
+ with gr.Row():
204
+ with gr.Column():
205
+ selected_sae = gr.Dropdown(choices=["Gemma-2B layer 0", "Gemma-2B layer 6", "Gemma-2B layer 10", "Gemma-2B layer 12"], label="Select SAE")
206
+ feature_number = gr.Number(label="Select feature number", minimum=0, maximum=16383, value=0)
207
+
208
+ mode = gr.Radio(
209
+ choices=["cosine distance token lists", "definition tree generation"],
210
+ label="Select mode",
211
+ value="cosine distance token lists"
212
+ )
213
+
214
+ weight_type = gr.Radio(["encoder", "decoder"], label="Select weight type for feature vector construction", value="encoder")
215
+ use_token_centroid = gr.Checkbox(label="Use token centroid offset", value=True)
216
+ scaling_factor = gr.Slider(minimum=0.1, maximum=10.0, value=3.8, label="Scaling factor (3.8 is mean distance from token embeddings to token centroid)")
217
+
218
+ num_exp = gr.Slider(minimum=0.1, maximum=5.0, value=1.4, label="Numerator exponent m")
219
+ denom_exp = gr.Slider(minimum=0.1, maximum=5.0, value=1.0, label="Denominator exponent n")
220
+ use_pca = gr.Checkbox(label="Introduce first PCA component")
221
+ pca_weight = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="PCA weight")
222
+
223
+ with gr.Column():
224
+ output = gr.Image(label="Tree Diagram Output")
225
+ neuronpedia_embed = gr.HTML(label="Neuronpedia Embed")
226
+ trim_slider = gr.Slider(minimum=0.00001, maximum=0.1, value=0.00001, label="Trim cutoff for cumulative probability")
227
+
228
+ def update_output(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode, trim_cutoff):
229
+ neuronpedia_url = get_neuronpedia_url(selected_sae.split(" ")[-1], feature_number)
230
+ neuronpedia_embed.update(value=f'<iframe src="{neuronpedia_url}" width="100%" height="300" frameborder="0"></iframe>')
231
+
232
+ if mode == "cosine distance token lists":
233
+ # Keep the original functionality here
234
+ pass
235
+ elif mode == "definition tree generation":
236
+ embedding = embeddings[int(feature_number)].to(config.DEVICE)
237
+ if use_token_centroid:
238
+ token_centroid = torch.mean(embeddings, dim=0).to(config.DEVICE)
239
+ embedding = token_centroid + scaling_factor * (embedding - token_centroid) / torch.norm(embedding - token_centroid)
240
+
241
+ base_prompt = f'A typical definition of "{tokenizer.decode([config.SUB_TOKEN_ID], skip_special_tokens=True)}" would be "'
242
+ results_dict = generate_definition_tree(base_prompt, embedding, model, tokenizer, config)
243
+ tree_diagram = create_tree_diagram(results_dict, config, trim_cutoff=trim_cutoff)
244
+ return tree_diagram
245
+
246
+ def update_ui(mode_selected):
247
+ show_cosine_controls = mode_selected == "cosine distance token lists"
248
+ return (
249
+ gr.update(visible=show_cosine_controls),
250
+ gr.update(visible=show_cosine_controls),
251
+ gr.update(visible=show_cosine_controls),
252
+ gr.update(visible=show_cosine_controls)
253
+ )
254
+
255
+ inputs = [selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode, trim_slider]
256
+ mode.change(
257
+ update_ui, inputs=[mode],
258
+ outputs=[num_exp, denom_exp, output]
259
+ )
260
+
261
+ gr.Button("Generate Output").click(update_output, inputs=inputs, outputs=[output])
262
+
263
+ return demo
264
+
265
+
266
+ if __name__ == "__main__":
267
+ iface = gradio_interface()
268
+ iface.launch()