import torch import numpy as np import gradio as gr import matplotlib.pyplot as plt import clip from sae import SAE import os # --- 1. Setup and Model Loading --- # Use GPU if available, otherwise CPU device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # Define file paths for clarity SAE_MODEL_PATH = '6144_768_TopKReLU_64_RW_False_False_0.0_cc3m_ViT-L~14_train_image_2905936_768.pth' VOCAB_SCORES_PATH = 'Concept_Interpreter_6144_768_TopKReLU_64_RW_False_False_0.0_cc3m_ViT-L~14_train_image_2905936_768_disect_ViT-L~14_-1_text_20000_768.npy' VOCAB_NAMES_PATH = 'clip_disect_20k.txt' # Load models and data try: # Load CLIP model model, preprocess = clip.load("ViT-L/14", device=device) # Load Sparse Autoencoder (SAE) model # Ensure the SAE class correctly handles moving the model to the specified device sae_model = SAE(SAE_MODEL_PATH).to(device).eval() # Load pre-computed vocabulary scores and names vocab_scores = np.load(VOCAB_SCORES_PATH) with open(VOCAB_NAMES_PATH, 'r') as f: vocab_names = [line.strip().lower() for line in f.readlines()] except FileNotFoundError as e: print(f"ERROR: A required file was not found: {e.filename}") print("Please ensure all model and vocabulary files are present in the correct paths.") # Exit if essential files are missing exit() # Pre-calculate mappings for faster lookup # For a given feature index, what is the best concept name? feature_to_concept_score = np.max(vocab_scores, axis=0) feature_to_concept_name_idx = np.argmax(vocab_scores, axis=0) # For a given concept name, what is the best feature index? concept_to_feature_score = np.max(vocab_scores, axis=1) concept_to_feature_idx = np.argmax(vocab_scores, axis=1) # --- 2. Helper and Core Logic Functions --- def calculate_fvu(original_input, reconstruction): """Calculates the Fraction of Variance Unexplained (FVU).""" variance = (original_input - original_input.mean(dim=-1, keepdim=True)).var(dim=-1) recon_error_variance = (original_input - reconstruction).var(dim=-1) # Clamp to avoid division by zero or tiny numbers fvu_val = (recon_error_variance / (variance + 1e-8)).mean() return fvu_val.item() def predict(input_img, top_k, concept, neg_concept, max_strength, add_error): """ Main function to process an image, identify top concepts, and visualize concept manipulation. """ if not input_img: raise gr.Error("Please provide an input image.") # --- Part A: Top Concepts Analysis --- # Preprocess the input image and move to the correct device image_input_processed = preprocess(input_img.convert("RGB")).unsqueeze(0).to(device) with torch.no_grad(): # Encode the image with CLIP image_features = model.encode_image(image_input_processed).to(torch.float32) # Get SAE reconstruction and latent activations reconstructed_features, _, full_latents = sae_model(image_features) fvu_score = calculate_fvu(image_features, reconstructed_features) error = image_features - reconstructed_features # Get the top K activating SAE features for the image full_latents = full_latents.cpu().flatten() top_sae_values, top_sae_indices = full_latents.topk(k=top_k) # Create the bar plot for top concepts fig_bar, ax_bar = plt.subplots(figsize=(10, 6)) concept_labels = [ f"{vocab_names[feature_to_concept_name_idx[i]]} ({feature_to_concept_score[i]:.2f})" for i in top_sae_indices ] ax_bar.barh(range(top_k), top_sae_values.numpy(), color='skyblue') ax_bar.set_yticks(range(top_k)) ax_bar.set_yticklabels(concept_labels) ax_bar.invert_yaxis() # Display top concept at the top ax_bar.set_xlabel("SAE Feature Activation") ax_bar.set_title(f"Top {top_k} Concepts (Concept Name (Concept Similarity Score)) with FVU: {fvu_score:.2f}") plt.tight_layout() # --- Part B: Concept Manipulation --- # Validate the user-provided concept concept = concept.lower().strip() if concept not in vocab_names: raise gr.Error(f"Concept '{concept}' not found in vocabulary. Please choose a valid concept.") # Get the feature index corresponding to the chosen concept concept_feature_id = concept_to_feature_idx[vocab_names.index(concept)] concept_assign_score = concept_to_feature_score[vocab_names.index(concept)] # Get the original activation strength of this concept in the input image original_strength = full_latents[concept_feature_id].item() # Create positive and negative text prompts if not neg_concept: neg_concept_prompt = f"a photo without {concept}" else: neg_concept_prompt = f"a photo with {neg_concept.lower().strip()}" pos_concept_prompt = f"a photo with {concept}" # Tokenize prompts and encode with CLIP text_labels = clip.tokenize([pos_concept_prompt, neg_concept_prompt]).to(device) with torch.no_grad(): text_features = model.encode_text(text_labels) text_features /= text_features.norm(dim=-1, keepdim=True) # Define the range of strengths to test strengths = torch.linspace(0.0, max_strength, 11).to(device) pos_concept_probs, neg_concept_probs, cos_sims = [], [], [] original_reconstructed_norm = reconstructed_features / reconstructed_features.norm(dim=-1, keepdim=True) for st in strengths: with torch.no_grad(): # Create a copy of latents and modify the target concept feature modified_latents = full_latents.clone().to(device).reshape(1, -1) modified_latents[:, concept_feature_id] = st # Decode the modified latents back into feature space modified_reconstructed = sae_model.decode(modified_latents) if add_error: modified_reconstructed = modified_reconstructed + error # Normalize for comparison modified_reconstructed_norm = modified_reconstructed / modified_reconstructed.norm(dim=-1, keepdim=True) # Calculate similarity to the text prompts (probabilities) probs = (100.0 * modified_reconstructed_norm @ text_features.T).softmax(dim=-1) pos_concept_probs.append(probs[0, 0].item()) neg_concept_probs.append(probs[0, 1].item()) # Calculate cosine similarity to the original reconstructed image cos_sims.append( torch.nn.functional.cosine_similarity(modified_reconstructed_norm, original_reconstructed_norm).item() ) # Create the line plot for concept manipulation fig_line, ax_line = plt.subplots(figsize=(10, 6)) strengths_cpu = strengths.cpu().numpy() ax_line.plot(strengths_cpu, pos_concept_probs, 'o-', label=f'"{pos_concept_prompt}"') ax_line.plot(strengths_cpu, neg_concept_probs, 'o-', label=f'"{neg_concept_prompt}"') # ✨ NEW: Add a vertical line indicating the original strength of the concept ax_line.axvline(x=original_strength, color='purple', linestyle='--', label=f'Original Strength ({original_strength:.2f})') # Add cosine similarity on a secondary y-axis ax2 = ax_line.twinx() ax2.plot(strengths_cpu, cos_sims, 'x-', color='green', label='Similarity to Original') ax2.set_ylabel('Cosine Similarity', color='green') ax2.tick_params(axis='y', labelcolor='green') ax_line.set_xlabel("Magnitude of the Concept SAE Feature") ax_line.set_ylabel("CLIP Probability") ax_line.set_title(f"Effect of Modifying Concept '{concept}' (Assignment Score: {concept_assign_score:.2f})") fig_line.legend(loc="upper right", bbox_to_anchor=(1, 1), bbox_transform=ax_line.transAxes) plt.tight_layout() # Close figures to free memory plt.close(fig_bar) plt.close(fig_line) return input_img, fig_bar, fig_line # --- 3. Gradio Interface --- with gr.Blocks(theme=gr.themes.Soft(), title="Matryoshka Sparse Autoencoder (MSAE) Example") as demo: gr.Markdown( "Based on the paper: [Interpreting CLIP with Hierarchical Sparse Autoencoders](https://openreview.net/forum?id=5MQQsenQBm) with [github code](https://github.com/WolodjaZ/MSAE). " "Upload an image to see its top activating concepts from a sparse autoencoder. Then, choose a concept (from `clip_disect_20k.txt`) to visualize how manipulating its corresponding concept magnitude affects the image representation." ) with gr.Row(): with gr.Column(scale=1): image_input = gr.Image(label="Input Image", sources=['upload', 'webcam'], type="pil") gr.Markdown("### Analysis & Manipulation Controls") top_k_slider = gr.Slider(minimum=3, maximum=20, value=10, step=1, label="Numb of Top K Concepts to visualize") concept_input = gr.Textbox(label="Concept to Manipulate", value="hair", placeholder="e.g., hair") neg_concept_input = gr.Textbox(label="Negative Concept (Optional)", placeholder="e.g., a frown") max_strength_slider = gr.Slider(minimum=1.0, maximum=20.0, value=10.0, step=0.5, label="Max Concept Strength") add_error_checkbox = gr.Checkbox(label="Add error term to reconstruction") submit_btn = gr.Button("Analyze and Interpret", variant="primary") with gr.Column(scale=2): gr.Markdown("### Results") output_image = gr.Image(label="Original Image", interactive=False) output_bar_plot = gr.Plot() output_line_plot = gr.Plot() gr.Examples( examples=[ ["bird.jpg", 10, "birds", "", 10.0, True], ["statue.jpg", 10, "statue", "humans", 10.0, True], ], inputs=[image_input, top_k_slider, concept_input, neg_concept_input, max_strength_slider, add_error_checkbox], outputs=[output_image, output_bar_plot, output_line_plot], fn=predict, cache_examples=True ) # Wire up the button to the function submit_btn.click( fn=predict, inputs=[image_input, top_k_slider, concept_input, neg_concept_input, max_strength_slider], outputs=[output_image, output_bar_plot, output_line_plot] ) if __name__ == "__main__": demo.launch(debug=True)