coco-demo / app.py
evanec's picture
Upload 11 files
21f2675 verified
raw
history blame
5.4 kB
import gradio as gr
import torch
from src.utils import load_experiment
from demo.interpretability_demo import generate_rollout_for_demo
from data.transforms import build_coco_transform
# Load Model + Preprocessing
CHECKPOINT = "experiment_21_full_vlm_ceiling/20251127_231732"
device = "cuda" if torch.cuda.is_available() else "cpu"
model, tokenizer, meta, config = load_experiment(CHECKPOINT, device=device)
image_size = config["model"]["image_size"]
preprocess = build_coco_transform(image_size)
def step_prev(step):
return max(step - 1, 0)
def step_next(step, max_step):
return min(step + 1, max_step)
# Backend Logic
def run_full_rollout(img, max_tokens, alpha):
data = generate_rollout_for_demo(
model,
tokenizer,
img,
preprocess,
device=device,
max_new_tokens=max_tokens,
alpha=alpha
)
caption = data["caption"]
avg_rollout = data["avg"]["frames"]
heads_rollout = data["heads"]["frames"]
labels = data["avg"]["labels"]
if len(avg_rollout) == 0:
return caption, None, None, None, 0
max_step = len(avg_rollout) - 1
return caption, avg_rollout[0], avg_rollout, heads_rollout, labels, max_step
def update_display(step, mode, avg_rollout, heads_rollout, labels):
if avg_rollout is None:
return gr.update(visible=True, value=None), "", gr.update(visible=False)
step = max(0, min(step, len(avg_rollout) - 1))
label = labels[step]
if mode == "Averaged":
return (
gr.update(visible=True, value=avg_rollout[step]), # show averaged
label,
gr.update(visible=False) # hide gallery
)
# All Heads mode
frames = heads_rollout[step] # list of PIL images
return (
gr.update(visible=False), # hide averaged
label,
gr.update(visible=True, value=frames) # show gallery
)
# Gradio UI
with gr.Blocks(
title="Team Coco — Image Captioning + Cross-Attention Viz",
css="""
.token-box textarea {
font-size: 22px !important;
line-height: 1.5 !important;
height: 70px !important;
width: 200px !important;
}
"""
) as demo:
gr.Markdown("## Image Captioning + Cross-Attention Visualization")
with gr.Row():
input_img = gr.Image(type="pil", label="Upload Image",value="demo/36384.jpg", scale=0, image_mode="RGB", interactive=True)
with gr.Column():
max_tokens = gr.Slider(1, 64, value=32, step=1, label="Max Tokens")
alpha = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Overlay Transparency")
run_btn = gr.Button("Generate Caption + Heatmaps", variant="primary")
caption_out = gr.Textbox(label="Generated Caption")
mode = gr.Radio(
choices=["Averaged", "All Heads"],
value="Averaged",
label="Attention Heads"
)
step_slider = gr.Slider(
minimum=0,
maximum=0,
value=0,
step=1,
label="Token",
visible=False,
interactive=True
)
with gr.Row():
prev_btn = gr.Button("â—€ Prev")
next_btn = gr.Button("Next â–¶")
gr.Markdown("## Cross-Attention Heatmap")
with gr.Row():
attention_img = gr.Image(
label="Averaged Attention Overlay",
visible=True,
container=False,
scale=1
)
attention_label = gr.Textbox(
label="Token",
interactive=False,
elem_classes=["token-box"],
scale=1
)
head_gallery = gr.Gallery(
label="All Heads",
visible=False,
columns=6,
height="auto"
)
avg_state = gr.State()
heads_state = gr.State()
labels_state = gr.State()
max_step_state = gr.State()
# Run Rollout
run_btn.click(
fn=run_full_rollout,
inputs=[input_img, max_tokens, alpha],
outputs=[
caption_out,
attention_img,
avg_state,
heads_state,
labels_state,
max_step_state
]
).then(
lambda ms: gr.update(visible=True, maximum=ms, value=0),
inputs=max_step_state,
outputs=step_slider
)
# Updates on Step Change
step_slider.change(
fn=update_display,
inputs=[step_slider, mode, avg_state, heads_state, labels_state],
outputs=[attention_img, attention_label, head_gallery]
)
prev_btn.click(step_prev, inputs=step_slider, outputs=step_slider)
next_btn.click(step_next, inputs=[step_slider, max_step_state], outputs=step_slider)
# Updates on Mode Change
mode.change(
fn=update_display,
inputs=[step_slider, mode, avg_state, heads_state, labels_state],
outputs=[attention_img, attention_label, head_gallery]
)
demo.load(
fn=run_full_rollout,
inputs=[input_img, max_tokens, alpha],
outputs=[
caption_out,
attention_img,
avg_state,
heads_state,
labels_state,
max_step_state
]
).then(
lambda ms: gr.update(visible=True, maximum=ms, value=0),
inputs=max_step_state,
outputs=step_slider
)
demo.launch(share=True)