File size: 7,903 Bytes
5fa1af0
57c8491
2ec5753
57c8491
115893a
e5e6c04
 
 
1323bb7
 
 
 
e5e6c04
2ec5753
e5e6c04
 
 
 
 
2ec5753
57c8491
41c94d8
57c8491
2ec5753
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57c8491
2ec5753
 
 
57c8491
2ec5753
 
 
 
 
 
 
 
e5e6c04
 
41c94d8
5fa1af0
e5e6c04
 
55226e5
 
 
 
 
41c94d8
 
 
55226e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5fa1af0
41c94d8
 
 
5fa1af0
41c94d8
 
 
55226e5
41c94d8
 
e5e6c04
 
 
41c94d8
 
55226e5
 
41c94d8
e5e6c04
55226e5
 
 
 
 
 
 
 
 
e5e6c04
 
 
41c94d8
e5e6c04
 
 
2ec5753
41c94d8
 
 
2ec5753
41c94d8
2ec5753
41c94d8
 
 
 
 
 
 
1323bb7
 
41c94d8
1323bb7
5fa1af0
 
 
 
41c94d8
 
 
 
 
 
 
2ec5753
e5e6c04
5fa1af0
 
41c94d8
 
5fa1af0
 
 
 
1323bb7
41c94d8
 
e5e6c04
 
5fa1af0
41c94d8
 
 
 
 
 
 
57c8491
1323bb7
41c94d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ec5753
41c94d8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
import os 
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
import gradio as gr
from transformers import AutoModel, AutoImageProcessor
from PIL import Image
import torch

os.environ["HF_HUB_OFFLINE"] = "0"

# Global state to store loaded model + processors
state = {
    "model_type": None, 
    "model": None,
    "processor": None,
    "repo_id": None,
}

def similarity_heatmap(image): 
    """
    Compute cosine similarity between CLS token and patch tokens
    """
    model, processor = state["model"], state["processor"]
        
    inputs = processor(images=image, return_tensors="pt")
    pixel_values = inputs["pixel_values"].to(model.device)  # shape: (1, 3, H, W)

    # get ViT patch size (from model config)
    patch_size = model.config.patch_size  # usually 16

    # Compute patch grid (needed for resizing later)
    H_patch = pixel_values.shape[2] // patch_size
    W_patch = pixel_values.shape[3] // patch_size

    with torch.no_grad():
        outputs = model(pixel_values)  # last_hidden_state: (1, seq_len, hidden_dim)
        last_hidden_state = outputs.last_hidden_state
    cls_token = last_hidden_state[:, 0, :]  # shape: (1, hidden_dim)
    patch_tokens = last_hidden_state[:, 1:, :]  # shape: (1, num_patches, hidden_dim)

    cls_norm = cls_token / cls_token.norm(dim=-1, keepdim=True)
    patch_norm = patch_tokens / patch_tokens.norm(dim=-1, keepdim=True)

    cos_sim = torch.einsum("bd,bpd->bp", cls_norm, patch_norm)  # shape: (1, num_patches)
    cos_sim = cos_sim.reshape((H_patch, W_patch))
    return np.array(cos_sim) 

def overlay_cosine_grid_on_image(cos_grid: np.ndarray, image: Image.Image, alpha=0.5, colormap="viridis"):
    """
    cos_grid: (H_patch, W_patch) numpy array of cosine similarities
    image: PIL.Image
    alpha: blending factor
    colormap: matplotlib colormap name
    """
    # Normalize cosine values to [0, 1] for colormap
    norm_grid = (cos_grid - cos_grid.min()) / (cos_grid.max() - cos_grid.min() + 1e-8)
    
    # Apply colormap
    cmap = cm.get_cmap(colormap)
    heatmap_rgba = cmap(norm_grid)  # shape: (H_patch, W_patch, 4)
    
    # Convert to RGB 0-255
    heatmap_rgb = (heatmap_rgba[:, :, :3] * 255).astype(np.uint8)
    heatmap_img = Image.fromarray(heatmap_rgb)
    
    # Resize heatmap to match original image size
    heatmap_resized = heatmap_img.resize(image.size, resample=Image.BILINEAR)
    
    # Blend with original image
    blended = Image.blend(image.convert("RGBA"), heatmap_resized.convert("RGBA"), alpha=alpha)
    
    return blended

