|
|
import torch |
|
|
from PIL import Image |
|
|
from src.interpretability import cross_attention_to_image |
|
|
import numpy as np |
|
|
import matplotlib.cm as cm |
|
|
|
|
|
|
|
|
def resize_for_display(pil_img, max_dim=5000): |
|
|
w, h = pil_img.size |
|
|
if max(w, h) <= max_dim: |
|
|
return pil_img |
|
|
scale = max_dim / max(w, h) |
|
|
new_w = int(w * scale) |
|
|
new_h = int(h * scale) |
|
|
return pil_img.resize((new_w, new_h), Image.LANCZOS) |
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def generate_rollout_for_demo(model, tokenizer, img, preprocess, |
|
|
device="cuda", max_new_tokens=32, alpha=0.45): |
|
|
|
|
|
model.eval() |
|
|
|
|
|
img_tensor = preprocess(img).unsqueeze(0).to(device) |
|
|
|
|
|
vision_out = model.vision_encoder(img_tensor) |
|
|
img_embeds = vision_out["image_embeds"] |
|
|
|
|
|
if img_embeds.dim() == 2: |
|
|
img_embeds = img_embeds.unsqueeze(1) |
|
|
|
|
|
projected = model.projector(img_embeds) |
|
|
|
|
|
decoder_input_ids = torch.tensor( |
|
|
[[model.t5.config.decoder_start_token_id]], device=device |
|
|
) |
|
|
|
|
|
generated_ids = [] |
|
|
avg_frames = [] |
|
|
labels = [] |
|
|
per_head_frames = [] |
|
|
|
|
|
num_heads = None |
|
|
|
|
|
|
|
|
for step in range(max_new_tokens): |
|
|
|
|
|
outputs = model.t5( |
|
|
encoder_outputs=(projected,), |
|
|
decoder_input_ids=decoder_input_ids, |
|
|
output_attentions=True, |
|
|
return_dict=True, |
|
|
) |
|
|
|
|
|
|
|
|
last_cross = outputs.cross_attentions[-1][0] |
|
|
num_heads = last_cross.size(0) |
|
|
|
|
|
|
|
|
attn_avg = last_cross.mean(dim=0) |
|
|
|
|
|
|
|
|
attn_vec = attn_avg[-1] |
|
|
heat_avg = cross_attention_to_image(attn_vec) |
|
|
|
|
|
if isinstance(heat_avg, tuple): |
|
|
heat_avg = heat_avg[0] |
|
|
if isinstance(heat_avg, np.ndarray): |
|
|
heat_avg = Image.fromarray((heat_avg * 255).astype("uint8")) |
|
|
|
|
|
avg_frames.append( |
|
|
overlay_attention_for_demo(img_tensor, heat_avg, alpha=alpha) |
|
|
) |
|
|
|
|
|
head_overlays = [] |
|
|
for h in range(num_heads): |
|
|
attn_vec_h = last_cross[h][-1] |
|
|
hmap = cross_attention_to_image(attn_vec_h) |
|
|
|
|
|
if isinstance(hmap, tuple): |
|
|
hmap = hmap[0] |
|
|
if isinstance(hmap, np.ndarray): |
|
|
hmap = Image.fromarray((hmap * 255).astype("uint8")) |
|
|
|
|
|
head_overlays.append( |
|
|
overlay_attention_for_demo(img_tensor, hmap, alpha=alpha) |
|
|
) |
|
|
|
|
|
per_head_frames.append(head_overlays) |
|
|
|
|
|
|
|
|
next_token = outputs.logits[:, -1, :].argmax(-1) |
|
|
token_str = tokenizer.decode(next_token, skip_special_tokens=True) |
|
|
labels.append(f"Token #{step}: \"{token_str}\"") |
|
|
|
|
|
generated_ids.append(int(next_token)) |
|
|
|
|
|
if next_token.item() == tokenizer.eos_token_id: |
|
|
break |
|
|
|
|
|
decoder_input_ids = torch.cat( |
|
|
[decoder_input_ids, next_token.unsqueeze(0)], dim=1 |
|
|
) |
|
|
|
|
|
|
|
|
caption = tokenizer.decode(generated_ids, skip_special_tokens=True) |
|
|
|
|
|
|
|
|
return { |
|
|
"caption": caption, |
|
|
"avg": { |
|
|
"frames": avg_frames, |
|
|
"labels": labels |
|
|
}, |
|
|
"heads": { |
|
|
"frames": per_head_frames, |
|
|
"labels": labels, |
|
|
"num_heads": num_heads |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
def overlay_attention_for_demo(image_tensor, heatmap, alpha=0.45): |
|
|
|
|
|
img = image_tensor[0].detach().cpu().permute(1, 2, 0).numpy() |
|
|
img = (img - img.min()) / (img.max() - img.min()) |
|
|
img_uint8 = (img * 255).astype("uint8") |
|
|
|
|
|
heatmap = heatmap.resize((img_uint8.shape[1], img_uint8.shape[0]), Image.BILINEAR) |
|
|
heat_np = np.asarray(heatmap).astype("float32") / 255.0 |
|
|
|
|
|
base = Image.fromarray(img_uint8).convert("RGBA") |
|
|
|
|
|
colored = cm.inferno(heat_np) |
|
|
|
|
|
colored_uint8 = (colored * 255).astype("uint8") |
|
|
heat = Image.fromarray(colored_uint8).convert("RGBA") |
|
|
|
|
|
heat.putalpha(int(alpha * 255)) |
|
|
|
|
|
blended = Image.alpha_composite(base, heat) |
|
|
blended = blended.convert("RGB") |
|
|
return blended |