|
import gradio as gr
|
|
import torch
|
|
import matplotlib.pyplot as plt
|
|
import os
|
|
from PIL import Image
|
|
import numpy as np
|
|
|
|
|
|
from SDLens import HookedStableDiffusionXLPipeline
|
|
from training.k_sparse_autoencoder import SparseAutoencoder
|
|
from utils.hooks import add_feature_on_text_prompt
|
|
|
|
|
|
def modulate_hook_prompt(sae, steering_feature, block):
|
|
def hook_function(*args, **kwargs):
|
|
return add_feature_on_text_prompt(
|
|
sae,
|
|
steering_feature,
|
|
*args, **kwargs
|
|
)
|
|
return hook_function
|
|
|
|
|
|
def load_models():
|
|
try:
|
|
|
|
pipe = HookedStableDiffusionXLPipeline.from_pretrained('stabilityai/sdxl-turbo')
|
|
pipe.set_progress_bar_config(disable=True)
|
|
|
|
|
|
blocks_to_save = ['text_encoder.text_model.encoder.layers.10', 'text_encoder_2.text_model.encoder.layers.28']
|
|
|
|
|
|
sae_path = "Checkpoints/dahyecheckpoint"
|
|
sae = SparseAutoencoder.load_from_disk(os.path.join(sae_path, 'final'))
|
|
|
|
return pipe, blocks_to_save, sae
|
|
except Exception as e:
|
|
print(f"Error loading models: {e}")
|
|
return None, None, None
|
|
|
|
|
|
def activation_modulation_across_prompt(pipe, sae, blocks_to_save, steer_prompt, strength, prompt, guidance_scale, num_inference_steps, seed):
|
|
|
|
output, cache = pipe.run_with_cache(
|
|
steer_prompt,
|
|
positions_to_cache=blocks_to_save,
|
|
save_input=True,
|
|
save_output=True,
|
|
num_inference_steps=1,
|
|
guidance_scale=guidance_scale,
|
|
generator=torch.Generator(device="cpu").manual_seed(seed)
|
|
)
|
|
diff = torch.cat([cache['output'][blocks_to_save[0]], cache['output'][blocks_to_save[1]]], dim=-1)
|
|
diff = diff.squeeze(0).squeeze(0)
|
|
|
|
with torch.no_grad():
|
|
activated = sae.encode_without_topk(diff)
|
|
mask = activated * strength
|
|
|
|
to_add = mask @ sae.decoder.weight.T
|
|
steering_feature = to_add
|
|
|
|
|
|
output = pipe.run_with_hooks(
|
|
prompt,
|
|
position_hook_dict = {
|
|
block: modulate_hook_prompt(sae, steering_feature, block)
|
|
for block in blocks_to_save
|
|
},
|
|
num_inference_steps=num_inference_steps,
|
|
guidance_scale=guidance_scale,
|
|
generator=torch.Generator(device="cpu").manual_seed(seed)
|
|
)
|
|
|
|
return output.images[0]
|
|
|
|
|
|
def generate_comparison(prompt, steer_prompt, strength, seed, guidance_scale, steps):
|
|
if pipe is None or sae is None or blocks_to_save is None:
|
|
return Image.new('RGB', (512, 512), color='red'), Image.new('RGB', (512, 512), color='red'), "Error: Models failed to load"
|
|
|
|
try:
|
|
|
|
standard_image = pipe(
|
|
prompt,
|
|
num_inference_steps=steps,
|
|
guidance_scale=guidance_scale,
|
|
generator=torch.Generator(device="cpu").manual_seed(seed)
|
|
).images[0]
|
|
|
|
|
|
if strength > 0:
|
|
modified_image = activation_modulation_across_prompt(
|
|
pipe, sae, blocks_to_save,
|
|
steer_prompt, strength, prompt,
|
|
guidance_scale, steps, seed
|
|
)
|
|
else:
|
|
|
|
modified_image = standard_image
|
|
|
|
comparison_message = f"Generated images with modulation strength: {strength}"
|
|
return standard_image, modified_image, comparison_message
|
|
except Exception as e:
|
|
error_image = Image.new('RGB', (512, 512), color='red')
|
|
return error_image, error_image, f"Error during generation: {str(e)}"
|
|
|
|
|
|
print("Loading models...")
|
|
pipe, blocks_to_save, sae = load_models()
|
|
if pipe is not None:
|
|
print("Models loaded successfully!")
|
|
else:
|
|
print("Failed to load models")
|
|
|
|
|
|
with gr.Blocks(title="SDXL Activation Modulation") as app:
|
|
gr.Markdown("# SDXL Activation Modulation Comparison")
|
|
gr.Markdown("""
|
|
This app demonstrates activation modulation in Stable Diffusion XL using sparse autoencoders.
|
|
It compares standard SDXL-Turbo outputs with modulated outputs that can steer the generation based on a separate concept.
|
|
""")
|
|
|
|
with gr.Row():
|
|
with gr.Column():
|
|
prompt = gr.Textbox(label="Prompt", placeholder="Enter your main image prompt here...", value="A photo of a tree")
|
|
steer_prompt = gr.Textbox(label="Steering Prompt", placeholder="Enter concept to steer with...", value="tree with autumn leaves")
|
|
strength = gr.Slider(minimum=-2.5, maximum=2.5, value=0.8, step=0.05,
|
|
label="Modulation Strength (λ)")
|
|
|
|
with gr.Accordion("Advanced Settings", open=False):
|
|
seed = gr.Slider(minimum=0, maximum=2147483647, step=1, value=61730, label="Seed")
|
|
guidance_scale = gr.Slider(minimum=0.0, maximum=10.0, value=0.0, step=0.5, label="Guidance Scale")
|
|
steps = gr.Slider(minimum=1, maximum=50, value=3, step=1, label="Inference Steps")
|
|
|
|
generate_btn = gr.Button("Generate Comparison", variant="primary")
|
|
status = gr.Textbox(label="Status", interactive=False)
|
|
|
|
with gr.Row():
|
|
standard_output = gr.Image(label="Standard SDXL-Turbo")
|
|
modified_output = gr.Image(label="Modulated Output")
|
|
|
|
gr.Markdown("""
|
|
## Examples from the notebook:
|
|
- Main prompt: "A photo of a tree" with steering prompt: "tree with autumn leaves"
|
|
- Main prompt: "A dog" with steering prompt: "full shot"
|
|
- Main prompt: "A car" with steering prompt: "A blue car"
|
|
""")
|
|
|
|
with gr.Row():
|
|
example1 = gr.Button("Example 1: Tree with autumn leaves")
|
|
example2 = gr.Button("Example 2: Dog with full shot")
|
|
example3 = gr.Button("Example 3: Blue car")
|
|
|
|
|
|
generate_btn.click(
|
|
fn=generate_comparison,
|
|
inputs=[prompt, steer_prompt, strength, seed, guidance_scale, steps],
|
|
outputs=[standard_output, modified_output, status]
|
|
)
|
|
|
|
|
|
example1.click(
|
|
fn=lambda: ["A photo of a tree", "tree with autumn leaves", 0.5, 61730, 0.0, 3],
|
|
inputs=None,
|
|
outputs=[prompt, steer_prompt, strength, seed, guidance_scale, steps]
|
|
)
|
|
|
|
example2.click(
|
|
fn=lambda: ["A dog", "full shot", 0.4, 61730, 0.0, 3],
|
|
inputs=None,
|
|
outputs=[prompt, steer_prompt, strength, seed, guidance_scale, steps]
|
|
)
|
|
|
|
example3.click(
|
|
fn=lambda: ["A car", "A blue car", 0.3, 61730, 0.0, 3],
|
|
inputs=None,
|
|
outputs=[prompt, steer_prompt, strength, seed, guidance_scale, steps]
|
|
)
|
|
|
|
gr.Markdown("""
|
|
## How to Use
|
|
1. Enter your main prompt (what you want to generate)
|
|
2. Enter a steering prompt (concept to influence the generation)
|
|
3. Adjust the modulation strength slider (λ) - higher values mean stronger influence
|
|
4. Click "Generate Comparison" to see the results side by side
|
|
5. Use advanced settings if needed to adjust seed, guidance scale, or steps
|
|
|
|
## About
|
|
This app demonstrates activation modulation using a sparse autoencoder trained on SDXL text encoder layers.
|
|
The modulation allows steering the generation toward specific concepts without changing the main prompt.
|
|
""")
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
app.launch() |