patchsae-demo / app.py
hyesulim's picture
test: try fixing variable error
9fc25a3 verified
raw
history blame
27.1 kB
import gzip
import os
import pickle
from glob import glob
import threading
import psutil
from functools import lru_cache
import concurrent.futures
from typing import Dict, Tuple, List, Optional
from time import sleep
import gradio as gr
import numpy as np
import torch
from PIL import Image, ImageDraw
import plotly.graph_objects as go
from plotly.subplots import make_subplots
# Constants
IMAGE_SIZE = 400
DATASET_LIST = ["imagenet", "oxford_flowers", "ucf101", "caltech101", "dtd", "eurosat"]
GRID_NUM = 14
PKL_ROOT = "./data/out"
# Global cache with better type hints and error handling
class Cache:
def __init__(self):
self.data: Dict[str, Dict] = {
'data_dict': {},
'sae_data_dict': {},
'model_data': {},
'segmasks': {},
'top_images': {},
'precomputed_activations': {}
}
def get(self, category: str, key: str, default=None):
try:
return self.data[category].get(key, default)
except KeyError:
return default
def set(self, category: str, key: str, value):
try:
self.data[category][key] = value
except KeyError:
self.data[category] = {key: value}
def clear_category(self, category: str):
if category in self.data:
self.data[category].clear()
_CACHE = Cache()
def load_all_data(image_root: str, pkl_root: str) -> Tuple[Dict, Dict]:
"""Load all data with optimized parallel processing."""
def load_image_file(image_file: str) -> Optional[Dict]:
try:
image = Image.open(image_file).resize((IMAGE_SIZE, IMAGE_SIZE))
return {
"image": image,
"image_path": image_file,
}
except Exception as e:
print(f"Error loading image {image_file}: {e}")
return None
# Load images in parallel
with concurrent.futures.ThreadPoolExecutor() as executor:
future_to_file = {
executor.submit(load_image_file, image_file): image_file
for image_file in glob(f"{image_root}/*")
}
for future in concurrent.futures.as_completed(future_to_file):
try:
image_file = future_to_file[future]
image_name = os.path.basename(image_file).split(".")[0]
result = future.result()
if result:
_CACHE.set('data_dict', image_name, result)
except Exception as e:
print(f"Error processing image future: {e}")
# Load SAE data
try:
with open("./data/sae_data/mean_acts.pkl", "rb") as f:
_CACHE.set('sae_data_dict', "mean_acts", pickle.load(f))
except Exception as e:
print(f"Error loading mean_acts.pkl: {e}")
# Load mean act values
datasets = ["imagenet", "imagenet-sketch", "caltech101"]
for dataset in datasets:
try:
with gzip.open(f"./data/sae_data/mean_act_values_{dataset}.pkl.gz", "rb") as f:
if "mean_act_values" not in _CACHE.data['sae_data_dict']:
_CACHE.set('sae_data_dict', "mean_act_values", {})
_CACHE.data['sae_data_dict']["mean_act_values"][dataset] = pickle.load(f)
except Exception as e:
print(f"Error loading mean act values for {dataset}: {e}")
return _CACHE.data['data_dict'], _CACHE.data['sae_data_dict']
@lru_cache(maxsize=1024)
def get_data(image_name: str, model_name: str) -> np.ndarray:
"""Get model data with caching."""
cache_key = f"{model_name}_{image_name}"
if cache_key not in _CACHE.data['model_data']:
try:
data_dir = f"{PKL_ROOT}/{model_name}/{image_name}.pkl.gz"
with gzip.open(data_dir, "rb") as f:
_CACHE.data['model_data'][cache_key] = pickle.load(f)
except Exception as e:
print(f"Error loading model data for {cache_key}: {e}")
return np.array([])
return _CACHE.data['model_data'][cache_key]
@lru_cache(maxsize=1024)
def get_activation_distribution(image_name: str, model_type: str) -> np.ndarray:
"""Get activation distribution with memory optimization."""
try:
data = get_data(image_name, model_type)
if isinstance(data, (list, tuple)):
activation = data[0]
else:
activation = data
if not isinstance(activation, np.ndarray):
activation = np.array(activation)
mean_acts = _CACHE.get('sae_data_dict', "mean_acts", {}).get("imagenet", np.array([]))
if mean_acts.size > 0 and activation.size > 0:
noisy_features_indices = np.where(mean_acts > 0.1)[0]
if activation.ndim >= 2:
activation[:, noisy_features_indices] = 0
return activation
except Exception as e:
print(f"Error getting activation distribution: {e}")
return np.array([])
def get_grid_loc(evt: gr.SelectData, image: Image.Image) -> Tuple[int, int, int, int]:
"""Get grid location from click event."""
x, y = evt.index[0], evt.index[1]
cell_width = image.width // GRID_NUM
cell_height = image.height // GRID_NUM
grid_x = x // cell_width
grid_y = y // cell_height
return grid_x, grid_y, cell_width, cell_height
def highlight_grid(evt: gr.SelectData, image_name: str) -> Image.Image:
"""Highlight selected grid cell."""
image = _CACHE.get('data_dict', image_name, {}).get("image")
if not image:
return None
grid_x, grid_y, cell_width, cell_height = get_grid_loc(evt, image)
highlighted_image = image.copy()
draw = ImageDraw.Draw(highlighted_image)
box = [
grid_x * cell_width,
grid_y * cell_height,
(grid_x + 1) * cell_width,
(grid_y + 1) * cell_height,
]
draw.rectangle(box, outline="red", width=3)
return highlighted_image
def plot_activations(
all_activation: np.ndarray,
tile_activations: Optional[np.ndarray] = None,
grid_x: Optional[int] = None,
grid_y: Optional[int] = None,
top_k: int = 5,
colors: Tuple[str, str] = ("blue", "cyan"),
model_name: str = "CLIP",
) -> go.Figure:
"""Plot activation distributions."""
fig = go.Figure()
def _add_scatter_with_annotation(fig, activations, model_name, color, label):
fig.add_trace(
go.Scatter(
x=np.arange(len(activations)),
y=activations,
mode="lines",
name=label,
line=dict(color=color, dash="solid"),
showlegend=True,
)
)
top_neurons = np.argsort(activations)[::-1][:top_k]
for idx in top_neurons:
fig.add_annotation(
x=idx,
y=activations[idx],
text=str(idx),
showarrow=True,
arrowhead=2,
ax=0,
ay=-15,
arrowcolor=color,
opacity=0.7,
)
return fig
label = f"{model_name.split('-')[-1]} Image-level"
fig = _add_scatter_with_annotation(fig, all_activation, model_name, colors[0], label)
if tile_activations is not None:
label = f"{model_name.split('-')[-1]} Tile ({grid_x}, {grid_y})"
fig = _add_scatter_with_annotation(fig, tile_activations, model_name, colors[1], label)
fig.update_layout(
title="Activation Distribution",
xaxis_title="SAE latent index",
yaxis_title="Activation Value",
template="plotly_white",
legend=dict(orientation="h", yanchor="middle", y=0.5, xanchor="center", x=0.5)
)
return fig
def get_segmask(selected_image: str, slider_value: int, model_type: str) -> Optional[np.ndarray]:
"""Get segmentation mask with caching."""
cache_key = f"{selected_image}_{slider_value}_{model_type}"
cached_mask = _CACHE.get('segmasks', cache_key)
if cached_mask is not None:
return cached_mask
try:
image = _CACHE.get('data_dict', selected_image, {}).get("image")
if image is None:
return None
sae_act = get_data(selected_image, model_type)[0]
temp = sae_act[:, slider_value]
mask = torch.tensor(temp[1:].reshape(14, 14)).view(1, 1, 14, 14)
mask = torch.nn.functional.interpolate(mask, (image.height, image.width))[0][0].numpy()
if mask.size == 0:
return None
mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-10)
base_opacity = 30
image_array = np.array(image)[..., :3]
rgba_overlay = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
rgba_overlay[..., :3] = image_array
darkened_image = (image_array * (base_opacity / 255)).astype(np.uint8)
rgba_overlay[mask == 0, :3] = darkened_image[mask == 0]
rgba_overlay[..., 3] = 255
_CACHE.set('segmasks', cache_key, rgba_overlay)
return rgba_overlay
except Exception as e:
print(f"Error generating segmentation mask: {e}")
return None
def get_top_images(slider_value: int, toggle_btn: bool) -> List[Image.Image]:
"""Get top images with caching."""
cache_key = f"{slider_value}_{toggle_btn}"
cached_images = _CACHE.get('top_images', cache_key)
if cached_images is not None:
return cached_images
dataset_path = "./data/top_images_masked" if toggle_btn else "./data/top_images"
paths = [
os.path.join(dataset_path, dataset, f"{slider_value}.jpg")
for dataset in ["imagenet", "imagenet-sketch", "caltech101"]
]
images = [
Image.open(path) if os.path.exists(path) else Image.new("RGB", (256, 256), (255, 255, 255))
for path in paths
]
_CACHE.set('top_images', cache_key, images)
return images
# UI Event Handlers
def plot_activation_distribution(
evt: Optional[gr.EventData],
selected_image: str,
model_name: str
) -> go.Figure:
"""Plot activation distributions for both models."""
fig = make_subplots(
rows=2,
cols=1,
shared_xaxes=True,
subplot_titles=["CLIP Activation", f"{model_name} Activation"],
)
def get_activations(evt, selected_image, model_name, colors):
activation = get_activation_distribution(selected_image, model_name)
all_activation = activation.mean(0)
tile_activations = None
grid_x = None
grid_y = None
if evt is not None and evt._data is not None:
image = _CACHE.get('data_dict', selected_image, {}).get("image")
if image:
grid_x, grid_y, _, _ = get_grid_loc(evt, image)
token_idx = grid_y * GRID_NUM + grid_x + 1
tile_activations = activation[token_idx]
return plot_activations(
all_activation,
tile_activations,
grid_x,
grid_y,
top_k=5,
model_name=model_name,
colors=colors,
)
fig_clip = get_activations(evt, selected_image, "CLIP", colors=("#00b4d8", "#90e0ef"))
fig_maple = get_activations(evt, selected_image, model_name, colors=("#ff5a5f", "#ffcad4"))
def _attach_fig(fig, sub_fig, row, col, yref):
for trace in sub_fig.data:
fig.add_trace(trace, row=row, col=col)
for annotation in sub_fig.layout.annotations:
annotation.update(yref=yref)
fig.add_annotation(annotation)
return fig
fig = _attach_fig(fig, fig_clip, row=1, col=1, yref="y1")
fig = _attach_fig(fig, fig_maple, row=2, col=1, yref="y2")
fig.update_xaxes(title_text="SAE Latent Index", row=2, col=1)
fig.update_xaxes(title_text="SAE Latent Index", row=1, col=1)
fig.update_yaxes(title_text="Activation Value", row=1, col=1)
fig.update_yaxes(title_text="Activation Value", row=2, col=1)
fig.update_layout(
template="plotly_white",
showlegend=True,
legend=dict(orientation="h", yanchor="bottom", y=-0.2, xanchor="center", x=0.5),
margin=dict(l=20, r=20, t=40, b=20),
)
return fig
def show_activation_heatmap_clip(
selected_image: str,
slider_value: str,
toggle_btn: bool
):
"""Show activation heatmap for CLIP model."""
rgba_overlay, top_images, act_values = show_activation_heatmap(
selected_image, slider_value, "CLIP", toggle_btn
)
sleep(0.1)
return (
rgba_overlay,
top_images[0],
top_images[1],
top_images[2],
act_values[0],
act_values[1],
act_values[2],
)
def show_activation_heatmap(
selected_image: str,
slider_value: str,
model_type: str,
toggle_btn: bool = False
) -> Tuple[np.ndarray, List[Image.Image], List[str]]:
"""Show activation heatmap with segmentation mask and top images."""
slider_value = int(slider_value.split("-")[-1])
rgba_overlay = get_segmask(selected_image, slider_value, model_type)
top_images = get_top_images(slider_value, toggle_btn)
act_values = []
for dataset in ["imagenet", "imagenet-sketch", "caltech101"]:
act_value = _CACHE.get('sae_data_dict', "mean_act_values", {}).get(dataset, np.array([]))[slider_value, :5]
act_value = [str(round(value, 3)) for value in act_value]
act_value = " | ".join(act_value)
out = f"#### Activation values: {act_value}"
act_values.append(out)
return rgba_overlay, top_images, act_values
def show_activation_heatmap_maple(
selected_image: str,
slider_value: str,
model_name: str
) -> np.ndarray:
"""Show activation heatmap for MaPLE model."""
slider_value = int(slider_value.split("-")[-1])
rgba_overlay = get_segmask(selected_image, slider_value, model_name)
sleep(0.1)
return rgba_overlay
def get_init_radio_options(selected_image: str, model_name: str) -> List[str]:
"""Get initial radio options for UI."""
clip_neuron_dict = {}
maple_neuron_dict = {}
def _get_top_activation(selected_image: str, model_name: str, neuron_dict: Dict, top_k: int = 5) -> Dict:
activations = get_activation_distribution(selected_image, model_name).mean(0)
top_neurons = list(np.argsort(activations)[::-1][:top_k])
for top_neuron in top_neurons:
neuron_dict[top_neuron] = activations[top_neuron]
return dict(sorted(neuron_dict.items(), key=lambda item: item[1], reverse=True))
clip_neuron_dict = _get_top_activation(selected_image, "CLIP", clip_neuron_dict)
maple_neuron_dict = _get_top_activation(selected_image, model_name, maple_neuron_dict)
return get_radio_names(clip_neuron_dict, maple_neuron_dict)
def get_radio_names(
clip_neuron_dict: Dict[int, float],
maple_neuron_dict: Dict[int, float]
) -> List[str]:
"""Generate radio button names based on neuron activations."""
clip_keys = list(clip_neuron_dict.keys())
maple_keys = list(maple_neuron_dict.keys())
common_keys = list(set(clip_keys).intersection(set(maple_keys)))
clip_only_keys = list(set(clip_keys) - set(maple_keys))
maple_only_keys = list(set(maple_keys) - set(clip_keys))
common_keys.sort(key=lambda x: max(clip_neuron_dict[x], maple_neuron_dict[x]), reverse=True)
clip_only_keys.sort(reverse=True)
maple_only_keys.sort(reverse=True)
out = []
out.extend([f"common-{i}" for i in common_keys[:5]])
out.extend([f"CLIP-{i}" for i in clip_only_keys[:5]])
out.extend([f"MaPLE-{i}" for i in maple_only_keys[:5]])
return out
def update_radio_options(
evt: Optional[gr.EventData],
selected_image: str,
model_name: str
) -> gr.Radio:
"""Update radio options based on user interaction."""
clip_neuron_dict = {}
maple_neuron_dict = {}
def _get_top_activation(evt, selected_image, model_name, neuron_dict):
all_activation = get_activation_distribution(selected_image, model_name)
image_activation = all_activation.mean(0)
top_neurons = list(np.argsort(image_activation)[::-1][:5])
for top_neuron in top_neurons:
neuron_dict[top_neuron] = image_activation[top_neuron]
if evt is not None and evt._data is not None and isinstance(evt._data["index"], list):
image = _CACHE.get('data_dict', selected_image, {}).get("image")
if image:
grid_x, grid_y, _, _ = get_grid_loc(evt, image)
token_idx = grid_y * GRID_NUM + grid_x + 1
tile_activations = all_activation[token_idx]
top_tile_neurons = list(np.argsort(tile_activations)[::-1][:5])
for top_neuron in top_tile_neurons:
neuron_dict[top_neuron] = tile_activations[top_neuron]
return dict(sorted(neuron_dict.items(), key=lambda item: item[1], reverse=True))
clip_neuron_dict = _get_top_activation(evt, selected_image, "CLIP", clip_neuron_dict)
maple_neuron_dict = _get_top_activation(evt, selected_image, model_name, maple_neuron_dict)
radio_choices = get_radio_names(clip_neuron_dict, maple_neuron_dict)
return gr.Radio(choices=radio_choices, label="Top activating SAE latent", value=radio_choices[0])
def update_markdown(option_value: str) -> Tuple[str, str]:
"""Update markdown text based on selected option."""
latent_idx = int(option_value.split("-")[-1])
out_1 = f"## Segmentation mask for the selected SAE latent - {latent_idx}"
out_2 = f"## Top reference images for the selected SAE latent - {latent_idx}"
return out_1, out_2
def update_all(
selected_image: str,
slider_value: str,
toggle_btn: bool,
model_name: str
) -> Tuple:
"""Update all UI components."""
(
seg_mask_display,
top_image_1,
top_image_2,
top_image_3,
act_value_1,
act_value_2,
act_value_3,
) = show_activation_heatmap_clip(selected_image, slider_value, toggle_btn)
seg_mask_display_maple = show_activation_heatmap_maple(
selected_image, slider_value, model_name
)
markdown_display, markdown_display_2 = update_markdown(slider_value)
return (
seg_mask_display,
seg_mask_display_maple,
top_image_1,
top_image_2,
top_image_3,
act_value_1,
act_value_2,
act_value_3,
markdown_display,
markdown_display_2,
)
def monitor_memory_usage():
"""Monitor memory usage and clean cache if necessary."""
process = psutil.Process()
mem_info = process.memory_info()
mem_percent = process.memory_percent()
print(f"""
Memory Usage:
- RSS: {mem_info.rss / (1024**2):.2f} MB
- VMS: {mem_info.vms / (1024**2):.2f} MB
- Percent: {mem_percent:.1f}%
- Cache Sizes: {[len(cache) for cache in _CACHE.data.values()]}
""")
if mem_percent > 80:
print("Memory usage too high, clearing caches...")
_CACHE.clear_category('segmasks')
_CACHE.clear_category('top_images')
_CACHE.clear_category('precomputed_activations')
def start_memory_monitor(interval: int = 300):
"""Start periodic memory monitoring."""
monitor_memory_usage()
threading.Timer(interval, start_memory_monitor).start()
# Initialize the application
data_dict, sae_data_dict = load_all_data(image_root="./data/image", pkl_root=PKL_ROOT)
default_image_name = "christmas-imagenet"
# Create the Gradio interface
with gr.Blocks(
theme=gr.themes.Citrus(),
css="""
.image-row .gr-image { margin: 0 !important; padding: 0 !important; }
.image-row img { width: auto; height: 50px; }
""",
) as demo:
with gr.Row():
with gr.Column():
gr.Markdown("## Select input image and patch on the image")
image_selector = gr.Dropdown(
choices=list(_CACHE.data['data_dict'].keys()),
value=default_image_name,
label="Select Image",
)
image_display = gr.Image(
value=_CACHE.get('data_dict', default_image_name, {}).get("image"),
type="pil",
interactive=True,
)
image_selector.change(
fn=lambda img_name: _CACHE.get('data_dict', img_name, {}).get("image"),
inputs=image_selector,
outputs=image_display,
)
image_display.select(
fn=highlight_grid,
inputs=[image_selector],
outputs=[image_display]
)
with gr.Column():
gr.Markdown("## SAE latent activations of CLIP and MaPLE")
model_options = [f"MaPLE-{dataset_name}" for dataset_name in DATASET_LIST]
model_selector = gr.Dropdown(
choices=model_options,
value=model_options[0],
label="Select adapted model (MaPLe)",
)
init_plot = plot_activation_distribution(None, default_image_name, model_options[0])
neuron_plot = gr.Plot(value=init_plot, show_label=False)
image_selector.change(
fn=plot_activation_distribution,
inputs=[image_selector, model_selector],
outputs=neuron_plot,
)
image_display.select(
fn=plot_activation_distribution,
inputs=[image_selector, model_selector],
outputs=neuron_plot,
)
model_selector.change(
fn=lambda img_name: _CACHE.get('data_dict', img_name, {}).get("image"),
inputs=[image_selector],
outputs=image_display,
)
model_selector.change(
fn=plot_activation_distribution,
inputs=[image_selector, model_selector],
outputs=neuron_plot,
)
with gr.Row():
with gr.Column():
radio_names = get_init_radio_options(default_image_name, model_options[0])
feature_idx = radio_names[0].split("-")[-1]
markdown_display = gr.Markdown(
f"## Segmentation mask for the selected SAE latent - {feature_idx}"
)
init_seg, init_tops, init_values = show_activation_heatmap(
default_image_name, radio_names[0], "CLIP"
)
gr.Markdown("### Localize SAE latent activation using CLIP")
seg_mask_display = gr.Image(value=init_seg, type="pil", show_label=False)
init_seg_maple, _, _ = show_activation_heatmap(
default_image_name, radio_names[0], model_options[0]
)
gr.Markdown("### Localize SAE latent activation using MaPLE")
seg_mask_display_maple = gr.Image(value=init_seg_maple, type="pil", show_label=False)
with gr.Column():
gr.Markdown("## Top activating SAE latent index")
radio_choices = gr.Radio(
choices=radio_names,
label="Top activating SAE latent",
interactive=True,
value=radio_names[0],
)
toggle_btn = gr.Checkbox(label="Show segmentation mask", value=False)
markdown_display_2 = gr.Markdown(
f"## Top reference images for the selected SAE latent - {feature_idx}"
)
gr.Markdown("### ImageNet")
top_image_1 = gr.Image(value=init_tops[0], type="pil", show_label=False)
act_value_1 = gr.Markdown(init_values[0])
gr.Markdown("### ImageNet-Sketch")
top_image_2 = gr.Image(value=init_tops[1], type="pil", show_label=False)
act_value_2 = gr.Markdown(init_values[1])
gr.Markdown("### Caltech101")
top_image_3 = gr.Image(value=init_tops[2], type="pil", show_label=False)
act_value_3 = gr.Markdown(init_values[2])
# Event handlers
image_display.select(
fn=update_radio_options,
inputs=[image_selector, model_selector],
outputs=[radio_choices],
)
model_selector.change(
fn=update_radio_options,
inputs=[image_selector, model_selector],
outputs=[radio_choices],
)
image_selector.select(
fn=update_radio_options,
inputs=[image_selector, model_selector],
outputs=[radio_choices],
)
radio_choices.change(
fn=update_all,
inputs=[image_selector, radio_choices, toggle_btn, model_selector],
outputs=[
seg_mask_display,
seg_mask_display_maple,
top_image_1,
top_image_2,
top_image_3,
act_value_1,
act_value_2,
act_value_3,
markdown_display,
markdown_display_2,
],
)
toggle_btn.change(
fn=show_activation_heatmap_clip,
inputs=[image_selector, radio_choices, toggle_btn],
outputs=[
seg_mask_display,
top_image_1,
top_image_2,
top_image_3,
act_value_1,
act_value_2,
act_value_3,
],
)
if __name__ == "__main__":
# Initialize memory monitoring
start_memory_monitor()
# Get system memory info
mem = psutil.virtual_memory()
total_ram_gb = mem.total / (1024**3)
try:
print("Starting application initialization...")
# Precompute common data
print("Precomputing activation patterns...")
for image_name in _CACHE.data['data_dict'].keys():
for model_name in ["CLIP"] + [f"MaPLE-{ds}" for ds in DATASET_LIST]:
try:
activation = get_activation_distribution(image_name, model_name)
cache_key = f"activation_{model_name}_{image_name}"
_CACHE.set('precomputed_activations', cache_key, activation.mean(0))
except Exception as e:
print(f"Error precomputing activation for {image_name}, {model_name}: {e}")
print("Starting Gradio interface...")
# Launch the app with optimized settings
demo.queue(max_size=min(20, int(total_ram_gb)))
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True,
max_threads=min(16, psutil.cpu_count())
)
except Exception as e:
print(f"Critical error during startup: {e}")
# Attempt to clean up resources
_CACHE.data.clear()
raise