"""Utilities for interactive visualization of extracted graphs"""
import pandas as pd
import plotly.express as px
import streamlit as st
def create_scatter_plot_with_filter(graph_data):
"""
Crea uno scatter plot interattivo con filtro per cumulative influence
Args:
graph_data: Dizionario contenente i dati del grafo (nodes, metadata, etc)
"""
if 'nodes' not in graph_data:
st.warning("⚠️ No nodes found in graph data")
return
# Estrai prompt_tokens dalla metadata per mappare ctx_idx -> token
prompt_tokens = graph_data.get('metadata', {}).get('prompt_tokens', [])
# Crea mapping ctx_idx -> token
token_map = {i: token for i, token in enumerate(prompt_tokens)}
# Estrai i nodi con ctx_idx, layer e influence
# Mappa layer 'E' (embeddings) a -1, numeri restano numeri
scatter_data = []
skipped_nodes = [] # Per logging nodi problematici
for node in graph_data['nodes']:
layer_val = node.get('layer', '')
try:
# Mappa embedding layer a -1
if str(layer_val).upper() == 'E':
layer_numeric = -1
else:
# Prova a convertire a int
layer_numeric = int(layer_val)
# Gestisci influence: usa valore minimo se mancante o zero
influence_val = node.get('influence', 0)
if influence_val is None or influence_val == 0:
influence_val = 0.001 # Valore minimo per visibilità
# Ottieni ctx_idx e mappa al token
ctx_idx_val = node.get('ctx_idx', 0)
token_str = token_map.get(ctx_idx_val, f"ctx_{ctx_idx_val}")
# Estrai feature_index dal node_id SOLO per nodi SAE
# Formato SAE: "layer_featureIndex_sequence" → es. "24_79427_7"
# Altri tipi (MLP error, embeddings, logits) usano formati diversi
node_id = node.get('node_id', '')
node_type = node.get('feature_type', '')
feature_idx = None
if node_type == 'cross layer transcoder':
# Solo per nodi SAE: estrai feature_idx da node_id
if node_id and '_' in node_id:
parts = node_id.split('_')
if len(parts) >= 2:
try:
# Il secondo elemento è il feature_index
feature_idx = int(parts[1])
except (ValueError, IndexError):
pass
# Se il parsing fallisce per un nodo SAE, skippa!
if feature_idx is None:
skipped_nodes.append(f"layer={layer_val}, node_id={node_id}, type=SAE")
continue # Salta nodi SAE malformati
else:
# Per nodi non-SAE (embeddings, logits, MLP error, ecc.):
# usa -1 come placeholder - NON estrarre da node_id!
feature_idx = -1
scatter_data.append({
'layer': layer_numeric,
'ctx_idx': ctx_idx_val,
'token': token_str,
'id': node_id,
'influence': influence_val,
'feature': feature_idx # Ora contiene l'indice corretto o -1 per non-features!
})
except (ValueError, TypeError):
# Salta nodi con layer non valido
continue
# Log nodi skippati se ce ne sono
if skipped_nodes:
st.warning(f"⚠️ {len(skipped_nodes)} feature nodes with malformed node_id were skipped")
with st.expander("Skipped nodes details"):
for node_info in skipped_nodes[:10]: # Mostra solo i primi 10
st.text(node_info)
if len(skipped_nodes) > 10:
st.text(f"... and {len(skipped_nodes) - 10} more nodes")
if not scatter_data:
st.warning("⚠️ No valid nodes found for plotting")
return
scatter_df = pd.DataFrame(scatter_data)
# Pulisci NaN e valori invalidi
scatter_df['influence'] = scatter_df['influence'].fillna(0.001)
scatter_df['influence'] = scatter_df['influence'].replace(0, 0.001)
# === BINNING PER EVITARE SOVRAPPOSIZIONI (stile Neuronpedia) ===
# Per ogni combinazione (ctx_idx, layer), distribuiamo i nodi su sub-colonne
import numpy as np
bin_width = 0.3 # Larghezza della sub-colonna
scatter_df['sub_column'] = 0
for (ctx, layer), group in scatter_df.groupby(['ctx_idx', 'layer']):
n_nodes = len(group)
if n_nodes > 1:
# Calcola quante sub-colonne servono (max 5 per evitare troppa dispersione)
n_bins = min(5, int(np.ceil(np.sqrt(n_nodes))))
# Assegna ogni nodo a una sub-colonna
for i, idx in enumerate(group.index):
sub_col = (i % n_bins) - (n_bins - 1) / 2 # Centra attorno a 0
scatter_df.at[idx, 'sub_column'] = sub_col * bin_width
# Applica offset per creare sub-colonne
scatter_df['ctx_idx_display'] = scatter_df['ctx_idx'] + scatter_df['sub_column']
# === FILTRO PER CUMULATIVE INFLUENCE ===
st.markdown("### 3️⃣ Filter Features by Cumulative Influence Coverage")
# Calcola il massimo valore di influence presente nei dati
max_influence = scatter_df['influence'].max()
# Mostra il node_threshold usato durante la generazione (se disponibile)
node_threshold_used = graph_data.get('metadata', {}).get('node_threshold', None)
if node_threshold_used is not None:
st.info(f"""
**The `influence` field is the cumulative coverage (0-{max_influence:.2f})** calculated by circuit tracer pruning. When nodes are sorted by descending influence, a node with `influence=0.65` means that
**up to that node** covers 65% of the total influence.
""")
else:
st.info(f"""
**The `influence` field is the cumulative coverage (0-{max_influence:.2f})** calculated by circuit tracer pruning.
When nodes are sorted by descending influence, a node with `influence=0.65` means that
**up to that node** covers 65% of the total influence.
""")
cumulative_threshold = st.slider(
"Cumulative Influence Threshold",
min_value=0.0,
max_value=float(max_influence),
value=float(max_influence),
step=0.01,
key="cumulative_slider_main",
help=f"Keep only nodes with influence ≤ threshold. Range: 0.0 - {max_influence:.2f} (max in data)"
)
# Checkbox per filtrare reconstruction error nodes
filter_error_nodes = st.checkbox(
"Exclude Reconstruction Error Nodes (feature = -1)",
value=False,
key="filter_error_checkbox",
help="Reconstruction error nodes represent the part of the model not explained by SAE features"
)
# Filtra usando direttamente il campo influence dal JSON
num_total = len(scatter_df)
# Identifica reconstruction error nodes (feature = -1) - KPI verrà calcolato dopo
is_error_node = scatter_df['feature'] == -1
n_error_total = is_error_node.sum()
pct_error_nodes = (n_error_total / num_total * 100) if num_total > 0 else 0
# Identifica embeddings e logits da mantenere sempre
is_embedding = scatter_df['layer'] == -1 # Layer 'E' mappato a -1
# Logits hanno layer massimo (es. layer 27 per gemma-2-2b con 26 layer + 1)
max_layer = scatter_df['layer'].max()
is_logit = scatter_df['layer'] == max_layer
# Applica filtri combinati: influence threshold + error nodes (se checkbox attivo)
if cumulative_threshold < 1.0:
mask_influence = scatter_df['influence'] <= cumulative_threshold
mask_keep = mask_influence | is_embedding | is_logit
else:
mask_keep = pd.Series([True] * len(scatter_df), index=scatter_df.index)
# Applica filtro error nodes se checkbox attivo
if filter_error_nodes:
# Escludi error nodes (feature = -1), ma mantieni embeddings/logits
mask_not_error = (scatter_df['feature'] != -1) | is_embedding | is_logit
mask_keep = mask_keep & mask_not_error
scatter_filtered = scatter_df[mask_keep].copy()
# Soglia di influence effettiva (max influence tra i nodi filtrati, escludendo embeddings/logits)
feature_nodes_filtered = scatter_filtered[~((scatter_filtered['layer'] == -1) | (scatter_filtered['layer'] == max_layer))]
if len(feature_nodes_filtered) > 0:
threshold_influence = feature_nodes_filtered['influence'].max()
else:
threshold_influence = 0.0
num_selected = len(scatter_filtered)
# Conta embeddings, features e error nodes nel dataset filtrato (prima di rimuovere logit)
is_embedding_filtered = scatter_filtered['layer'] == -1
max_layer_filtered = scatter_filtered['layer'].max()
is_logit_filtered = scatter_filtered['layer'] == max_layer_filtered
is_error_filtered = scatter_filtered['feature'] == -1
n_embeddings = len(scatter_filtered[is_embedding_filtered])
n_error_nodes = len(scatter_filtered[is_error_filtered & ~is_embedding_filtered & ~is_logit_filtered])
n_features = len(scatter_filtered[~(is_embedding_filtered | is_logit_filtered | is_error_filtered)])
n_logits_excluded = len(scatter_filtered[is_logit_filtered])
n_error_excluded = n_error_total - n_error_nodes if filter_error_nodes else 0
# Mostra statistiche filtro
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Total Nodes", num_total)
with col2:
st.metric("Selected Nodes", num_selected)
with col3:
pct = (num_selected / num_total * 100) if num_total > 0 else 0
st.metric("% Nodes", f"{pct:.1f}%")
with col4:
st.metric("Influence Threshold", f"{threshold_influence:.6f}")
# Usa il dataframe filtrato per il plot
scatter_df = scatter_filtered
# Ricalcola le sub-colonne per il dataset filtrato
scatter_df = scatter_df.copy()
scatter_df['sub_column'] = 0
for (ctx, layer), group in scatter_df.groupby(['ctx_idx', 'layer']):
n_nodes = len(group)
if n_nodes > 1:
n_bins = min(5, int(np.ceil(np.sqrt(n_nodes))))
for i, idx in enumerate(group.index):
sub_col = (i % n_bins) - (n_bins - 1) / 2
scatter_df.at[idx, 'sub_column'] = sub_col * bin_width
scatter_df['ctx_idx_display'] = scatter_df['ctx_idx'] + scatter_df['sub_column']
# Calcola node_influence (marginal influence) per il raggio dei cerchi/quadrati
# Se non presente nel JSON (vecchi grafi), calcoliamo al volo
if 'node_influence' not in scatter_df.columns:
# Calcola marginal influence come differenza tra cumulative consecutive
df_sorted_by_cumul = scatter_df.sort_values('influence').reset_index(drop=True)
df_sorted_by_cumul['node_influence'] = df_sorted_by_cumul['influence'].diff()
df_sorted_by_cumul.loc[0, 'node_influence'] = df_sorted_by_cumul.loc[0, 'influence']
# Remap al dataframe originale
node_id_to_marginal = dict(zip(df_sorted_by_cumul['id'], df_sorted_by_cumul['node_influence']))
scatter_df['node_influence'] = scatter_df['id'].map(node_id_to_marginal).fillna(scatter_df['influence'])
# CALCOLA KPI ERROR NODES (ora che node_influence è disponibile)
# Usa scatter_df (dataset completo prima della rimozione logit) per i KPI globali
is_error_in_complete = scatter_df['feature'] == -1
total_node_influence = scatter_df['node_influence'].sum()
error_node_influence = scatter_df[is_error_in_complete]['node_influence'].sum()
pct_error_influence = (error_node_influence / total_node_influence * 100) if total_node_influence > 0 else 0
# Mostra KPI reconstruction error nodes (prima del plot)
col1, col2 = st.columns(2)
with col1:
st.metric(
"% Error Nodes",
f"{pct_error_nodes:.1f}%",
help=f"{n_error_total} out of {num_total} total nodes are reconstruction error (feature=-1)"
)
with col2:
st.metric(
"% Node Influence (Error)",
f"{pct_error_influence:.1f}%",
help=f"Reconstruction error nodes contribute {pct_error_influence:.1f}% of total node_influence"
)
# Messaggio info con breakdown
info_parts = [f"{n_embeddings} embeddings", f"{n_features} features"]
if n_error_nodes > 0:
info_parts.append(f"{n_error_nodes} error nodes")
excluded_parts = [f"{n_logits_excluded} logits"]
if n_error_excluded > 0:
excluded_parts.append(f"{n_error_excluded} error nodes")
st.info(f"📊 Displaying {n_embeddings + n_features + n_error_nodes} nodes: {', '.join(info_parts)} ({', '.join(excluded_parts)} excluded)")
# Identifica i 2 gruppi: embeddings e features (escludi logits)
is_embedding_group = scatter_df['layer'] == -1
max_layer = scatter_df['layer'].max()
is_logit_group = scatter_df['layer'] == max_layer
is_feature_group = ~(is_embedding_group | is_logit_group)
# RIMUOVI I LOGIT dal dataset
scatter_df = scatter_df[~is_logit_group].copy()
# Ricalcola le maschere dopo il filtro
is_embedding_group = scatter_df['layer'] == -1
is_feature_group = scatter_df['layer'] != -1
# Aggiungi colonna per il tipo di nodo (solo 2 tipi ora)
scatter_df['node_type'] = 'feature'
scatter_df.loc[is_embedding_group, 'node_type'] = 'embedding'
# Calcola influence_log normalizzato per gruppo con formula più aggressiva
# Ogni gruppo ha la sua scala basata sul max del gruppo
scatter_df['influence_log'] = 0.0
for group_name, group_mask in [('embedding', is_embedding_group),
('feature', is_feature_group)]:
if group_mask.sum() > 0:
group_data = scatter_df[group_mask]['node_influence'].abs()
# Normalizza rispetto al max del gruppo
max_in_group = group_data.max()
if max_in_group > 0:
normalized = group_data / max_in_group
# Formula più aggressiva: usa power 3 per estremizzare le differenze
# normalized^3 rende i valori bassi molto più piccoli e i valori alti più grandi
# Moltiplica per 1000 per avere un buon range di grandezza
scatter_df.loc[group_mask, 'influence_log'] = (normalized ** 3) * 1000 + 10
else:
scatter_df.loc[group_mask, 'influence_log'] = 10 # Valore minimo default
# Crea scatter plot con simboli diversi per gruppo (solo embeddings e features)
symbol_map = {
'embedding': 'square',
'feature': 'circle'
}
fig = px.scatter(
scatter_df,
x='ctx_idx_display', # Usa posizione con offset
y='layer',
size='influence_log', # Usa scala aggressiva (power 3) normalizzata per gruppo
symbol='node_type', # Simbolo diverso per tipo
symbol_map=symbol_map,
color='node_type', # Colore diverso per tipo
color_discrete_map={
'embedding': '#4CAF50', # Verde per embeddings
'feature': '#808080' # Grigio per features
},
labels={
'id': 'Node ID',
'ctx_idx_display': 'Context Position',
'ctx_idx': 'ctx_idx',
'layer': 'Layer',
'influence': 'Cumulative Influence',
'node_influence': 'Node Influence',
'node_type': 'Node Type',
'token': 'Token',
'feature': 'Feature'
},
title='Features by Layer and Position (size: node_influence^3 normalized per group)',
hover_data={
'ctx_idx': True,
'token': True,
'layer': True,
'node_type': True,
'id': True,
'feature': True,
'node_influence': ':.6f', # Influenza marginale (grandezza simbolo)
'influence': ':.4f', # Cumulative influence (filtro slider)
'ctx_idx_display': False, # Nascondi la posizione modificata
'influence_log': False # Nascondi il valore logaritmico
}
)
# Personalizza il layout con alta trasparenza e outline marcato
# Applica a tutte le tracce (embeddings, features, logits)
max_influence_log = scatter_df['influence_log'].max()
fig.update_traces(
marker=dict(
sizemode='area',
sizeref=2.*max_influence_log/(50.**2) if max_influence_log > 0 else 1,
sizemin=2, # Dimensione minima
opacity=0.3, # Trasparenza medio-alta
line=dict(width=1.5, color='white') # Contorno bianco per distinguere
)
)
# Crea tick labels personalizzate per l'asse x (ctx_idx: token)
unique_ctx = sorted(scatter_df['ctx_idx'].unique())
tick_labels = [f"{ctx}: {token_map.get(ctx, '')}" for ctx in unique_ctx]
fig.update_layout(
template='plotly_white',
height=600,
showlegend=True, # Mostra legenda per i 3 gruppi
legend=dict(
title="Node Type",
orientation="v",
yanchor="top",
y=0.99,
xanchor="left",
x=0.99,
bgcolor="rgba(255,255,255,0.8)"
),
xaxis=dict(
gridcolor='lightgray',
tickmode='array',
tickvals=unique_ctx,
ticktext=tick_labels,
tickangle=-45
),
yaxis=dict(gridcolor='lightgray')
)
st.plotly_chart(fig, use_container_width=True)
# Mostra statistiche per gruppo
with st.expander("📊 Statistics by Group (Size Normalization)", expanded=False):
col1, col2 = st.columns(2)
with col1:
st.markdown("**🟩 Embeddings (green squares)**")
emb_data = scatter_df[scatter_df['node_type'] == 'embedding']
if len(emb_data) > 0:
st.metric("Nodes", len(emb_data))
st.metric("Max node_influence", f"{emb_data['node_influence'].max():.6f}")
st.metric("Mean node_influence", f"{emb_data['node_influence'].mean():.6f}")
st.metric("Min node_influence", f"{emb_data['node_influence'].min():.6f}")
else:
st.info("No embeddings in filtered dataset")
with col2:
st.markdown("**⚪ Features (gray circles)**")
feat_data = scatter_df[scatter_df['node_type'] == 'feature']
if len(feat_data) > 0:
st.metric("Nodes", len(feat_data))
st.metric("Max node_influence", f"{feat_data['node_influence'].max():.6f}")
st.metric("Mean node_influence", f"{feat_data['node_influence'].mean():.6f}")
st.metric("Min node_influence", f"{feat_data['node_influence'].min():.6f}")
else:
st.info("No features in filtered dataset")
st.info("""
💡 **Size formula**: `size = (normalized_node_influence)³ × 1000 + 10`
Size is normalized **per group** and uses **power 3** to emphasize differences:
- A node with 50% of max → size = 0.5³ = 12.5% (much smaller)
- A node with 80% of max → size = 0.8³ = 51.2%
- A node with 100% of max → size = 1.0³ = 100%
The 2 groups (embeddings and features) have independent scales.
Note: in the JSON the "influence" field is the pre-pruning cumulative, so estimating node_influence as the difference between consecutive cumulatives is only a normalized proxy (to be renormalized on the current set), because the graph may already be topologically pruned and the selection does not coincide with a contiguous prefix of sorted nodes.
""")
# === GRAFICO PARETO: NODE INFLUENCE (solo features, no embeddings/logits) ===
with st.expander("📈 Pareto Analysis Node Influence (Features only)", expanded=False):
try:
# Filtra solo features (scatter_df ha già rimosso i logit e ha node_type)
features_only = scatter_df[scatter_df['node_type'] == 'feature'].copy()
if len(features_only) == 0:
st.warning("⚠️ No features found in filtered dataset")
return
# Ordina per node_influence decrescente
sorted_df = features_only.sort_values('node_influence', ascending=False).reset_index(drop=True)
# Calcola rank e percentile
sorted_df['rank'] = range(1, len(sorted_df) + 1)
sorted_df['rank_pct'] = sorted_df['rank'] / len(sorted_df) * 100
# Calcola node_influence cumulativa (somma progressiva)
total_node_inf = sorted_df['node_influence'].sum()
if total_node_inf == 0:
st.warning("⚠️ Total Node influence is 0")
return
sorted_df['cumulative_node_influence'] = sorted_df['node_influence'].cumsum()
sorted_df['cumulative_node_influence_pct'] = sorted_df['cumulative_node_influence'] / total_node_inf * 100
# Crea grafico Pareto con doppio asse Y
import plotly.graph_objects as go
from plotly.subplots import make_subplots
# Crea subplot con asse Y secondario
fig_pareto = make_subplots(specs=[[{"secondary_y": True}]])
# Barra: node_influence individuale (limita a primi 100 nodi per leggibilità)
display_limit = min(100, len(sorted_df))
fig_pareto.add_trace(
go.Bar(
x=sorted_df['rank'][:display_limit],
y=sorted_df['node_influence'][:display_limit],
name='Node Influence',
marker=dict(color='#2196F3', opacity=0.6),
hovertemplate='Rank: %{x}
Node Influence: %{y:.6f}'
),
secondary_y=False
)
# Linea: cumulativa % (usa tutti i nodi)
fig_pareto.add_trace(
go.Scatter(
x=sorted_df['rank_pct'],
y=sorted_df['cumulative_node_influence_pct'],
mode='lines+markers',
name='Cumulative %',
line=dict(color='#FF5722', width=3),
marker=dict(size=4),
hovertemplate='Top %{x:.1f}% features
Cumulative: %{y:.1f}%'
),
secondary_y=True
)
# Linee di riferimento Pareto (80%, 90%, 95%)
for pct, label in [(80, '80%'), (90, '90%'), (95, '95%')]:
fig_pareto.add_hline(
y=pct,
line_dash="dash",
line_color="gray",
opacity=0.5,
secondary_y=True
)
fig_pareto.add_annotation(
x=100,
y=pct,
text=label,
showarrow=False,
xanchor='left',
yref='y2'
)
# Trova il "knee" (punto dove la cumulativa raggiunge 80%)
knee_idx = (sorted_df['cumulative_node_influence_pct'] >= 80).idxmax()
knee_rank_pct = sorted_df.loc[knee_idx, 'rank_pct']
knee_cumul = sorted_df.loc[knee_idx, 'cumulative_node_influence_pct']
fig_pareto.add_trace(
go.Scatter(
x=[knee_rank_pct],
y=[knee_cumul],
mode='markers',
name='Knee (80%)',
marker=dict(size=15, color='#4CAF50', symbol='diamond', line=dict(width=2, color='white')),
hovertemplate=f'Knee Point
Top {knee_rank_pct:.1f}% features
Cumulativa: {knee_cumul:.1f}%',
showlegend=True
),
secondary_y=True
)
# Layout
fig_pareto.update_xaxes(title_text="Rank % Features (by descending node_influence)")
fig_pareto.update_yaxes(title_text="Node Influence (individual)", secondary_y=False)
fig_pareto.update_yaxes(title_text="Cumulative % Node Influence", secondary_y=True, range=[0, 105])
fig_pareto.update_layout(
height=500,
showlegend=True,
template='plotly_white',
legend=dict(x=0.02, y=0.98, xanchor='left', yanchor='top'),
title="Pareto Chart: Node Influence of Features"
)
st.plotly_chart(fig_pareto, use_container_width=True)
# Statistiche chiave Pareto
st.markdown("#### 📊 Pareto Statistics (Node Influence)")
col1, col2, col3, col4 = st.columns(4)
# Trova percentili chiave
top_10_idx = max(0, int(len(sorted_df) * 0.1))
top_20_idx = max(0, int(len(sorted_df) * 0.2))
top_50_idx = max(0, int(len(sorted_df) * 0.5))
top_10_pct = sorted_df['cumulative_node_influence_pct'].iloc[top_10_idx] if top_10_idx < len(sorted_df) else 0
top_20_pct = sorted_df['cumulative_node_influence_pct'].iloc[top_20_idx] if top_20_idx < len(sorted_df) else 0
top_50_pct = sorted_df['cumulative_node_influence_pct'].iloc[top_50_idx] if top_50_idx < len(sorted_df) else 0
with col1:
st.metric("Top 10% features", f"{top_10_pct:.1f}% node_influence",
help=f"The top {int(len(sorted_df)*0.1)} most influential features cover {top_10_pct:.1f}% of total influence")
with col2:
st.metric("Top 20% features", f"{top_20_pct:.1f}% node_influence",
help=f"The top {int(len(sorted_df)*0.2)} most influential features cover {top_20_pct:.1f}% of total influence")
with col3:
st.metric("Top 50% features", f"{top_50_pct:.1f}% node_influence",
help=f"The top {int(len(sorted_df)*0.5)} most influential features cover {top_50_pct:.1f}% of total influence")
with col4:
# Gini coefficient
gini = 1 - 2 * np.trapz(sorted_df['cumulative_node_influence_pct'] / 100, sorted_df['rank_pct'] / 100)
st.metric("Gini Coefficient", f"{gini:.3f}", help="0 = equal distribution, 1 = highly concentrated")
# Info sul knee point e suggerimento threshold
# sorted_df[knee_idx] ci dà la riga del knee point
knee_cumul_threshold = sorted_df.loc[knee_idx, 'influence'] if 'influence' in sorted_df.columns else scatter_df['influence'].max()
st.success(f"""
🎯 **Knee Point (80%)**: The first **{knee_rank_pct:.1f}%** of features ({int(len(sorted_df) * knee_rank_pct / 100)} nodes)
cover **80%** of total node_influence.
💡 **Threshold Suggestion**: To focus on features up to the knee point (80%),
use `cumulative_threshold ≈ {knee_cumul_threshold:.4f}` in the slider above.
""")
# Histogram distribuzione node_influence (opzionale, in expander)
with st.expander("📊 Node Influence Distribution Histogram", expanded=False):
fig_hist = px.histogram(
sorted_df,
x='node_influence',
nbins=50,
title='Node Influence Distribution (Features)',
labels={'node_influence': 'Node Influence', 'count': 'Frequency'},
color_discrete_sequence=['#2196F3']
)
fig_hist.update_layout(
height=350,
template='plotly_white',
showlegend=False
)
fig_hist.update_traces(marker=dict(opacity=0.7))
st.plotly_chart(fig_hist, use_container_width=True)
# Statistiche distribuzione
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Mean", f"{sorted_df['node_influence'].mean():.6f}")
with col2:
st.metric("Median", f"{sorted_df['node_influence'].median():.6f}")
with col3:
st.metric("Std Dev", f"{sorted_df['node_influence'].std():.6f}")
with col4:
st.metric("Max", f"{sorted_df['node_influence'].max():.6f}")
except Exception as e:
st.error(f"❌ Error creating distribution chart: {str(e)}")
import traceback
st.code(traceback.format_exc())
# Ritorna le feature filtrate (solo SAE features, no embeddings/logits/errors)
# Utile per export
sae_features_only = scatter_filtered[
~(is_embedding_filtered | is_logit_filtered | is_error_filtered)
].copy()
return sae_features_only