charlieoneill commited on
Commit
b481357
1 Parent(s): e740e47

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +523 -0
app.py ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import json
4
+ import pandas as pd
5
+ from openai import OpenAI
6
+ import yaml
7
+ from typing import Optional, List, Dict, Tuple, Any
8
+ from topk_sae import FastAutoencoder
9
+ import torch
10
+ import plotly.express as px
11
+ from collections import Counter
12
+ from huggingface_hub import hf_hub_download
13
+ import os
14
+
15
+ import os
16
+ print(os.getenv('MODEL_REPO_ID'))
17
+
18
+ # Constants
19
+ EMBEDDING_MODEL = "text-embedding-3-small"
20
+ d_model = 1536
21
+ n_dirs = d_model * 6
22
+ k = 64
23
+ auxk = 128
24
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
25
+ torch.set_grad_enabled(False)
26
+
27
+ # Function to download all necessary files
28
+ def download_all_files():
29
+ files_to_download = [
30
+ "astroPH_paper_metadata.csv",
31
+ "csLG_feature_analysis_results_64.json",
32
+ "astroPH_topk_indices_64_9216_int32.npy",
33
+ "astroPH_64_9216.pth",
34
+ "astroPH_topk_values_64_9216_float16.npy",
35
+ "csLG_abstract_texts.json",
36
+ "csLG_topk_values_64_9216_float16.npy",
37
+ "csLG_abstract_embeddings_float16.npy",
38
+ "csLG_paper_metadata.csv",
39
+ "csLG_64_9216.pth",
40
+ "astroPH_abstract_texts.json",
41
+ "astroPH_feature_analysis_results_64.json",
42
+ "csLG_topk_indices_64_9216_int32.npy",
43
+ "astroPH_abstract_embeddings_float16.npy"
44
+ ]
45
+
46
+ for file in files_to_download:
47
+ local_path = os.path.join("data", file)
48
+ os.makedirs(os.path.dirname(local_path), exist_ok=True)
49
+ hf_hub_download(repo_id="charlieoneill/saerch-ai-data", filename=file, local_dir="data")
50
+ print(f"Downloaded {file}")
51
+
52
+ # Load configuration and initialize OpenAI client
53
+ download_all_files()
54
+ # config = yaml.safe_load(open('../config.yaml', 'r'))
55
+ # client = OpenAI(api_key=config['jwu_openai_key'])
56
+
57
+ # Load the API key from the environment variable
58
+ api_key = os.getenv('openai_key')
59
+
60
+ # Ensure the API key is set
61
+ if not api_key:
62
+ raise ValueError("The environment variable 'openai_key' is not set.")
63
+
64
+ # Initialize the OpenAI client with the API key
65
+ client = OpenAI(api_key=api_key)
66
+
67
+ # Function to load data for a specific subject
68
+ def load_subject_data(subject):
69
+ # embeddings_path = f"data/{subject}_abstract_embeddings.npy"
70
+ # texts_path = f"data/{subject}_abstract_texts.json"
71
+ # feature_analysis_path = f"data/{subject}_feature_analysis_results_{k}.json"
72
+ # metadata_path = f'data/{subject}_paper_metadata.csv'
73
+ # topk_indices_path = f"data/{subject}_topk_indices_{k}_{n_dirs}.npy"
74
+ # topk_values_path = f"data/{subject}_topk_values_{k}_{n_dirs}.npy"
75
+
76
+ embeddings_path = f"data/{subject}_abstract_embeddings_float16.npy"
77
+ texts_path = f"data/{subject}_abstract_texts.json"
78
+ feature_analysis_path = f"data/{subject}_feature_analysis_results_{k}.json"
79
+ metadata_path = f'data/{subject}_paper_metadata.csv'
80
+ topk_indices_path = f"data/{subject}_topk_indices_{k}_{n_dirs}_int32.npy"
81
+ topk_values_path = f"data/{subject}_topk_values_{k}_{n_dirs}_float16.npy"
82
+
83
+ # abstract_embeddings = np.load(embeddings_path)
84
+ # with open(texts_path, 'r') as f:
85
+ # abstract_texts = json.load(f)
86
+ # with open(feature_analysis_path, 'r') as f:
87
+ # feature_analysis = json.load(f)
88
+ # df_metadata = pd.read_csv(metadata_path)
89
+ # topk_indices = np.load(topk_indices_path)
90
+ # topk_values = np.load(topk_values_path)
91
+
92
+ abstract_embeddings = np.load(embeddings_path).astype(np.float32) # Load float16 and convert to float32
93
+ with open(texts_path, 'r') as f:
94
+ abstract_texts = json.load(f)
95
+ with open(feature_analysis_path, 'r') as f:
96
+ feature_analysis = json.load(f)
97
+ df_metadata = pd.read_csv(metadata_path)
98
+ topk_indices = np.load(topk_indices_path) # Already in int32, no conversion needed
99
+ topk_values = np.load(topk_values_path).astype(np.float32)
100
+
101
+ model_filename = f"{subject}_64_9216.pth"
102
+ model_path = os.path.join("data", model_filename)
103
+
104
+ ae = FastAutoencoder(n_dirs, d_model, k, auxk, multik=0).to(device)
105
+ ae.load_state_dict(torch.load(model_path))
106
+ ae.eval()
107
+
108
+ weights = torch.load(model_path)
109
+ decoder = weights['decoder.weight'].cpu().numpy()
110
+ del weights
111
+
112
+ return {
113
+ 'abstract_embeddings': abstract_embeddings,
114
+ 'abstract_texts': abstract_texts,
115
+ 'feature_analysis': feature_analysis,
116
+ 'df_metadata': df_metadata,
117
+ 'topk_indices': topk_indices,
118
+ 'topk_values': topk_values,
119
+ 'ae': ae,
120
+ 'decoder': decoder
121
+ }
122
+
123
+ # Load data for both subjects
124
+ subject_data = {
125
+ 'astroPH': load_subject_data('astroPH'),
126
+ 'csLG': load_subject_data('csLG')
127
+ }
128
+
129
+ # Update existing functions to use the selected subject's data
130
+ def get_embedding(text: Optional[str], model: str = EMBEDDING_MODEL) -> Optional[np.ndarray]:
131
+ try:
132
+ embedding = client.embeddings.create(input=[text], model=model).data[0].embedding
133
+ return np.array(embedding, dtype=np.float32)
134
+ except Exception as e:
135
+ print(f"Error getting embedding: {e}")
136
+ return None
137
+
138
+ def intervened_hidden_to_intervened_embedding(topk_indices, topk_values, ae):
139
+ with torch.no_grad():
140
+ return ae.decode_sparse(topk_indices, topk_values)
141
+
142
+ # Function definitions for feature activation, co-occurrence, styling, etc.
143
+ def get_feature_activations(subject, feature_index, m=5, min_length=100):
144
+ abstract_texts = subject_data[subject]['abstract_texts']
145
+ abstract_embeddings = subject_data[subject]['abstract_embeddings']
146
+ topk_indices = subject_data[subject]['topk_indices']
147
+ topk_values = subject_data[subject]['topk_values']
148
+
149
+ doc_ids = abstract_texts['doc_ids']
150
+ abstracts = abstract_texts['abstracts']
151
+
152
+ feature_mask = topk_indices == feature_index
153
+ activated_indices = np.where(feature_mask.any(axis=1))[0]
154
+ activation_values = np.where(feature_mask, topk_values, 0).max(axis=1)
155
+
156
+ sorted_activated_indices = activated_indices[np.argsort(-activation_values[activated_indices])]
157
+
158
+ top_m_abstracts = []
159
+ top_m_indices = []
160
+ for i in sorted_activated_indices:
161
+ if len(abstracts[i]) > min_length:
162
+ top_m_abstracts.append((doc_ids[i], abstracts[i], activation_values[i]))
163
+ top_m_indices.append(i)
164
+ if len(top_m_abstracts) == m:
165
+ break
166
+
167
+ return top_m_abstracts
168
+
169
+ def calculate_co_occurrences(subject, target_index, n_features=9216):
170
+ topk_indices = subject_data[subject]['topk_indices']
171
+
172
+ mask = np.any(topk_indices == target_index, axis=1)
173
+ co_occurring_indices = topk_indices[mask].flatten()
174
+ co_occurrences = Counter(co_occurring_indices)
175
+ del co_occurrences[target_index]
176
+ result = np.zeros(n_features, dtype=int)
177
+ result[list(co_occurrences.keys())] = list(co_occurrences.values())
178
+ return result
179
+
180
+ def style_dataframe(df: pd.DataFrame, is_top: bool) -> pd.DataFrame:
181
+ cosine_values = df['Cosine similarity'].astype(float)
182
+ min_val = cosine_values.min()
183
+ max_val = cosine_values.max()
184
+
185
+ def color_similarity(val):
186
+ val = float(val)
187
+ # Normalize the value between 0 and 1
188
+ if is_top:
189
+ normalized_val = (val - min_val) / (max_val - min_val)
190
+ else:
191
+ # For bottom correlated, reverse the normalization
192
+ normalized_val = (max_val - val) / (max_val - min_val)
193
+
194
+ # Adjust the color intensity to avoid zero intensity
195
+ color_intensity = 0.2 + (normalized_val * 0.8) # This ensures the range is from 0.2 to 1.0
196
+
197
+ if is_top:
198
+ color = f'background-color: rgba(0, 255, 0, {color_intensity:.2f})'
199
+ else:
200
+ color = f'background-color: rgba(255, 0, 0, {color_intensity:.2f})'
201
+ return color
202
+
203
+ return df.style.applymap(color_similarity, subset=['Cosine similarity'])
204
+
205
+ def get_feature_from_index(subject, index):
206
+ feature = next((f for f in subject_data[subject]['feature_analysis'] if f['index'] == index), None)
207
+ return feature
208
+
209
+ def visualize_feature(subject, index):
210
+ feature = next((f for f in subject_data[subject]['feature_analysis'] if f['index'] == index), None)
211
+ if feature is None:
212
+ return "Invalid feature index", None, None, None, None, None, None
213
+
214
+ output = f"# {feature['label']}\n\n"
215
+ output += f"* Pearson correlation: {feature['pearson_correlation']:.4f}\n\n"
216
+ output += f"* Density: {feature['density']:.4f}\n\n"
217
+
218
+ # Top m abstracts
219
+ top_m_abstracts = get_feature_activations(subject, index)
220
+
221
+ # Create dataframe for top abstracts
222
+ df_data = [
223
+ {"Title": m[1].split('\n\n')[0], "Activation value": f"{m[2]:.4f}"}
224
+ for m in top_m_abstracts
225
+ ]
226
+ df_top_abstracts = pd.DataFrame(df_data)
227
+
228
+ # Activation value distribution
229
+ topk_indices = subject_data[subject]['topk_indices']
230
+ topk_values = subject_data[subject]['topk_values']
231
+
232
+ activation_values = np.where(topk_indices == index, topk_values, 0).max(axis=1)
233
+ fig2 = px.histogram(x=activation_values, nbins=50)
234
+ fig2.update_layout(
235
+ #title=f'{feature["label"]}',
236
+ xaxis_title='Activation value',
237
+ yaxis_title=None,
238
+ yaxis_type='log',
239
+ height=220,
240
+ )
241
+
242
+ # Correlated features
243
+ decoder = subject_data[subject]['decoder']
244
+ feature_vector = decoder[:, index]
245
+ decoder_without_feature = np.delete(decoder, index, axis=1)
246
+ cosine_similarities = np.dot(feature_vector, decoder_without_feature) / (np.linalg.norm(decoder_without_feature, axis=0) * np.linalg.norm(feature_vector))
247
+
248
+ topk = 5
249
+ topk_indices_cosine = np.argsort(-cosine_similarities)[:topk]
250
+ topk_values_cosine = cosine_similarities[topk_indices_cosine]
251
+
252
+ # Create dataframe for top 5 correlated features
253
+ df_top_correlated = pd.DataFrame({
254
+ "Feature": [get_feature_from_index(subject, i)['label'] for i in topk_indices_cosine],
255
+ "Cosine similarity": [f"{v:.4f}" for v in topk_values_cosine]
256
+ })
257
+ df_top_correlated_styled = style_dataframe(df_top_correlated, is_top=True)
258
+
259
+ bottomk = 5
260
+ bottomk_indices_cosine = np.argsort(cosine_similarities)[:bottomk]
261
+ bottomk_values_cosine = cosine_similarities[bottomk_indices_cosine]
262
+
263
+ # Create dataframe for bottom 5 correlated features
264
+ df_bottom_correlated = pd.DataFrame({
265
+ "Feature": [get_feature_from_index(subject, i)['label'] for i in bottomk_indices_cosine],
266
+ "Cosine similarity": [f"{v:.4f}" for v in bottomk_values_cosine]
267
+ })
268
+ df_bottom_correlated_styled = style_dataframe(df_bottom_correlated, is_top=False)
269
+
270
+ # Co-occurrences
271
+ co_occurrences = calculate_co_occurrences(subject, index)
272
+ topk = 5
273
+ topk_indices_co_occurrence = np.argsort(-co_occurrences)[:topk]
274
+ topk_values_co_occurrence = co_occurrences[topk_indices_co_occurrence]
275
+
276
+ # Create dataframe for top 5 co-occurring features
277
+ df_co_occurrences = pd.DataFrame({
278
+ "Feature": [get_feature_from_index(subject, i)['label'] for i in topk_indices_co_occurrence],
279
+ "Co-occurrences": topk_values_co_occurrence
280
+ })
281
+
282
+ return output, df_top_abstracts, df_top_correlated_styled, df_bottom_correlated_styled, df_co_occurrences, fig2
283
+
284
+ # Modify the main interface function
285
+ def create_interface():
286
+ custom_css = """
287
+ #custom-slider-* {
288
+ background-color: #ffe6e6;
289
+ }
290
+ """
291
+
292
+ with gr.Blocks(css=custom_css) as demo:
293
+ subject = gr.Dropdown(choices=['astroPH', 'csLG'], label="Select Subject", value='astroPH')
294
+
295
+ with gr.Tabs():
296
+ with gr.Tab("SAErch"):
297
+ input_text = gr.Textbox(label="input")
298
+ search_results_state = gr.State([])
299
+ feature_values_state = gr.State([])
300
+ feature_indices_state = gr.State([])
301
+ manually_added_features_state = gr.State([])
302
+
303
+ def update_search_results(feature_values, feature_indices, manually_added_features, current_subject):
304
+ ae = subject_data[current_subject]['ae']
305
+ abstract_embeddings = subject_data[current_subject]['abstract_embeddings']
306
+ abstract_texts = subject_data[current_subject]['abstract_texts']
307
+ df_metadata = subject_data[current_subject]['df_metadata']
308
+
309
+ # Combine manually added features with query-generated features
310
+ all_indices = []
311
+ all_values = []
312
+
313
+ # Add manually added features first
314
+ for index in manually_added_features:
315
+ if index not in all_indices:
316
+ all_indices.append(index)
317
+ all_values.append(feature_values[feature_indices.index(index)] if index in feature_indices else 0.0)
318
+
319
+ # Add remaining query-generated features
320
+ for index, value in zip(feature_indices, feature_values):
321
+ if index not in all_indices:
322
+ all_indices.append(index)
323
+ all_values.append(value)
324
+
325
+ # Reconstruct query embedding
326
+ topk_indices = torch.tensor(all_indices).to(device)
327
+ topk_values = torch.tensor(all_values).to(device)
328
+
329
+ intervened_embedding = intervened_hidden_to_intervened_embedding(topk_indices, topk_values, ae)
330
+ intervened_embedding = intervened_embedding.cpu().numpy().flatten()
331
+
332
+ # Perform similarity search
333
+ sims = np.dot(abstract_embeddings, intervened_embedding)
334
+ topk_indices_search = np.argsort(sims)[::-1][:10]
335
+ doc_ids = abstract_texts['doc_ids']
336
+ topk_doc_ids = [doc_ids[i] for i in topk_indices_search]
337
+
338
+ # Prepare search results
339
+ search_results = []
340
+ for doc_id in topk_doc_ids:
341
+ metadata = df_metadata[df_metadata['arxiv_id'] == doc_id].iloc[0]
342
+ title = metadata['title'].replace('[', '').replace(']', '')
343
+ search_results.append([
344
+ title,
345
+ int(metadata['citation_count']),
346
+ int(metadata['year'])
347
+ ])
348
+
349
+ return search_results, all_values, all_indices
350
+
351
+ @gr.render(inputs=[input_text, search_results_state, feature_values_state, feature_indices_state, manually_added_features_state, subject])
352
+ def show_components(text, search_results, feature_values, feature_indices, manually_added_features, current_subject):
353
+ if len(text) == 0:
354
+ return gr.Markdown("## No Input Provided")
355
+
356
+ if not search_results or text != getattr(show_components, 'last_query', None):
357
+ show_components.last_query = text
358
+ query_embedding = get_embedding(text)
359
+
360
+ ae = subject_data[current_subject]['ae']
361
+ with torch.no_grad():
362
+ recons, z_dict = ae(torch.tensor(query_embedding).unsqueeze(0).to(device))
363
+ topk_indices = z_dict['topk_indices'][0].cpu().numpy()
364
+ topk_values = z_dict['topk_values'][0].cpu().numpy()
365
+
366
+ feature_values = topk_values.tolist()
367
+ feature_indices = topk_indices.tolist()
368
+ search_results, feature_values, feature_indices = update_search_results(feature_values, feature_indices, manually_added_features, current_subject)
369
+
370
+ with gr.Row():
371
+ with gr.Column(scale=2):
372
+ df = gr.Dataframe(
373
+ headers=["Title", "Citation Count", "Year"],
374
+ value=search_results,
375
+ label="Top 10 Search Results"
376
+ )
377
+
378
+ feature_search = gr.Textbox(label="Search Feature Labels")
379
+ feature_matches = gr.CheckboxGroup(label="Matching Features", choices=[])
380
+ add_button = gr.Button("Add Selected Features")
381
+
382
+ def search_feature_labels(search_text):
383
+ if not search_text:
384
+ return gr.CheckboxGroup(choices=[])
385
+ matches = [f"{f['label']} ({f['index']})" for f in subject_data[current_subject]['feature_analysis'] if search_text.lower() in f['label'].lower()]
386
+ return gr.CheckboxGroup(choices=matches[:10])
387
+
388
+ feature_search.change(search_feature_labels, inputs=[feature_search], outputs=[feature_matches])
389
+
390
+ def on_add_features(selected_features, current_values, current_indices, manually_added_features):
391
+ if selected_features:
392
+ new_indices = [int(f.split('(')[-1].strip(')')) for f in selected_features]
393
+
394
+ # Add new indices to manually_added_features if they're not already there
395
+ manually_added_features = list(dict.fromkeys(manually_added_features + new_indices))
396
+
397
+ return gr.CheckboxGroup(value=[]), current_values, current_indices, manually_added_features
398
+ return gr.CheckboxGroup(value=[]), current_values, current_indices, manually_added_features
399
+
400
+ add_button.click(
401
+ on_add_features,
402
+ inputs=[feature_matches, feature_values_state, feature_indices_state, manually_added_features_state],
403
+ outputs=[feature_matches, feature_values_state, feature_indices_state, manually_added_features_state]
404
+ )
405
+
406
+ with gr.Column(scale=1):
407
+ update_button = gr.Button("Update Results")
408
+ sliders = []
409
+ for i, (value, index) in enumerate(zip(feature_values, feature_indices)):
410
+ feature = next((f for f in subject_data[current_subject]['feature_analysis'] if f['index'] == index), None)
411
+ label = f"{feature['label']} ({index})" if feature else f"Feature {index}"
412
+
413
+ # Add prefix and change color for manually added features
414
+ if index in manually_added_features:
415
+ label = f"[Custom] {label}"
416
+ slider = gr.Slider(minimum=0, maximum=1, step=0.01, value=value, label=label, key=f"slider-{index}", elem_id=f"custom-slider-{index}")
417
+ else:
418
+ slider = gr.Slider(minimum=0, maximum=1, step=0.01, value=value, label=label, key=f"slider-{index}")
419
+
420
+ sliders.append(slider)
421
+
422
+ def on_slider_change(*values):
423
+ manually_added_features = values[-1]
424
+ slider_values = list(values[:-1])
425
+
426
+ # Reconstruct feature_indices based on the order of sliders
427
+ reconstructed_indices = [int(slider.label.split('(')[-1].split(')')[0]) for slider in sliders]
428
+
429
+ new_results, new_values, new_indices = update_search_results(slider_values, reconstructed_indices, manually_added_features, current_subject)
430
+ return new_results, new_values, new_indices, manually_added_features
431
+
432
+ update_button.click(
433
+ on_slider_change,
434
+ inputs=sliders + [manually_added_features_state],
435
+ outputs=[search_results_state, feature_values_state, feature_indices_state, manually_added_features_state]
436
+ )
437
+
438
+ return [df, feature_search, feature_matches, add_button, update_button] + sliders
439
+
440
+ with gr.Tab("Feature Visualisation"):
441
+ gr.Markdown("# Feature Visualiser")
442
+ with gr.Row():
443
+ feature_search = gr.Textbox(label="Search Feature Labels")
444
+ feature_matches = gr.CheckboxGroup(label="Matching Features", choices=[])
445
+ visualize_button = gr.Button("Visualize Feature")
446
+
447
+ feature_info = gr.Markdown()
448
+ abstracts_heading = gr.Markdown("## Top 5 Abstracts")
449
+ top_abstracts = gr.Dataframe(
450
+ headers=["Title", "Activation value"],
451
+ interactive=False
452
+ )
453
+
454
+ gr.Markdown("## Correlated Features")
455
+ with gr.Row():
456
+ with gr.Column(scale=1):
457
+ gr.Markdown("### Top 5 Correlated Features")
458
+ top_correlated = gr.Dataframe(
459
+ headers=["Feature", "Cosine similarity"],
460
+ interactive=False
461
+ )
462
+ with gr.Column(scale=1):
463
+ gr.Markdown("### Bottom 5 Correlated Features")
464
+ bottom_correlated = gr.Dataframe(
465
+ headers=["Feature", "Cosine similarity"],
466
+ interactive=False
467
+ )
468
+
469
+ with gr.Row():
470
+ with gr.Column(scale=1):
471
+ gr.Markdown("## Top 5 Co-occurring Features")
472
+ co_occurring_features = gr.Dataframe(
473
+ headers=["Feature", "Co-occurrences"],
474
+ interactive=False
475
+ )
476
+ with gr.Column(scale=1):
477
+ gr.Markdown(f"## Activation Value Distribution")
478
+ activation_dist = gr.Plot()
479
+
480
+ def search_feature_labels(search_text, current_subject):
481
+ if not search_text:
482
+ return gr.CheckboxGroup(choices=[])
483
+ matches = [f"{f['label']} ({f['index']})" for f in subject_data[current_subject]['feature_analysis'] if search_text.lower() in f['label'].lower()]
484
+ return gr.CheckboxGroup(choices=matches[:10])
485
+
486
+ feature_search.change(search_feature_labels, inputs=[feature_search, subject], outputs=[feature_matches])
487
+
488
+ def on_visualize(selected_features, current_subject):
489
+ if not selected_features:
490
+ return "Please select a feature to visualize.", None, None, None, None, None, "", []
491
+
492
+ # Extract the feature index from the selected feature string
493
+ feature_index = int(selected_features[0].split('(')[-1].strip(')'))
494
+ feature_info, top_abstracts, top_correlated, bottom_correlated, co_occurring_features, activation_dist = visualize_feature(current_subject, feature_index)
495
+
496
+ # Return the visualization results along with empty values for search box and checkbox
497
+ return feature_info, top_abstracts, top_correlated, bottom_correlated, co_occurring_features, activation_dist, "", []
498
+
499
+ visualize_button.click(
500
+ on_visualize,
501
+ inputs=[feature_matches, subject],
502
+ outputs=[feature_info, top_abstracts, top_correlated, bottom_correlated, co_occurring_features, activation_dist, feature_search, feature_matches]
503
+ )
504
+
505
+ # Add logic to update components when subject changes
506
+ def on_subject_change(new_subject):
507
+ # Clear all states and return empty values for all components
508
+ return [], [], [], [], "", [], "", [], None, None, None, None, None, None
509
+
510
+ subject.change(
511
+ on_subject_change,
512
+ inputs=[subject],
513
+ outputs=[search_results_state, feature_values_state, feature_indices_state, manually_added_features_state,
514
+ input_text, feature_matches, feature_search, feature_matches,
515
+ feature_info, top_abstracts, top_correlated, bottom_correlated, co_occurring_features, activation_dist]
516
+ )
517
+
518
+ return demo
519
+
520
+ # Launch the interface
521
+ if __name__ == "__main__":
522
+ demo = create_interface()
523
+ demo.launch()