kevin1911 commited on
Commit
e93333c
·
verified ·
1 Parent(s): b7556c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -150
app.py CHANGED
@@ -1,198 +1,160 @@
1
  import torch
2
  import gradio as gr
3
  import plotly.express as px
4
- import numpy as np
5
  from transformers import AutoModel, AutoTokenizer
6
 
7
  ########################################
8
- # 1) Load DistilBERT with attention
9
  ########################################
10
  model_name = "distilbert-base-uncased"
11
  tokenizer = AutoTokenizer.from_pretrained(model_name)
12
  model = AutoModel.from_pretrained(model_name, output_attentions=True)
13
  model.eval()
14
 
15
- ########################################
16
- # 2) Generate attention analysis
17
- ########################################
18
- def analyze_attention(text, layer=5, top_k=3, show_heatmap=True):
19
  """
20
- 1. Tokenize 'text'.
21
- 2. Forward pass DistilBERT (output_attentions=True).
22
- 3. Extract attention from chosen layer (0..5).
23
- 4. Average across heads => (seq_len, seq_len).
24
- 5. Optionally create Plotly heatmap (fig_dict).
25
- 6. Create text summary of top-k focuses for each token.
26
- 7. Generate an "interpretation" to highlight interesting patterns.
27
  """
28
-
29
  with torch.no_grad():
30
  inputs = tokenizer.encode_plus(text, return_tensors="pt")
31
  outputs = model(**inputs)
32
- all_attentions = outputs.attentions # tuple: (#layers) each => (1, #heads, seq_len, seq_len)
33
- # DistilBERT has 6 layers => valid range: 0..5
34
- att = all_attentions[layer].mean(dim=1) # average across heads => shape: (1, seq_len, seq_len)
 
 
 
35
 
36
- att_matrix = att[0].cpu().numpy() # (seq_len, seq_len)
37
  input_ids = inputs["input_ids"][0]
38
  tokens = tokenizer.convert_ids_to_tokens(input_ids)
39
- seq_len = len(tokens)
40
-
41
- # (Optional) Heatmap
42
- fig_dict = None
43
- if show_heatmap:
44
- fig = px.imshow(
45
- att_matrix,
46
- x=tokens,
47
- y=tokens,
48
- labels={"x": "Token Being Looked At", "y": "Token Doing the Looking"},
49
- color_continuous_scale="Blues",
50
- title=f"DistilBERT Self-Attention (Layer {layer})"
51
- )
52
- fig.update_xaxes(side="top")
53
- fig.update_traces(
54
- hovertemplate="Row token: %{y}<br>Column token: %{x}<br>Focus Weight: %{z:.3f}"
55
- )
56
- fig_dict = fig.to_dict()
57
-
58
- # Top-K Summary for each row
59
- summary_md = "## Top-K Focus for Each Token\n"
60
- summary_md += f"Showing the **top {top_k}** tokens each token focuses on.\n\n"
61
- for i in range(seq_len):
62
- row_token = tokens[i]
63
- row_weights = att_matrix[i]
64
- sorted_idx = row_weights.argsort()[::-1]
65
- top_indices = sorted_idx[:top_k]
66
-
67
- summary_md += f"**Token '{row_token}'** focuses on:\n"
68
- for j in top_indices:
69
- col_token = tokens[j]
70
- weight = row_weights[j]
71
- summary_md += f" - `{col_token}` (weight={weight:.3f})\n"
72
- summary_md += "\n"
73
 
74
- # Generate an additional "interpretation" to highlight patterns
75
- interpretation_md = interpret_attention(att_matrix, tokens)
76
-
77
- # Combine summaries
78
- combined_md = summary_md + "\n" + interpretation_md
 
 
 
 
 
79
 
80
- return fig_dict, combined_md
 
 
 
 
81
 
 
82
 
