File size: 8,077 Bytes
33b542e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
import gradio as gr
import torch
import matplotlib.pyplot as plt
import os
from PIL import Image
import numpy as np
# Import your custom modules
from SDLens import HookedStableDiffusionXLPipeline
from training.k_sparse_autoencoder import SparseAutoencoder
from utils.hooks import add_feature_on_text_prompt
# Function to modulate hooks on prompt
def modulate_hook_prompt(sae, steering_feature, block):
def hook_function(*args, **kwargs):
return add_feature_on_text_prompt(
sae,
steering_feature,
*args, **kwargs
)
return hook_function
# Function to load models
def load_models():
try:
# Load the Pipeline
pipe = HookedStableDiffusionXLPipeline.from_pretrained('stabilityai/sdxl-turbo')
pipe.set_progress_bar_config(disable=True)
# Define blocks to save
blocks_to_save = ['text_encoder.text_model.encoder.layers.10', 'text_encoder_2.text_model.encoder.layers.28']
# Load the sparse autoencoder
sae_path = "Checkpoints/dahyecheckpoint"
sae = SparseAutoencoder.load_from_disk(os.path.join(sae_path, 'final'))
return pipe, blocks_to_save, sae
except Exception as e:
print(f"Error loading models: {e}")
return None, None, None
# Function to generate images with activation modulation
def activation_modulation_across_prompt(pipe, sae, blocks_to_save, steer_prompt, strength, prompt, guidance_scale, num_inference_steps, seed):
# Generate steering feature
output, cache = pipe.run_with_cache(
steer_prompt,
positions_to_cache=blocks_to_save,
save_input=True,
save_output=True,
num_inference_steps=1,
guidance_scale=guidance_scale,
generator=torch.Generator(device="cpu").manual_seed(seed)
)
diff = torch.cat([cache['output'][blocks_to_save[0]], cache['output'][blocks_to_save[1]]], dim=-1)
diff = diff.squeeze(0).squeeze(0)
with torch.no_grad():
activated = sae.encode_without_topk(diff) # [77, 81920]
mask = activated * strength
to_add = mask @ sae.decoder.weight.T
steering_feature = to_add
# Generate image with modulation
output = pipe.run_with_hooks(
prompt,
position_hook_dict = {
block: modulate_hook_prompt(sae, steering_feature, block)
for block in blocks_to_save
},
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=torch.Generator(device="cpu").manual_seed(seed)
)
return output.images[0]
# Function to generate images for the Gradio app
def generate_comparison(prompt, steer_prompt, strength, seed, guidance_scale, steps):
if pipe is None or sae is None or blocks_to_save is None:
return Image.new('RGB', (512, 512), color='red'), Image.new('RGB', (512, 512), color='red'), "Error: Models failed to load"
try:
# Generate image with standard model (strength = 0)
standard_image = pipe(
prompt,
num_inference_steps=steps,
guidance_scale=guidance_scale,
generator=torch.Generator(device="cpu").manual_seed(seed)
).images[0]
# Generate image with activation modulation
if strength > 0:
modified_image = activation_modulation_across_prompt(
pipe, sae, blocks_to_save,
steer_prompt, strength, prompt,
guidance_scale, steps, seed
)
else:
# If strength is 0, just return the standard image again to avoid redundant computation
modified_image = standard_image
comparison_message = f"Generated images with modulation strength: {strength}"
return standard_image, modified_image, comparison_message
except Exception as e:
error_image = Image.new('RGB', (512, 512), color='red')
return error_image, error_image, f"Error during generation: {str(e)}"
# Load the models at startup
print("Loading models...")
pipe, blocks_to_save, sae = load_models()
if pipe is not None:
print("Models loaded successfully!")
else:
print("Failed to load models")
# Define the Gradio interface
with gr.Blocks(title="SDXL Activation Modulation") as app:
gr.Markdown("# SDXL Activation Modulation Comparison")
gr.Markdown("""
This app demonstrates activation modulation in Stable Diffusion XL using sparse autoencoders.
It compares standard SDXL-Turbo outputs with modulated outputs that can steer the generation based on a separate concept.
""")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", placeholder="Enter your main image prompt here...", value="A photo of a tree")
steer_prompt = gr.Textbox(label="Steering Prompt", placeholder="Enter concept to steer with...", value="tree with autumn leaves")
strength = gr.Slider(minimum=-2.5, maximum=2.5, value=0.8, step=0.05,
label="Modulation Strength (λ)")
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(minimum=0, maximum=2147483647, step=1, value=61730, label="Seed")
guidance_scale = gr.Slider(minimum=0.0, maximum=10.0, value=0.0, step=0.5, label="Guidance Scale")
steps = gr.Slider(minimum=1, maximum=50, value=3, step=1, label="Inference Steps")
generate_btn = gr.Button("Generate Comparison", variant="primary")
status = gr.Textbox(label="Status", interactive=False)
with gr.Row():
standard_output = gr.Image(label="Standard SDXL-Turbo")
modified_output = gr.Image(label="Modulated Output")
gr.Markdown("""
## Examples from the notebook:
- Main prompt: "A photo of a tree" with steering prompt: "tree with autumn leaves"
- Main prompt: "A dog" with steering prompt: "full shot"
- Main prompt: "A car" with steering prompt: "A blue car"
""")
with gr.Row():
example1 = gr.Button("Example 1: Tree with autumn leaves")
example2 = gr.Button("Example 2: Dog with full shot")
example3 = gr.Button("Example 3: Blue car")
# Set up button actions
generate_btn.click(
fn=generate_comparison,
inputs=[prompt, steer_prompt, strength, seed, guidance_scale, steps],
outputs=[standard_output, modified_output, status]
)
# Set up example button click events
example1.click(
fn=lambda: ["A photo of a tree", "tree with autumn leaves", 0.5, 61730, 0.0, 3],
inputs=None,
outputs=[prompt, steer_prompt, strength, seed, guidance_scale, steps]
)
example2.click(
fn=lambda: ["A dog", "full shot", 0.4, 61730, 0.0, 3],
inputs=None,
outputs=[prompt, steer_prompt, strength, seed, guidance_scale, steps]
)
example3.click(
fn=lambda: ["A car", "A blue car", 0.3, 61730, 0.0, 3],
inputs=None,
outputs=[prompt, steer_prompt, strength, seed, guidance_scale, steps]
)
gr.Markdown("""
## How to Use
1. Enter your main prompt (what you want to generate)
2. Enter a steering prompt (concept to influence the generation)
3. Adjust the modulation strength slider (λ) - higher values mean stronger influence
4. Click "Generate Comparison" to see the results side by side
5. Use advanced settings if needed to adjust seed, guidance scale, or steps
## About
This app demonstrates activation modulation using a sparse autoencoder trained on SDXL text encoder layers.
The modulation allows steering the generation toward specific concepts without changing the main prompt.
""")
# Launch the app
if __name__ == "__main__":
app.launch() |