taesiri's picture
Update
942ca77
import spaces
from typing import List, Tuple, Dict
from collections import OrderedDict
import gradio as gr
import torch
import torch.nn.functional as F
import timm
from timm.data import create_transform
from timm.models import create_model
from timm.utils import AttentionExtract
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
def get_attention_models() -> List[str]:
"""Get a list of timm models that have attention blocks."""
all_models = timm.list_pretrained()
# FIXME Focusing on ViT models for initial impl
attention_models = [model for model in all_models if any([model.lower().startswith(p) for p in ('vit', 'deit', 'beit', 'eva')])]
return attention_models
def load_model(model_name: str) -> Tuple[torch.nn.Module, AttentionExtract]:
"""Load a model from timm and prepare it for attention extraction."""
timm.layers.set_fused_attn(False)
model = create_model(model_name, pretrained=True)
model = model.cuda() # Move the model to CUDA
model.eval()
extractor = AttentionExtract(model, method='fx')
return model, extractor
@spaces.GPU
def process_image(
image: Image.Image,
model: torch.nn.Module,
extractor: AttentionExtract
) -> Dict[str, torch.Tensor]:
"""Process the input image and get the attention maps."""
# Get the correct transform for the model
config = model.pretrained_cfg
transform = create_transform(
input_size=config['input_size'],
crop_pct=config['crop_pct'],
mean=config['mean'],
std=config['std'],
interpolation=config['interpolation'],
is_training=False
)
# Preprocess the image and move to CUDA
tensor = transform(image).unsqueeze(0).cuda()
# Extract attention maps
attention_maps = extractor(tensor)
return attention_maps
def apply_mask(image: np.ndarray, mask: np.ndarray, color: Tuple[float, float, float], alpha: float = 0.5) -> np.ndarray:
# Ensure mask and image have the same shape
mask = mask[:, :, np.newaxis]
mask = np.repeat(mask, 3, axis=2)
# Convert color to numpy array
color = np.array(color)
# Apply mask
masked_image = image * (1 - alpha * mask) + alpha * mask * color[np.newaxis, np.newaxis, :] * 255
return masked_image.astype(np.uint8)
def rollout(attentions, discard_ratio, head_fusion, num_prefix_tokens=1):
device = attentions[0].device
result = torch.eye(attentions[0].size(-1)).to(device)
with torch.no_grad():
for attention in attentions:
if head_fusion.startswith('mean'):
attention_heads_fused = attention.mean(dim=0)
elif head_fusion == "max":
attention_heads_fused = attention.amax(dim=0)
elif head_fusion == "min":
attention_heads_fused = attention.amin(dim=0)
else:
raise ValueError("Attention head fusion type Not supported")
# Discard the lowest attentions, but don't discard the prefix tokens
flat = attention_heads_fused.view(-1)
_, indices = flat.topk(int(flat.size(-1) * discard_ratio), -1, False)
indices = indices[indices >= num_prefix_tokens]
flat[indices] = 0
I = torch.eye(attention_heads_fused.size(-1)).to(device)
a = (attention_heads_fused + 1.0 * I) / 2
a = a / a.sum(dim=-1)
result = torch.matmul(a, result)
# Look at the total attention between the prefix tokens (usually class tokens)
# and the image patches
mask = result[0, num_prefix_tokens:]
width = int(mask.size(-1) ** 0.5)
mask = mask.reshape(width, width).cpu().numpy()
mask = mask / np.max(mask)
return mask
@spaces.GPU
def visualize_attention(
image: Image.Image,
model_name: str,
head_fusion: str,
discard_ratio: float,
) -> Tuple[List[Image.Image], Image.Image]:
"""Visualize attention maps and rollout for the given image and model."""
model, extractor = load_model(model_name)
attention_maps = process_image(image, model, extractor)
num_prefix_tokens = getattr(model, 'num_prefix_tokens', 1) # Default to 1 class token if not specified
# Convert PIL Image to numpy array
image_np = np.array(image)
# Create visualizations
visualizations = []
attentions_for_rollout = []
for layer_name, attn_map in attention_maps.items():
print(f"Attention map shape for {layer_name}: {attn_map.shape}")
attn_map = attn_map[0].detach() # Remove batch dimension and detach
attentions_for_rollout.append(attn_map)
attn_map = attn_map[:, :, num_prefix_tokens:] # Remove prefix tokens for visualization
if head_fusion == 'mean_std':
attn_map = attn_map.mean(0) / attn_map.std(0)
elif head_fusion == 'mean':
attn_map = attn_map.mean(0)
elif head_fusion == 'max':
attn_map = attn_map.amax(0)
elif head_fusion == 'min':
attn_map = attn_map.amin(0)
else:
raise ValueError(f"Invalid head fusion method: {head_fusion}")
# Use the first token's attention (usually the class token)
attn_map = attn_map[0]
# Reshape the attention map to 2D
num_patches = int(attn_map.shape[0] ** 0.5)
attn_map = attn_map.reshape(num_patches, num_patches)
# Interpolate to match image size
attn_map = attn_map.unsqueeze(0).unsqueeze(0)
attn_map = F.interpolate(attn_map, size=(image_np.shape[0], image_np.shape[1]), mode='bilinear', align_corners=False)
attn_map = attn_map.squeeze().cpu().detach().numpy() # Move to CPU, detach, and convert to numpy
# Normalize attention map
attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min())
# Create visualization
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))
# Original image
ax1.imshow(image_np)
ax1.set_title("Original Image")
ax1.axis('off')
# Attention map overlay
masked_image = apply_mask(image_np, attn_map, color=(1, 0, 0)) # Red mask
ax2.imshow(masked_image)
ax2.set_title(f'Attention Map for {layer_name}')
ax2.axis('off')
plt.tight_layout()
# Convert plot to image
fig.canvas.draw()
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
vis_image = Image.fromarray(data)
visualizations.append(vis_image)
plt.close(fig)
# Ensure tensors are on CPU and detached before converting to numpy
attentions_for_rollout = [attn.cpu().detach() for attn in attentions_for_rollout]
# Calculate rollout
rollout_mask = rollout(attentions_for_rollout, discard_ratio, head_fusion, num_prefix_tokens)
# Create rollout visualization
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))
# Original image
ax1.imshow(image_np)
ax1.set_title("Original Image")
ax1.axis('off')
# Rollout overlay
rollout_mask_pil = Image.fromarray((rollout_mask * 255).astype(np.uint8))
rollout_mask_resized = np.array(rollout_mask_pil.resize((image_np.shape[1], image_np.shape[0]), Image.BICUBIC)) / 255.0
masked_image = apply_mask(image_np, rollout_mask_resized, color=(1, 0, 0)) # Red mask
ax2.imshow(masked_image)
ax2.set_title('Attention Rollout')
ax2.axis('off')
plt.tight_layout()
# Convert plot to image
fig.canvas.draw()
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
rollout_image = Image.fromarray(data)
plt.close(fig)
return visualizations, rollout_image
# Create Gradio interface
iface = gr.Interface(
fn=visualize_attention,
inputs=[
gr.Image(type="pil", label="Input Image"),
gr.Dropdown(choices=get_attention_models(), label="Select Model"),
gr.Dropdown(
choices=['mean_std', 'mean', 'max', 'min'],
label="Head Fusion Method",
value='mean' # Default value
),
gr.Slider(0, 1, 0.9, label="Discard Ratio", info="Ratio of lowest attentions to discard")
],
outputs=[
gr.Gallery(label="Attention Maps"),
gr.Image(label="Attention Rollout")
],
title="Attention Map Visualizer for timm Models",
description="Upload an image and select a timm model to visualize its attention maps."
)
# Launch the interface
iface.launch()