Spaces:
Paused
Paused
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import clip | |
| from sae import SAE | |
| import os | |
| # --- 1. Setup and Model Loading --- | |
| # Use GPU if available, otherwise CPU | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| # Define file paths for clarity | |
| SAE_MODEL_PATH = '6144_768_TopKReLU_64_RW_False_False_0.0_cc3m_ViT-L~14_train_image_2905936_768.pth' | |
| VOCAB_SCORES_PATH = 'Concept_Interpreter_6144_768_TopKReLU_64_RW_False_False_0.0_cc3m_ViT-L~14_train_image_2905936_768_disect_ViT-L~14_-1_text_20000_768.npy' | |
| VOCAB_NAMES_PATH = 'clip_disect_20k.txt' | |
| # Load models and data | |
| try: | |
| # Load CLIP model | |
| model, preprocess = clip.load("ViT-L/14", device=device) | |
| # Load Sparse Autoencoder (SAE) model | |
| # Ensure the SAE class correctly handles moving the model to the specified device | |
| sae_model = SAE(SAE_MODEL_PATH).to(device).eval() | |
| # Load pre-computed vocabulary scores and names | |
| vocab_scores = np.load(VOCAB_SCORES_PATH) | |
| with open(VOCAB_NAMES_PATH, 'r') as f: | |
| vocab_names = [line.strip().lower() for line in f.readlines()] | |
| except FileNotFoundError as e: | |
| print(f"ERROR: A required file was not found: {e.filename}") | |
| print("Please ensure all model and vocabulary files are present in the correct paths.") | |
| # Exit if essential files are missing | |
| exit() | |
| # Pre-calculate mappings for faster lookup | |
| # For a given feature index, what is the best concept name? | |
| feature_to_concept_score = np.max(vocab_scores, axis=0) | |
| feature_to_concept_name_idx = np.argmax(vocab_scores, axis=0) | |
| # For a given concept name, what is the best feature index? | |
| concept_to_feature_score = np.max(vocab_scores, axis=1) | |
| concept_to_feature_idx = np.argmax(vocab_scores, axis=1) | |
| # --- 2. Helper and Core Logic Functions --- | |
| def calculate_fvu(original_input, reconstruction): | |
| """Calculates the Fraction of Variance Unexplained (FVU).""" | |
| variance = (original_input - original_input.mean(dim=-1, keepdim=True)).var(dim=-1) | |
| recon_error_variance = (original_input - reconstruction).var(dim=-1) | |
| # Clamp to avoid division by zero or tiny numbers | |
| fvu_val = (recon_error_variance / (variance + 1e-8)).mean() | |
| return fvu_val.item() | |
| def predict(input_img, top_k, concept, neg_concept, max_strength, add_error): | |
| """ | |
| Main function to process an image, identify top concepts, and visualize concept manipulation. | |
| """ | |
| if not input_img: | |
| raise gr.Error("Please provide an input image.") | |
| # --- Part A: Top Concepts Analysis --- | |
| # Preprocess the input image and move to the correct device | |
| image_input_processed = preprocess(input_img.convert("RGB")).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| # Encode the image with CLIP | |
| image_features = model.encode_image(image_input_processed).to(torch.float32) | |
| # Get SAE reconstruction and latent activations | |
| reconstructed_features, _, full_latents = sae_model(image_features) | |
| fvu_score = calculate_fvu(image_features, reconstructed_features) | |
| error = image_features - reconstructed_features | |
| # Get the top K activating SAE features for the image | |
| full_latents = full_latents.cpu().flatten() | |
| top_sae_values, top_sae_indices = full_latents.topk(k=top_k) | |
| # Create the bar plot for top concepts | |
| fig_bar, ax_bar = plt.subplots(figsize=(10, 6)) | |
| concept_labels = [ | |
| f"{vocab_names[feature_to_concept_name_idx[i]]} ({feature_to_concept_score[i]:.2f})" | |
| for i in top_sae_indices | |
| ] | |
| ax_bar.barh(range(top_k), top_sae_values.numpy(), color='skyblue') | |
| ax_bar.set_yticks(range(top_k)) | |
| ax_bar.set_yticklabels(concept_labels) | |
| ax_bar.invert_yaxis() # Display top concept at the top | |
| ax_bar.set_xlabel("SAE Feature Activation") | |
| ax_bar.set_title(f"Top {top_k} Concepts (Concept Name (Concept Similarity Score)) with FVU: {fvu_score:.2f}") | |
| plt.tight_layout() | |
| # --- Part B: Concept Manipulation --- | |
| # Validate the user-provided concept | |
| concept = concept.lower().strip() | |
| if concept not in vocab_names: | |
| raise gr.Error(f"Concept '{concept}' not found in vocabulary. Please choose a valid concept.") | |
| # Get the feature index corresponding to the chosen concept | |
| concept_feature_id = concept_to_feature_idx[vocab_names.index(concept)] | |
| concept_assign_score = concept_to_feature_score[vocab_names.index(concept)] | |
| # Get the original activation strength of this concept in the input image | |
| original_strength = full_latents[concept_feature_id].item() | |
| # Create positive and negative text prompts | |
| if not neg_concept: | |
| neg_concept_prompt = f"a photo without {concept}" | |
| else: | |
| neg_concept_prompt = f"a photo with {neg_concept.lower().strip()}" | |
| pos_concept_prompt = f"a photo with {concept}" | |
| # Tokenize prompts and encode with CLIP | |
| text_labels = clip.tokenize([pos_concept_prompt, neg_concept_prompt]).to(device) | |
| with torch.no_grad(): | |
| text_features = model.encode_text(text_labels) | |
| text_features /= text_features.norm(dim=-1, keepdim=True) | |
| # Define the range of strengths to test | |
| strengths = torch.linspace(0.0, max_strength, 11).to(device) | |
| pos_concept_probs, neg_concept_probs, cos_sims = [], [], [] | |
| original_reconstructed_norm = reconstructed_features / reconstructed_features.norm(dim=-1, keepdim=True) | |
| for st in strengths: | |
| with torch.no_grad(): | |
| # Create a copy of latents and modify the target concept feature | |
| modified_latents = full_latents.clone().to(device).reshape(1, -1) | |
| modified_latents[:, concept_feature_id] = st | |
| # Decode the modified latents back into feature space | |
| modified_reconstructed = sae_model.decode(modified_latents) | |
| if add_error: | |
| modified_reconstructed = modified_reconstructed + error | |
| # Normalize for comparison | |
| modified_reconstructed_norm = modified_reconstructed / modified_reconstructed.norm(dim=-1, keepdim=True) | |
| # Calculate similarity to the text prompts (probabilities) | |
| probs = (100.0 * modified_reconstructed_norm @ text_features.T).softmax(dim=-1) | |
| pos_concept_probs.append(probs[0, 0].item()) | |
| neg_concept_probs.append(probs[0, 1].item()) | |
| # Calculate cosine similarity to the original reconstructed image | |
| cos_sims.append( | |
| torch.nn.functional.cosine_similarity(modified_reconstructed_norm, original_reconstructed_norm).item() | |
| ) | |
| # Create the line plot for concept manipulation | |
| fig_line, ax_line = plt.subplots(figsize=(10, 6)) | |
| strengths_cpu = strengths.cpu().numpy() | |
| ax_line.plot(strengths_cpu, pos_concept_probs, 'o-', label=f'"{pos_concept_prompt}"') | |
| ax_line.plot(strengths_cpu, neg_concept_probs, 'o-', label=f'"{neg_concept_prompt}"') | |
| # ✨ NEW: Add a vertical line indicating the original strength of the concept | |
| ax_line.axvline(x=original_strength, color='purple', linestyle='--', label=f'Original Strength ({original_strength:.2f})') | |
| # Add cosine similarity on a secondary y-axis | |
| ax2 = ax_line.twinx() | |
| ax2.plot(strengths_cpu, cos_sims, 'x-', color='green', label='Similarity to Original') | |
| ax2.set_ylabel('Cosine Similarity', color='green') | |
| ax2.tick_params(axis='y', labelcolor='green') | |
| ax_line.set_xlabel("Magnitude of the Concept SAE Feature") | |
| ax_line.set_ylabel("CLIP Probability") | |
| ax_line.set_title(f"Effect of Modifying Concept '{concept}' (Assignment Score: {concept_assign_score:.2f})") | |
| fig_line.legend(loc="upper right", bbox_to_anchor=(1, 1), bbox_transform=ax_line.transAxes) | |
| plt.tight_layout() | |
| # Close figures to free memory | |
| plt.close(fig_bar) | |
| plt.close(fig_line) | |
| return input_img, fig_bar, fig_line | |
| # --- 3. Gradio Interface --- | |
| with gr.Blocks(theme=gr.themes.Soft(), title="Matryoshka Sparse Autoencoder (MSAE) Example") as demo: | |
| gr.Markdown( | |
| "Based on the paper: [Interpreting CLIP with Hierarchical Sparse Autoencoders](https://openreview.net/forum?id=5MQQsenQBm) with [github code](https://github.com/WolodjaZ/MSAE). " | |
| "Upload an image to see its top activating concepts from a sparse autoencoder. Then, choose a concept (from `clip_disect_20k.txt`) to visualize how manipulating its corresponding concept magnitude affects the image representation." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image(label="Input Image", sources=['upload', 'webcam'], type="pil") | |
| gr.Markdown("### Analysis & Manipulation Controls") | |
| top_k_slider = gr.Slider(minimum=3, maximum=20, value=10, step=1, label="Numb of Top K Concepts to visualize") | |
| concept_input = gr.Textbox(label="Concept to Manipulate", value="hair", placeholder="e.g., hair") | |
| neg_concept_input = gr.Textbox(label="Negative Concept (Optional)", placeholder="e.g., a frown") | |
| max_strength_slider = gr.Slider(minimum=1.0, maximum=20.0, value=10.0, step=0.5, label="Max Concept Strength") | |
| add_error_checkbox = gr.Checkbox(label="Add error term to reconstruction") | |
| submit_btn = gr.Button("Analyze and Interpret", variant="primary") | |
| with gr.Column(scale=2): | |
| gr.Markdown("### Results") | |
| output_image = gr.Image(label="Original Image", interactive=False) | |
| output_bar_plot = gr.Plot() | |
| output_line_plot = gr.Plot() | |
| gr.Examples( | |
| examples=[ | |
| ["bird.jpg", 10, "birds", "", 10.0, True], | |
| ["statue.jpg", 10, "statue", "humans", 10.0, True], | |
| ], | |
| inputs=[image_input, top_k_slider, concept_input, neg_concept_input, max_strength_slider, add_error_checkbox], | |
| outputs=[output_image, output_bar_plot, output_line_plot], | |
| fn=predict, | |
| cache_examples=True | |
| ) | |
| # Wire up the button to the function | |
| submit_btn.click( | |
| fn=predict, | |
| inputs=[image_input, top_k_slider, concept_input, neg_concept_input, max_strength_slider], | |
| outputs=[output_image, output_bar_plot, output_line_plot] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(debug=True) | |