Nu Appleblossom commited on
Commit
c1d7828
·
1 Parent(s): c1d81b1

Initial application file + requirements

Browse files
Files changed (2) hide show
  1. app.py +205 -8
  2. requirements.txt +8 -0
app.py CHANGED
@@ -1,14 +1,211 @@
1
  import gradio as gr
2
- import spaces
3
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- zero = torch.Tensor([0]).cuda()
6
- print(zero.device) # <-- 'cpu'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  @spaces.GPU
9
- def greet(n):
10
- print(zero.device) # <-- 'cuda:0'
11
- return f"Hello {zero + n} Tensor"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- demo = gr.Interface(fn=greet, inputs=gr.Number(), outputs=gr.Text())
14
- demo.launch()
 
 
1
  import gradio as gr
 
2
  import torch
3
+ import torch.nn.functional as F
4
+ from transformers import AutoTokenizer
5
+ from safetensors import safe_open
6
+ import os
7
+ import requests
8
+ import json
9
+ from sklearn.decomposition import PCA
10
+ import logging
11
+ import time
12
+ from dotenv import load_dotenv
13
+ from huggingface_hub import hf_hub_download
14
+ import spaces
15
+
16
+ # Load environment variables
17
+ load_dotenv()
18
+
19
+ # Set up logging
20
+ logging.basicConfig(level=logging.INFO)
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # Configuration
24
+ class Config:
25
+ def __init__(self):
26
+ self.MODEL_NAME = "google/gemma-2b"
27
+ self.ACCESS_TOKEN = os.getenv('HF_ACCESS_TOKEN')
28
+ self.DEVICE = "cpu" # Will be updated to "cuda" when GPU is available
29
+ self.DTYPE = torch.float32
30
+
31
+ config = Config()
32
+
33
+ def load_tokenizer():
34
+ try:
35
+ return AutoTokenizer.from_pretrained(config.MODEL_NAME, token=config.ACCESS_TOKEN)
36
+ except Exception as e:
37
+ logger.error(f"Error loading tokenizer: {str(e)}")
38
+ return None
39
+
40
+ def load_token_embeddings():
41
+ try:
42
+ embeddings_path = hf_hub_download(
43
+ repo_id="mwatkins1970/gemma-2b-embeddings",
44
+ filename="gemma_2b_embeddings.pt",
45
+ token=os.getenv("HF_ACCESS_TOKEN")
46
+ )
47
+ embeddings = torch.load(embeddings_path, map_location=config.DEVICE)
48
+ return embeddings.to(dtype=config.DTYPE)
49
+ except Exception as e:
50
+ logger.error(f"Error loading token embeddings: {str(e)}")
51
+ return None
52
+
53
+ def load_sae_weights(sae_name):
54
+ start_time = time.time()
55
+ base_url = 'https://huggingface.co/jbloom/Gemma-2b-Residual-Stream-SAEs/resolve/main/'
56
+
57
+ sae_urls = {
58
+ "Gemma-2B layer 6": "gemma_2b_blocks.6.hook_resid_post_16384_anthropic_fast_lr/sae_weights.safetensors",
59
+ "Gemma-2B layer 0": "gemma_2b_blocks.0.hook_resid_post_16384_anthropic/sae_weights.safetensors",
60
+ "Gemma-2B layer 10": "gemma_2b_blocks.10.hook_resid_post_16384/sae_weights.safetensors",
61
+ "Gemma-2B layer 12": "gemma_2b_blocks.12.hook_resid_post_16384/sae_weights.safetensors"
62
+ }
63
+
64
+ if sae_name not in sae_urls:
65
+ raise ValueError(f"Unknown SAE: {sae_name}")
66
+
67
+ url = f'{base_url}{sae_urls[sae_name]}?download=true'
68
+ local_filename = f'sae_{sae_name.replace(" ", "_").lower()}.safetensors'
69
 