83
- ########################################
84
- # 3) Interpretation function
85
- ########################################
86
- def interpret_attention(att_matrix: np.ndarray, tokens: list) -> str:
87
  """
88
- Provide a short bullet-list interpretation of the attention matrix:
89
- - Count how many tokens mostly attend to themselves (diagonal).
90
- - Find the global max attention weight (which row->col?), mention tokens involved.
91
- - Possibly mention if we see something interesting about distribution.
92
  """
 
 
 
 
 
 
 
 
 
93
 
94
- seq_len = len(tokens)
95
- diagonal_focus_count = 0
96
- # We'll track the max weight overall
97
- max_val = -1.0
98
- max_rc = (0, 0)
99
-
100
- # For each row, check if diagonal is the top focus
101
- for i in range(seq_len):
102
- row = att_matrix[i]
103
- best_j = row.argmax()
104
- if best_j == i:
105
- diagonal_focus_count += 1
106
- # Check global max
107
- if row[best_j] > max_val:
108
- max_val = row[best_j]
109
- max_rc = (i, best_j)
110
-
111
- # Summaries
112
- # 1) Diagonal focus stat
113
- diag_msg = f"- **{diagonal_focus_count}/{seq_len} tokens** focus most on themselves (the diagonal)."
114
-
115
- # 2) Global max
116
- i, j = max_rc
117
- token_i = tokens[i]
118
- token_j = tokens[j]
119
- global_msg = f"- The **highest single focus** in the matrix is **{max_val:.3f}**, from token '{token_i}' onto '{token_j}'."
120
-
121
- # 3) Possibly some quick ratio
122
- # For each row, sum of row vs. sum of diagonal
123
- # We'll keep it simpler for now
124
-
125
- interpretation = "## Additional Interpretation\n\n"
126
- interpretation += (
127
- "Here are some overall patterns in the attention matrix that might help you:\n\n"
128
  )
129
- interpretation += f"{diag_msg}\n"
130
- interpretation += f"{global_msg}\n"
131
 
132
- interpretation += "\n- A strong diagonal means tokens often reference themselves.\n"
133
- interpretation += (
134
- "- If a token's top focus is another token, that suggests it's referencing or depending on that other token.\n"
 
 
 
 
135
  )
136
 
137
- return interpretation
138
 
 
 
 
139
 
140
- ########################################
141
- # 4) Gradio UI
142
- ########################################
143
- description_md = """
144
- # DistilBERT Self-Attention with Extra Interpretation
145
-
146
- **Instructions:**
147
- 1. Type your text into the box.
148
- 2. Choose which **layer** of DistilBERT to visualize. (Layers range 0..5).
149
- 3. Decide how many top tokens you want listed for each token (Top-K).
150
- 4. (Optional) Check "Show Heatmap" to see the matrix. If it's too overwhelming, uncheck and just see the summary.
151
-
152
- **Reading the Heatmap**:
153
- - **Rows** = tokens doing the looking (focus).
154
- - **Columns** = tokens being looked at.
155
- - **Color intensity** = how strongly the row token focuses on the column token.
156
-
157
- Below the heatmap, you'll see:
158
- - A **Top-K focus** summary for each token.
159
- - An **interpretation** bullet list that highlights interesting overall patterns.
160
- """
161
 
162
- def run_demo(text, layer, top_k, show_heatmap):
163
- fig_dict, summary_md = analyze_attention(text, layer, top_k, show_heatmap)
164
- return fig_dict, summary_md
165
 
 
 
 
166
  with gr.Blocks() as demo:
167
- gr.Markdown(description_md)
 
168
 
169
  with gr.Row():
170
- text_in = gr.Textbox(
171
- label="Enter text",
172
  value="Transformers handle long-range context in parallel."
173
  )
174
- layer_in = gr.Slider(
175
  minimum=0, maximum=5, step=1, value=5,
176
- label="DistilBERT Layer"
177
- )
178
- topk_in = gr.Slider(
179
- minimum=1, maximum=6, step=1, value=3,
180
- label="Top-K Focus"
181
  )
