Update app.py
Browse files
app.py
CHANGED
@@ -1,198 +1,160 @@
|
|
1 |
import torch
|
2 |
import gradio as gr
|
3 |
import plotly.express as px
|
4 |
-
import numpy as np
|
5 |
from transformers import AutoModel, AutoTokenizer
|
6 |
|
7 |
########################################
|
8 |
-
#
|
9 |
########################################
|
10 |
model_name = "distilbert-base-uncased"
|
11 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
12 |
model = AutoModel.from_pretrained(model_name, output_attentions=True)
|
13 |
model.eval()
|
14 |
|
15 |
-
|
16 |
-
# 2) Generate attention analysis
|
17 |
-
########################################
|
18 |
-
def analyze_attention(text, layer=5, top_k=3, show_heatmap=True):
|
19 |
"""
|
20 |
-
1. Tokenize
|
21 |
-
2.
|
22 |
-
3.
|
23 |
-
4.
|
24 |
-
5.
|
25 |
-
6. Create text summary of top-k focuses for each token.
|
26 |
-
7. Generate an "interpretation" to highlight interesting patterns.
|
27 |
"""
|
28 |
-
|
29 |
with torch.no_grad():
|
30 |
inputs = tokenizer.encode_plus(text, return_tensors="pt")
|
31 |
outputs = model(**inputs)
|
32 |
-
all_attentions = outputs.attentions
|
33 |
-
# DistilBERT has 6 layers => valid
|
34 |
-
|
|
|
|
|
|
|
35 |
|
36 |
-
|
37 |
input_ids = inputs["input_ids"][0]
|
38 |
tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
39 |
-
seq_len = len(tokens)
|
40 |
-
|
41 |
-
# (Optional) Heatmap
|
42 |
-
fig_dict = None
|
43 |
-
if show_heatmap:
|
44 |
-
fig = px.imshow(
|
45 |
-
att_matrix,
|
46 |
-
x=tokens,
|
47 |
-
y=tokens,
|
48 |
-
labels={"x": "Token Being Looked At", "y": "Token Doing the Looking"},
|
49 |
-
color_continuous_scale="Blues",
|
50 |
-
title=f"DistilBERT Self-Attention (Layer {layer})"
|
51 |
-
)
|
52 |
-
fig.update_xaxes(side="top")
|
53 |
-
fig.update_traces(
|
54 |
-
hovertemplate="Row token: %{y}<br>Column token: %{x}<br>Focus Weight: %{z:.3f}"
|
55 |
-
)
|
56 |
-
fig_dict = fig.to_dict()
|
57 |
-
|
58 |
-
# Top-K Summary for each row
|
59 |
-
summary_md = "## Top-K Focus for Each Token\n"
|
60 |
-
summary_md += f"Showing the **top {top_k}** tokens each token focuses on.\n\n"
|
61 |
-
for i in range(seq_len):
|
62 |
-
row_token = tokens[i]
|
63 |
-
row_weights = att_matrix[i]
|
64 |
-
sorted_idx = row_weights.argsort()[::-1]
|
65 |
-
top_indices = sorted_idx[:top_k]
|
66 |
-
|
67 |
-
summary_md += f"**Token '{row_token}'** focuses on:\n"
|
68 |
-
for j in top_indices:
|
69 |
-
col_token = tokens[j]
|
70 |
-
weight = row_weights[j]
|
71 |
-
summary_md += f" - `{col_token}` (weight={weight:.3f})\n"
|
72 |
-
summary_md += "\n"
|
73 |
|
74 |
-
#
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
-
|
|
|
|
|
|
|
|
|
81 |
|
|
|
82 |
|
83 |
-
|
84 |
-
# 3) Interpretation function
|
85 |
-
########################################
|
86 |
-
def interpret_attention(att_matrix: np.ndarray, tokens: list) -> str:
|
87 |
"""
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
- Possibly mention if we see something interesting about distribution.
|
92 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
i, j = max_rc
|
117 |
-
token_i = tokens[i]
|
118 |
-
token_j = tokens[j]
|
119 |
-
global_msg = f"- The **highest single focus** in the matrix is **{max_val:.3f}**, from token '{token_i}' onto '{token_j}'."
|
120 |
-
|
121 |
-
# 3) Possibly some quick ratio
|
122 |
-
# For each row, sum of row vs. sum of diagonal
|
123 |
-
# We'll keep it simpler for now
|
124 |
-
|
125 |
-
interpretation = "## Additional Interpretation\n\n"
|
126 |
-
interpretation += (
|
127 |
-
"Here are some overall patterns in the attention matrix that might help you:\n\n"
|
128 |
)
|
129 |
-
interpretation += f"{diag_msg}\n"
|
130 |
-
interpretation += f"{global_msg}\n"
|
131 |
|
132 |
-
|
133 |
-
|
134 |
-
|
|
|
|
|
|
|
|
|
135 |
)
|
136 |
|
137 |
-
return
|
138 |
|
|
|
|
|
|
|
139 |
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
description_md = """
|
144 |
-
# DistilBERT Self-Attention with Extra Interpretation
|
145 |
-
|
146 |
-
**Instructions:**
|
147 |
-
1. Type your text into the box.
|
148 |
-
2. Choose which **layer** of DistilBERT to visualize. (Layers range 0..5).
|
149 |
-
3. Decide how many top tokens you want listed for each token (Top-K).
|
150 |
-
4. (Optional) Check "Show Heatmap" to see the matrix. If it's too overwhelming, uncheck and just see the summary.
|
151 |
-
|
152 |
-
**Reading the Heatmap**:
|
153 |
-
- **Rows** = tokens doing the looking (focus).
|
154 |
-
- **Columns** = tokens being looked at.
|
155 |
-
- **Color intensity** = how strongly the row token focuses on the column token.
|
156 |
-
|
157 |
-
Below the heatmap, you'll see:
|
158 |
-
- A **Top-K focus** summary for each token.
|
159 |
-
- An **interpretation** bullet list that highlights interesting overall patterns.
|
160 |
-
"""
|
161 |
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
|
|
|
|
|
|
|
166 |
with gr.Blocks() as demo:
|
167 |
-
gr.Markdown(
|
|
|
168 |
|
169 |
with gr.Row():
|
170 |
-
|
171 |
-
label="Enter
|
172 |
value="Transformers handle long-range context in parallel."
|
173 |
)
|
174 |
-
|
175 |
minimum=0, maximum=5, step=1, value=5,
|
176 |
-
label="DistilBERT Layer"
|
177 |
-
)
|
178 |
-
topk_in = gr.Slider(
|
179 |
-
minimum=1, maximum=6, step=1, value=3,
|
180 |
-
label="Top-K Focus"
|
181 |
)
|
182 |
-
|
183 |
-
|
184 |
-
|
|
|
|
|
|
|
|
|
|
|
185 |
)
|
186 |
-
run_btn = gr.Button("Analyze Attention")
|
187 |
|
188 |
-
|
189 |
-
|
|
|
|
|
|
|
|
|
|
|
190 |
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
|
|
|
|
195 |
)
|
196 |
|
197 |
demo.launch()
|
198 |
-
|
|
|
1 |
import torch
|
2 |
import gradio as gr
|
3 |
import plotly.express as px
|
|
|
4 |
from transformers import AutoModel, AutoTokenizer
|
5 |
|
6 |
########################################
|
7 |
+
# Load Transformer (DistilBERT) with attention
|
8 |
########################################
|
9 |
model_name = "distilbert-base-uncased"
|
10 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
11 |
model = AutoModel.from_pretrained(model_name, output_attentions=True)
|
12 |
model.eval()
|
13 |
|
14 |
+
def visualize_attention(text, layer=5):
|
|
|
|
|
|
|
15 |
"""
|
16 |
+
1. Tokenize input text.
|
17 |
+
2. Run DistilBERT forward pass to get attention matrices.
|
18 |
+
3. Pick a layer (0..5) and average across attention heads.
|
19 |
+
4. Generate a heatmap (Plotly) of shape (seq_len x seq_len).
|
20 |
+
5. Label axes with tokens (Query vs. Key).
|
|
|
|
|
21 |
"""
|
|
|
22 |
with torch.no_grad():
|
23 |
inputs = tokenizer.encode_plus(text, return_tensors="pt")
|
24 |
outputs = model(**inputs)
|
25 |
+
all_attentions = outputs.attentions
|
26 |
+
# DistilBERT has 6 layers => valid layer indices: 0..5
|
27 |
+
attn_layer = all_attentions[layer].mean(dim=1) # shape: (1, seq_len, seq_len)
|
28 |
+
|
29 |
+
# Convert to numpy for plotting
|
30 |
+
attn_matrix = attn_layer[0].cpu().numpy()
|
31 |
|
32 |
+
# Get tokens (including special tokens)
|
33 |
input_ids = inputs["input_ids"][0]
|
34 |
tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
+
# Build a Plotly heatmap
|
37 |
+
fig = px.imshow(
|
38 |
+
attn_matrix,
|
39 |
+
x=tokens,
|
40 |
+
y=tokens,
|
41 |
+
labels={"x": "Key (Being Attended to)", "y": "Query (Focusing)"},
|
42 |
+
color_continuous_scale="Blues",
|
43 |
+
title=f"DistilBERT Attention (Layer {layer})"
|
44 |
+
)
|
45 |
+
fig.update_xaxes(side="top")
|
46 |
|
47 |
+
# Add tooltip
|
48 |
+
fig.update_traces(
|
49 |
+
hovertemplate="Query: %{y}<br>Key: %{x}<br>Attention Weight: %{z:.3f}"
|
50 |
+
)
|
51 |
+
fig.update_layout(coloraxis_colorbar=dict(title="Attention Weight"))
|
52 |
|
53 |
+
return fig
|
54 |
|
55 |
+
def interpret_token_attention(text, token_index=0, layer=5):
|
|
|
|
|
|
|
56 |
"""
|
57 |
+
Provides a textual explanation for why a particular token (Query) attends
|
58 |
+
to other tokens in the input, highlighting the top 2 or 3 tokens
|
59 |
+
it focuses on.
|
|
|
60 |
"""
|
61 |
+
with torch.no_grad():
|
62 |
+
inputs = tokenizer.encode_plus(text, return_tensors="pt")
|
63 |
+
outputs = model(**inputs)
|
64 |
+
all_attentions = outputs.attentions
|
65 |
+
attn_layer = all_attentions[layer].mean(dim=1) # shape: (1, seq_len, seq_len)
|
66 |
+
|
67 |
+
# Get tokens
|
68 |
+
input_ids = inputs["input_ids"][0]
|
69 |
+
tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
70 |
|
71 |
+
# Safety check for token_index
|
72 |
+
if token_index < 0 or token_index >= len(tokens):
|
73 |
+
return "Invalid token index. Please choose a valid token index."
|
74 |
+
|
75 |
+
# Extract the row corresponding to our Query token
|
76 |
+
query_attn = attn_layer[0, token_index, :].cpu().numpy() # shape: (seq_len,)
|
77 |
+
|
78 |
+
# Sort tokens by attention weight (descending)
|
79 |
+
sorted_indices = query_attn.argsort()[::-1]
|
80 |
+
top_indices = sorted_indices[:3] # Grab top 3
|
81 |
+
top_tokens = [tokens[i] for i in top_indices]
|
82 |
+
top_weights = [query_attn[i] for i in top_indices]
|
83 |
+
|
84 |
+
# Build an explanation
|
85 |
+
query_token_str = tokens[token_index]
|
86 |
+
explanation = (
|
87 |
+
f"**You chose token index {token_index}, which is '{query_token_str}'.**\n\n"
|
88 |
+
"In Transformers, each token is converted into Query, Key, and Value vectors:\n"
|
89 |
+
"- **Query** = What this token is looking for\n"
|
90 |
+
"- **Key** = What another token has to offer\n"
|
91 |
+
"- **Value** = The actual information from that token\n\n"
|
92 |
+
f"As a Query, '{query_token_str}' attends most strongly to:\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
)
|
|
|
|
|
94 |
|
95 |
+
for t, w in zip(top_tokens, top_weights):
|
96 |
+
explanation += f"- **{t}** with attention weight ~ {w:.3f}\n"
|
97 |
+
|
98 |
+
explanation += (
|
99 |
+
"\nA higher attention weight indicates that this Query token is 'looking at' or "
|
100 |
+
"focusing on that Key token more strongly, likely because it finds the Key token "
|
101 |
+
"relevant to its meaning or context."
|
102 |
)
|
103 |
|
104 |
+
return explanation
|
105 |
|
106 |
+
# Short explanation text for the UI
|
107 |
+
description_text = """
|
108 |
+
## Understanding Transformer Self-Attention
|
109 |
|
110 |
+
- **Rows = Query token** (the token doing the 'looking').
|
111 |
+
- **Columns = Key token** (the token being 'looked at').
|
112 |
+
- Darker color = stronger attention weight.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
|
114 |
+
**Transformers** process all tokens in **parallel**, allowing any token to attend to any other token in the sentence.
|
115 |
+
This makes it easier for the model to capture long-distance relationships.
|
116 |
+
"""
|
117 |
|
118 |
+
########################################
|
119 |
+
# Gradio Interface
|
120 |
+
########################################
|
121 |
with gr.Blocks() as demo:
|
122 |
+
gr.Markdown("# Transformer Self-Attention Visualization (DistilBERT)")
|
123 |
+
gr.Markdown(description_text)
|
124 |
|
125 |
with gr.Row():
|
126 |
+
text_input = gr.Textbox(
|
127 |
+
label="Enter a sentence",
|
128 |
value="Transformers handle long-range context in parallel."
|
129 |
)
|
130 |
+
layer_slider = gr.Slider(
|
131 |
minimum=0, maximum=5, step=1, value=5,
|
132 |
+
label="DistilBERT Layer (0=lowest, 5=highest)"
|
|
|
|
|
|
|
|
|
133 |
)
|
134 |
+
output_plot = gr.Plot(label="Attention Heatmap")
|
135 |
+
|
136 |
+
# Visualization Button
|
137 |
+
visualize_button = gr.Button("Visualize Attention")
|
138 |
+
visualize_button.click(
|
139 |
+
fn=visualize_attention,
|
140 |
+
inputs=[text_input, layer_slider],
|
141 |
+
outputs=output_plot
|
142 |
)
|
|
|
143 |
|
144 |
+
# Dropdown (or Slider) to choose a token index for interpretation
|
145 |
+
token_index = gr.Number(
|
146 |
+
label="Choose a token index to interpret (0-based)",
|
147 |
+
value=0
|
148 |
+
)
|
149 |
+
|
150 |
+
interpretation_output = gr.Markdown(label="Interpretation")
|
151 |
|
152 |
+
# Interpretation Button
|
153 |
+
interpret_button = gr.Button("Explain This Token's Attention")
|
154 |
+
interpret_button.click(
|
155 |
+
fn=interpret_token_attention,
|
156 |
+
inputs=[text_input, token_index, layer_slider],
|
157 |
+
outputs=interpretation_output
|
158 |
)
|
159 |
|
160 |
demo.launch()
|
|