70
+ if not os.path.exists(local_filename):
71
+ try:
72
+ response = requests.get(url)
73
+ response.raise_for_status()
74
+ with open(local_filename, 'wb') as f:
75
+ f.write(response.content)
76
+ logger.info(f'SAE weights for {sae_name} downloaded successfully!')
77
+ except requests.RequestException as e:
78
+ logger.error(f"Failed to download SAE weights for {sae_name}: {str(e)}")
79
+ return None, None
80
+
81
+ try:
82
+ with safe_open(local_filename, framework="pt") as f:
83
+ w_dec = f.get_tensor("W_dec").to(device=config.DEVICE, dtype=config.DTYPE)
84
+ w_enc = f.get_tensor("W_enc").to(device=config.DEVICE, dtype=config.DTYPE)
85
+
86
+ logger.info(f"Successfully loaded weights for {sae_name}")
87
+ logger.info(f"Time taken to load weights: {time.time() - start_time:.2f} seconds")
88
+ return w_enc, w_dec
89
+ except Exception as e:
90
+ logger.error(f"Error loading SAE weights for {sae_name}: {str(e)}")
91
+ return None, None
92
+
93
+ @torch.no_grad()
94
+ def create_feature_vector(w_enc, w_dec, feature_number, weight_type, token_centroid, use_token_centroid, scaling_factor):
95
+ if weight_type == "encoder":
96
+ feature_vector = w_enc[:, feature_number]
97
+ else:
98
+ feature_vector = w_dec[feature_number]
99
+
100
+ if use_token_centroid:
101
+ feature_vector = token_centroid + scaling_factor * (feature_vector - token_centroid) / torch.norm(feature_vector - token_centroid)
102
+
103
+ return feature_vector
104
+
105
+ def perform_pca(_embeddings):
106
+ pca = PCA(n_components=1)
107
+ pca.fit(_embeddings.cpu().numpy())
108
+ pca_direction = torch.tensor(pca.components_[0], dtype=config.DTYPE, device=config.DEVICE)
109
+ return F.normalize(pca_direction, p=2, dim=0)
110
+
111
+ @torch.no_grad()
112
+ def create_ghost_token(_feature_vector, _token_centroid, _pca_direction, target_distance, pca_weight):
113
+ feature_direction = F.normalize(_feature_vector - _token_centroid, p=2, dim=0)
114
+ combined_direction = (1 - pca_weight) * feature_direction + pca_weight * _pca_direction
115
+ combined_direction = F.normalize(combined_direction, p=2, dim=0)
116
+ return _token_centroid + target_distance * combined_direction
117
+
118
+ @torch.no_grad()
119
+ def find_closest_tokens(_emb, _token_embeddings, _tokenizer, top_k=500, num_exp=1.4, denom_exp=1.0):
120
+ token_centroid = torch.mean(_token_embeddings, dim=0)
121
+ emb_norm = F.normalize(_emb.view(1, -1), p=2, dim=1)
122
+ centroid_norm = F.normalize(token_centroid.view(1, -1), p=2, dim=1)
123
+ normalized_embeddings = F.normalize(_token_embeddings, p=2, dim=1)
124
+
125
+ similarities_emb = torch.mm(emb_norm, normalized_embeddings.t()).squeeze()
126
+ similarities_centroid = torch.mm(centroid_norm, normalized_embeddings.t()).squeeze()
127
+
128
+ distances_emb = torch.pow(1 - similarities_emb, num_exp)
129
+ distances_centroid = torch.pow(1 - similarities_centroid, denom_exp)
130
+
131
+ ratios = distances_emb / distances_centroid
132
+ top_ratios, top_indices = torch.topk(ratios, k=top_k, largest=False)
133
+
134
+ closest_tokens = [_tokenizer.decode([idx.item()]) for idx in top_indices]
135
+ return list(zip(closest_tokens, top_ratios.tolist()))
136
+
137
+ def get_neuronpedia_url(layer, feature):
138
+ return f"https://neuronpedia.org/gemma-2b/{layer}-res-jb/{feature}?embed=true&embedexplanation=true&embedplots=false&embedtest=false&height=300"
139
+
140
+ # Global variables to store loaded resources
141
+ tokenizer = None
142
+ token_embeddings = None
143
+ w_enc_dict = {}
144
+ w_dec_dict = {}
145
 