182
- show_heatmap_check = gr.Checkbox(
183
- label="Show Heatmap?",
184
- value=True
 
 
 
 
 
185
  )
186
- run_btn = gr.Button("Analyze Attention")
187
 
188
- out_plot = gr.Plot(label="Attention Heatmap")
189
- out_summary = gr.Markdown(label="Attention Summaries & Interpretation")
 
 
 
 
 
190
 
191
- run_btn.click(
192
- fn=run_demo,
193
- inputs=[text_in, layer_in, topk_in, show_heatmap_check],
194
- outputs=[out_plot, out_summary]
 
 
195
  )
196
 
197
  demo.launch()
198
-
 
1
  import torch
2
  import gradio as gr
3
  import plotly.express as px
 
4
  from transformers import AutoModel, AutoTokenizer
5
 
6
  ########################################
7
+ # Load Transformer (DistilBERT) with attention
8
  ########################################
9
  model_name = "distilbert-base-uncased"
10
  tokenizer = AutoTokenizer.from_pretrained(model_name)
11
  model = AutoModel.from_pretrained(model_name, output_attentions=True)
12
  model.eval()
13
 
14
+ def visualize_attention(text, layer=5):
 
 
 
15
  """
16
+ 1. Tokenize input text.
17
+ 2. Run DistilBERT forward pass to get attention matrices.
18
+ 3. Pick a layer (0..5) and average across attention heads.
19
+ 4. Generate a heatmap (Plotly) of shape (seq_len x seq_len).
20
+ 5. Label axes with tokens (Query vs. Key).
 
 
21
  """
 
22
  with torch.no_grad():
23
  inputs = tokenizer.encode_plus(text, return_tensors="pt")
24
  outputs = model(**inputs)
25
+ all_attentions = outputs.attentions
26
+ # DistilBERT has 6 layers => valid layer indices: 0..5
27
+ attn_layer = all_attentions[layer].mean(dim=1) # shape: (1, seq_len, seq_len)
28
+
29
+ # Convert to numpy for plotting
30
+ attn_matrix = attn_layer[0].cpu().numpy()
31
 
32
+ # Get tokens (including special tokens)
33
  input_ids = inputs["input_ids"][0]