def load_model(repo_id: str, revision: str = None):
    """
    Load a Hugging Face model + processor from Hub.
    Works with any public repo_id.
    """
    try:
        # Clean up inputs
        repo_id = repo_id.strip()
        if not repo_id:
            return "Please enter a model repo ID"
            
        if revision and revision.strip() == "":
            revision = None
        
        # First try without cache_dir to avoid permission issues
        try:
            model = AutoModel.from_pretrained(
                repo_id, 
                revision=revision,
                trust_remote_code=True,
                use_auth_token=False  # Explicitly no auth for public models
            )
            
            processor = AutoImageProcessor.from_pretrained(
                repo_id, 
                revision=revision,
                trust_remote_code=True,
                use_auth_token=False
            )
        except Exception as e1:
            # If that fails, try with explicit cache directory
            model = AutoModel.from_pretrained(
                repo_id, 
                revision=revision,
                cache_dir="/tmp/model_cache",  # Use /tmp for better permissions
                trust_remote_code=True,
                use_auth_token=False,
                local_files_only=False  # Ensure we can download
            )
            
            processor = AutoImageProcessor.from_pretrained(
                repo_id, 
                revision=revision,
                cache_dir="/tmp/model_cache",
                trust_remote_code=True,
                use_auth_token=False,
                local_files_only=False
            )

        # Move to appropriate device
        device = "cuda" if torch.cuda.is_available() else "cpu"
        model.to(device)
        model.eval()
        
        # Validate it's a Vision Transformer
        if not hasattr(model.config, 'patch_size'):
            return f"Model '{repo_id}' doesn't appear to be a Vision Transformer (no patch_size in config)"
        
        # Update global state
        state["model"] = model
        state["processor"] = processor
        state["repo_id"] = repo_id
        state["model_type"] = "custom"
        
        patch_size = model.config.patch_size
        return f"Successfully loaded ViT model '{repo_id}' (patch size: {patch_size}) on {device}"
        
    except Exception as e:
        error_str = str(e).lower()
        if "repository not found" in error_str or "404" in error_str:
            return f"Repository '{repo_id}' not found. Please check the repo ID."
        elif "connection" in error_str or "network" in error_str or "offline" in error_str:
            return f"Network error: {str(e)}"
        elif "permission" in error_str or "forbidden" in error_str:
            return f"Permission denied. This might be a private repository."
        else:
            return f"Error loading model: {str(e)}"

def display_image(image: Image):
    """
    Simply returns the uploaded image.
    """
    return image

def visualize_cosine_heatmap(image: Image):
    """
    Generate and overlay cosine similarity heatmap on the input image.
    """
    if state["model"] is None:
        return None  # Return None if no model is loaded

    try:
        cos_grid = similarity_heatmap(image)
        blended = overlay_cosine_grid_on_image(cos_grid, image)
        return blended
    except Exception as e:
        print(f"Error generating heatmap: {e}")
        return None

# Gradio UI
with gr.Blocks(title="ViT CLS Visualizer") as demo:
    gr.Markdown("# ViT CLS-Visualizer")
    gr.Markdown(
        "Enter the Hugging Face model repo ID (must be public), upload an image, "
        "and visualize the cosine similarity between the CLS token and patches."
    )
    
    gr.Markdown("### Popular Vision Transformer models to try:")
    gr.Markdown(
        "- `google/vit-base-patch16-224`\n"
        "- `facebook/deit-base-distilled-patch16-224`\n"
        "- `microsoft/dit-base`"
    )

    with gr.Row():
        repo_input = gr.Textbox(
            label="Hugging Face Model Repo ID",
            placeholder="e.g. google/vit-base-patch16-224",
            value="google/vit-base-patch16-224"
        )
        revision_input = gr.Textbox(
            label="Revision (optional)",
            placeholder="branch, tag, or commit hash"
        )
        load_btn = gr.Button("Load Model", variant="primary")
    
    load_status = gr.Textbox(label="Model Status", interactive=False)

    with gr.Row():
        with gr.Column():
            image_input = gr.Image(type="pil", label="Upload Image")
            image_output = gr.Image(label="Uploaded Image")
        
        with gr.Column():
            compute_btn = gr.Button("Compute Heatmap", variant="primary")
            heatmap_output = gr.Image(label="Cosine Similarity Heatmap")

    # Events
    load_btn.click(
        fn=load_model, 
        inputs=[repo_input, revision_input], 
        outputs=load_status
    )
    
    image_input.change(
        fn=display_image, 
        inputs=image_input, 
        outputs=image_output
    )
    
    compute_btn.click(
        fn=visualize_cosine_heatmap, 
        inputs=image_input, 
        outputs=heatmap_output
    )

if __name__ == "__main__":
    demo.launch()