146
  @spaces.GPU
147
+ def process_input(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp):
148
+ global tokenizer, token_embeddings, w_enc_dict, w_dec_dict
149
+
150
+ if tokenizer is None:
151
+ tokenizer = load_tokenizer()
152
+ if token_embeddings is None:
153
+ token_embeddings = load_token_embeddings()
154
+
155
+ if selected_sae not in w_enc_dict or selected_sae not in w_dec_dict:
156
+ w_enc, w_dec = load_sae_weights(selected_sae)
157
+ w_enc_dict[selected_sae] = w_enc
158
+ w_dec_dict[selected_sae] = w_dec
159
+ else:
160
+ w_enc, w_dec = w_enc_dict[selected_sae], w_dec_dict[selected_sae]
161
+
162
+ if w_enc is None or w_dec is None:
163
+ return "Failed to load SAE weights. Please try selecting a different SAE or rerun the app."
164
+
165
+ token_centroid = torch.mean(token_embeddings, dim=0)
166
+ feature_vector = create_feature_vector(w_enc, w_dec, feature_number, weight_type, token_centroid, use_token_centroid, scaling_factor)
167
+
168
+ if use_pca:
169
+ pca_direction = perform_pca(token_embeddings)
170
+ feature_vector = create_ghost_token(feature_vector, token_centroid, pca_direction, scaling_factor, pca_weight)
171
+
172
+ closest_tokens_with_values = find_closest_tokens(
173
+ feature_vector, token_embeddings, tokenizer,
174
+ top_k=500, num_exp=num_exp, denom_exp=denom_exp
175
+ )
176
+
177
+ token_list = [token for token, _ in closest_tokens_with_values]
178
+ result = f"100 tokens whose embeddings produce the smallest ratio:\n\n"
179
+ result += f"[{', '.join(repr(token) for token in token_list[:100])}]\n\n"
180
+ result += "Top 500 list:\n"
181
+ result += "\n".join([f"{token!r}: {value:.4f}" for token, value in closest_tokens_with_values])
182
+
183
+ return result
184
+
185
+ def gradio_interface():
186
+ with gr.Blocks() as demo:
187
+ gr.Markdown("# SAE Feature Explorer")
188
+
189
+ with gr.Row():
190
+ with gr.Column():
191
+ selected_sae = gr.Dropdown(choices=["Gemma-2B layer 0", "Gemma-2B layer 6", "Gemma-2B layer 10", "Gemma-2B layer 12"], label="Select SAE")
192
+ feature_number = gr.Number(label="Select feature number", minimum=0, maximum=16383, value=0)
193
+ weight_type = gr.Radio(["encoder", "decoder"], label="Select weight type for feature vector")
194
+ use_token_centroid = gr.Checkbox(label="Use token centroid offset", value=True)
195
+ scaling_factor = gr.Slider(minimum=0.1, maximum=10.0, value=3.8, label="Scaling factor (3.8 is mean distance from token centroid)")
196
+ use_pca = gr.Checkbox(label="Introduce first PCA component")
197
+ pca_weight = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="PCA weight")
198
+ num_exp = gr.Slider(minimum=0.1, maximum=5.0, value=1.4, label="Numerator exponent m")
199
+ denom_exp = gr.Slider(minimum=0.1, maximum=5.0, value=1.0, label="Denominator exponent n")
200
+
201
+ with gr.Column():
202
+ output = gr.Textbox(label="Results")
203
+
204
+ submit_btn = gr.Button("Generate")
205
+ submit_btn.click(process_input, inputs=[selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp], outputs=output)
206
+
207
+ return demo
208
 
209
+ if __name__ == "__main__":
210
+ iface = gradio_interface()
211
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ transformers
4
+ safetensors
5
+ requests
6
+ scikit-learn
7
+ python-dotenv
8
+ huggingface_hub