MSAE_ICML2025 / app.py
WolodjaZ's picture
Update app.py
0f8256c verified
import torch
import numpy as np
import gradio as gr
import matplotlib.pyplot as plt
import clip
from sae import SAE
import os
# --- 1. Setup and Model Loading ---
# Use GPU if available, otherwise CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Define file paths for clarity
SAE_MODEL_PATH = '6144_768_TopKReLU_64_RW_False_False_0.0_cc3m_ViT-L~14_train_image_2905936_768.pth'
VOCAB_SCORES_PATH = 'Concept_Interpreter_6144_768_TopKReLU_64_RW_False_False_0.0_cc3m_ViT-L~14_train_image_2905936_768_disect_ViT-L~14_-1_text_20000_768.npy'
VOCAB_NAMES_PATH = 'clip_disect_20k.txt'
# Load models and data
try:
# Load CLIP model
model, preprocess = clip.load("ViT-L/14", device=device)
# Load Sparse Autoencoder (SAE) model
# Ensure the SAE class correctly handles moving the model to the specified device
sae_model = SAE(SAE_MODEL_PATH).to(device).eval()
# Load pre-computed vocabulary scores and names
vocab_scores = np.load(VOCAB_SCORES_PATH)
with open(VOCAB_NAMES_PATH, 'r') as f:
vocab_names = [line.strip().lower() for line in f.readlines()]
except FileNotFoundError as e:
print(f"ERROR: A required file was not found: {e.filename}")
print("Please ensure all model and vocabulary files are present in the correct paths.")
# Exit if essential files are missing
exit()
# Pre-calculate mappings for faster lookup
# For a given feature index, what is the best concept name?
feature_to_concept_score = np.max(vocab_scores, axis=0)
feature_to_concept_name_idx = np.argmax(vocab_scores, axis=0)
# For a given concept name, what is the best feature index?
concept_to_feature_score = np.max(vocab_scores, axis=1)
concept_to_feature_idx = np.argmax(vocab_scores, axis=1)
# --- 2. Helper and Core Logic Functions ---
def calculate_fvu(original_input, reconstruction):
"""Calculates the Fraction of Variance Unexplained (FVU)."""
variance = (original_input - original_input.mean(dim=-1, keepdim=True)).var(dim=-1)
recon_error_variance = (original_input - reconstruction).var(dim=-1)
# Clamp to avoid division by zero or tiny numbers
fvu_val = (recon_error_variance / (variance + 1e-8)).mean()
return fvu_val.item()
def predict(input_img, top_k, concept, neg_concept, max_strength, add_error):
"""
Main function to process an image, identify top concepts, and visualize concept manipulation.
"""
if not input_img:
raise gr.Error("Please provide an input image.")
# --- Part A: Top Concepts Analysis ---
# Preprocess the input image and move to the correct device
image_input_processed = preprocess(input_img.convert("RGB")).unsqueeze(0).to(device)
with torch.no_grad():
# Encode the image with CLIP
image_features = model.encode_image(image_input_processed).to(torch.float32)
# Get SAE reconstruction and latent activations
reconstructed_features, _, full_latents = sae_model(image_features)
fvu_score = calculate_fvu(image_features, reconstructed_features)
error = image_features - reconstructed_features
# Get the top K activating SAE features for the image
full_latents = full_latents.cpu().flatten()
top_sae_values, top_sae_indices = full_latents.topk(k=top_k)
# Create the bar plot for top concepts
fig_bar, ax_bar = plt.subplots(figsize=(10, 6))
concept_labels = [
f"{vocab_names[feature_to_concept_name_idx[i]]} ({feature_to_concept_score[i]:.2f})"
for i in top_sae_indices
]
ax_bar.barh(range(top_k), top_sae_values.numpy(), color='skyblue')
ax_bar.set_yticks(range(top_k))
ax_bar.set_yticklabels(concept_labels)
ax_bar.invert_yaxis() # Display top concept at the top
ax_bar.set_xlabel("SAE Feature Activation")
ax_bar.set_title(f"Top {top_k} Concepts (Concept Name (Concept Similarity Score)) with FVU: {fvu_score:.2f}")
plt.tight_layout()
# --- Part B: Concept Manipulation ---
# Validate the user-provided concept
concept = concept.lower().strip()
if concept not in vocab_names:
raise gr.Error(f"Concept '{concept}' not found in vocabulary. Please choose a valid concept.")
# Get the feature index corresponding to the chosen concept
concept_feature_id = concept_to_feature_idx[vocab_names.index(concept)]
concept_assign_score = concept_to_feature_score[vocab_names.index(concept)]
# Get the original activation strength of this concept in the input image
original_strength = full_latents[concept_feature_id].item()
# Create positive and negative text prompts
if not neg_concept:
neg_concept_prompt = f"a photo without {concept}"
else:
neg_concept_prompt = f"a photo with {neg_concept.lower().strip()}"
pos_concept_prompt = f"a photo with {concept}"
# Tokenize prompts and encode with CLIP
text_labels = clip.tokenize([pos_concept_prompt, neg_concept_prompt]).to(device)
with torch.no_grad():
text_features = model.encode_text(text_labels)
text_features /= text_features.norm(dim=-1, keepdim=True)
# Define the range of strengths to test
strengths = torch.linspace(0.0, max_strength, 11).to(device)
pos_concept_probs, neg_concept_probs, cos_sims = [], [], []
original_reconstructed_norm = reconstructed_features / reconstructed_features.norm(dim=-1, keepdim=True)
for st in strengths:
with torch.no_grad():
# Create a copy of latents and modify the target concept feature
modified_latents = full_latents.clone().to(device).reshape(1, -1)
modified_latents[:, concept_feature_id] = st
# Decode the modified latents back into feature space
modified_reconstructed = sae_model.decode(modified_latents)
if add_error:
modified_reconstructed = modified_reconstructed + error
# Normalize for comparison
modified_reconstructed_norm = modified_reconstructed / modified_reconstructed.norm(dim=-1, keepdim=True)
# Calculate similarity to the text prompts (probabilities)
probs = (100.0 * modified_reconstructed_norm @ text_features.T).softmax(dim=-1)
pos_concept_probs.append(probs[0, 0].item())
neg_concept_probs.append(probs[0, 1].item())
# Calculate cosine similarity to the original reconstructed image
cos_sims.append(
torch.nn.functional.cosine_similarity(modified_reconstructed_norm, original_reconstructed_norm).item()
)
# Create the line plot for concept manipulation
fig_line, ax_line = plt.subplots(figsize=(10, 6))
strengths_cpu = strengths.cpu().numpy()
ax_line.plot(strengths_cpu, pos_concept_probs, 'o-', label=f'"{pos_concept_prompt}"')
ax_line.plot(strengths_cpu, neg_concept_probs, 'o-', label=f'"{neg_concept_prompt}"')
# ✨ NEW: Add a vertical line indicating the original strength of the concept
ax_line.axvline(x=original_strength, color='purple', linestyle='--', label=f'Original Strength ({original_strength:.2f})')
# Add cosine similarity on a secondary y-axis
ax2 = ax_line.twinx()
ax2.plot(strengths_cpu, cos_sims, 'x-', color='green', label='Similarity to Original')
ax2.set_ylabel('Cosine Similarity', color='green')
ax2.tick_params(axis='y', labelcolor='green')
ax_line.set_xlabel("Magnitude of the Concept SAE Feature")
ax_line.set_ylabel("CLIP Probability")
ax_line.set_title(f"Effect of Modifying Concept '{concept}' (Assignment Score: {concept_assign_score:.2f})")
fig_line.legend(loc="upper right", bbox_to_anchor=(1, 1), bbox_transform=ax_line.transAxes)
plt.tight_layout()
# Close figures to free memory
plt.close(fig_bar)
plt.close(fig_line)
return input_img, fig_bar, fig_line
# --- 3. Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft(), title="Matryoshka Sparse Autoencoder (MSAE) Example") as demo:
gr.Markdown(
"Based on the paper: [Interpreting CLIP with Hierarchical Sparse Autoencoders](https://openreview.net/forum?id=5MQQsenQBm) with [github code](https://github.com/WolodjaZ/MSAE). "
"Upload an image to see its top activating concepts from a sparse autoencoder. Then, choose a concept (from `clip_disect_20k.txt`) to visualize how manipulating its corresponding concept magnitude affects the image representation."
)
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(label="Input Image", sources=['upload', 'webcam'], type="pil")
gr.Markdown("### Analysis & Manipulation Controls")
top_k_slider = gr.Slider(minimum=3, maximum=20, value=10, step=1, label="Numb of Top K Concepts to visualize")
concept_input = gr.Textbox(label="Concept to Manipulate", value="hair", placeholder="e.g., hair")
neg_concept_input = gr.Textbox(label="Negative Concept (Optional)", placeholder="e.g., a frown")
max_strength_slider = gr.Slider(minimum=1.0, maximum=20.0, value=10.0, step=0.5, label="Max Concept Strength")
add_error_checkbox = gr.Checkbox(label="Add error term to reconstruction")
submit_btn = gr.Button("Analyze and Interpret", variant="primary")
with gr.Column(scale=2):
gr.Markdown("### Results")
output_image = gr.Image(label="Original Image", interactive=False)
output_bar_plot = gr.Plot()
output_line_plot = gr.Plot()
gr.Examples(
examples=[
["bird.jpg", 10, "birds", "", 10.0, True],
["statue.jpg", 10, "statue", "humans", 10.0, True],
],
inputs=[image_input, top_k_slider, concept_input, neg_concept_input, max_strength_slider, add_error_checkbox],
outputs=[output_image, output_bar_plot, output_line_plot],
fn=predict,
cache_examples=True
)
# Wire up the button to the function
submit_btn.click(
fn=predict,
inputs=[image_input, top_k_slider, concept_input, neg_concept_input, max_strength_slider],
outputs=[output_image, output_bar_plot, output_line_plot]
)
if __name__ == "__main__":
demo.launch(debug=True)