34
  tokens = tokenizer.convert_ids_to_tokens(input_ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ # Build a Plotly heatmap
37
+ fig = px.imshow(
38
+ attn_matrix,
39
+ x=tokens,
40
+ y=tokens,
41
+ labels={"x": "Key (Being Attended to)", "y": "Query (Focusing)"},
42
+ color_continuous_scale="Blues",
43
+ title=f"DistilBERT Attention (Layer {layer})"
44
+ )
45
+ fig.update_xaxes(side="top")
46
 
47
+ # Add tooltip
48
+ fig.update_traces(
49
+ hovertemplate="Query: %{y}<br>Key: %{x}<br>Attention Weight: %{z:.3f}"
50
+ )
51
+ fig.update_layout(coloraxis_colorbar=dict(title="Attention Weight"))
52
 
53
+ return fig
54
 
55
+ def interpret_token_attention(text, token_index=0, layer=5):
 
 
 
56
  """
57
+ Provides a textual explanation for why a particular token (Query) attends
58
+ to other tokens in the input, highlighting the top 2 or 3 tokens
59
+ it focuses on.
 
60
  """
61
+ with torch.no_grad():
62
+ inputs = tokenizer.encode_plus(text, return_tensors="pt")
63
+ outputs = model(**inputs)
64
+ all_attentions = outputs.attentions
65
+ attn_layer = all_attentions[layer].mean(dim=1) # shape: (1, seq_len, seq_len)
66
+
67
+ # Get tokens
68
+ input_ids = inputs["input_ids"][0]
69
+ tokens = tokenizer.convert_ids_to_tokens(input_ids)
70
 
71
+ # Safety check for token_index
72
+ if token_index < 0 or token_index >= len(tokens):
73
+ return "Invalid token index. Please choose a valid token index."
74
+
75
+ # Extract the row corresponding to our Query token
76
+ query_attn = attn_layer[0, token_index, :].cpu().numpy() # shape: (seq_len,)
77
+
78
+ # Sort tokens by attention weight (descending)
79
+ sorted_indices = query_attn.argsort()[::-1]
80
+ top_indices = sorted_indices[:3] # Grab top 3
81
+ top_tokens = [tokens[i] for i in top_indices]
82
+ top_weights = [query_attn[i] for i in top_indices]
83
+
84
+ # Build an explanation
85
+ query_token_str = tokens[token_index]
86
+ explanation = (
87
+ f"**You chose token index {token_index}, which is '{query_token_str}'.**\n\n"
88
+ "In Transformers, each token is converted into Query, Key, and Value vectors:\n"
89
+ "- **Query** = What this token is looking for\n"
90
+ "- **Key** = What another token has to offer\n"
91
+ "- **Value** = The actual information from that token\n\n"
92
+ f"As a Query, '{query_token_str}' attends most strongly to:\n"
 
 
 
 
 
 
 
 
 
 
 
 
93
  )
 
 
94
 
95
+ for t, w in zip(top_tokens, top_weights):
96
+ explanation += f"- **{t}** with attention weight ~ {w:.3f}\n"
97
+
98
+ explanation += (
99
+ "\nA higher attention weight indicates that this Query token is 'looking at' or "
100
+ "focusing on that Key token more strongly, likely because it finds the Key token "
101
+ "relevant to its meaning or context."
102
  )
103
 
104
+ return explanation
105
 
106
+ # Short explanation text for the UI
107
+ description_text = """
108
+ ## Understanding Transformer Self-Attention
109
 
110
+ - **Rows = Query token** (the token doing the 'looking').
111
+ - **Columns = Key token** (the token being 'looked at').
112
+ - Darker color = stronger attention weight.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
+ **Transformers** process all tokens in **parallel**, allowing any token to attend to any other token in the sentence.
115
+ This makes it easier for the model to capture long-distance relationships.
116
+ """
117
 
118
+ ########################################
119
+ # Gradio Interface
120
+ ########################################
121
  with gr.Blocks() as demo:
122
+ gr.Markdown("# Transformer Self-Attention Visualization (DistilBERT)")
123
+ gr.Markdown(description_text)
124
 
125
  with gr.Row():
126
+ text_input = gr.Textbox(
127
+ label="Enter a sentence",
128
  value="Transformers handle long-range context in parallel."
129
  )
130
+ layer_slider = gr.Slider(
131
  minimum=0, maximum=5, step=1, value=5,
132
+ label="DistilBERT Layer (0=lowest, 5=highest)"
 
 
 
 
133
  )
134
+ output_plot = gr.Plot(label="Attention Heatmap")
135
+
136
+ # Visualization Button
137
+ visualize_button = gr.Button("Visualize Attention")
138
+ visualize_button.click(
139
+ fn=visualize_attention,
140
+ inputs=[text_input, layer_slider],
141
+ outputs=output_plot
142
  )
 
143
 
144
+ # Dropdown (or Slider) to choose a token index for interpretation
145
+ token_index = gr.Number(
146
+ label="Choose a token index to interpret (0-based)",
147
+ value=0
148
+ )
149
+
150
+ interpretation_output = gr.Markdown(label="Interpretation")
151
 
152
+ # Interpretation Button
153
+ interpret_button = gr.Button("Explain This Token's Attention")
154
+ interpret_button.click(
155
+ fn=interpret_token_attention,
156
+ inputs=[text_input, token_index, layer_slider],
157
+ outputs=interpretation_output
158
  )
159
 
160
  demo.launch()