Update app.py
Browse files
app.py
CHANGED
@@ -1,91 +1,95 @@
|
|
|
|
1 |
import gradio as gr
|
2 |
-
|
|
|
3 |
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
|
|
|
|
|
|
9 |
|
10 |
-
|
11 |
-
generator = load_model()
|
12 |
-
|
13 |
-
def generate_text(prompt, max_length=100, temperature=1.0, top_p=0.9):
|
14 |
"""
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
top_p (float): Nucleus sampling hyperparameter.
|
21 |
-
Returns:
|
22 |
-
str: Generated text from GPT-2.
|
23 |
"""
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
)
|
33 |
-
|
34 |
|
35 |
-
#
|
36 |
-
|
37 |
-
|
38 |
-
"""
|
39 |
-
# Educational GPT-2 Demo
|
40 |
-
This demo demonstrates how a smaller Large Language Model (GPT-2) predicts text.
|
41 |
-
Change the parameters below to see how the model's output is affected:
|
42 |
-
- **Max Length** controls the total number of tokens in the output.
|
43 |
-
- **Temperature** controls randomness (higher means more creative/chaotic).
|
44 |
-
- **Top-p** controls the diversity of tokens (lower means more conservative choices).
|
45 |
-
"""
|
46 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
with gr.Row():
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
)
|
56 |
-
|
57 |
-
minimum=20,
|
58 |
-
maximum=200,
|
59 |
-
value=100,
|
60 |
-
step=1,
|
61 |
-
label="Max Length"
|
62 |
-
)
|
63 |
-
temp = gr.Slider(
|
64 |
-
minimum=0.1,
|
65 |
-
maximum=2.0,
|
66 |
-
value=1.0,
|
67 |
-
step=0.1,
|
68 |
-
label="Temperature"
|
69 |
-
)
|
70 |
-
top_p = gr.Slider(
|
71 |
-
minimum=0.1,
|
72 |
-
maximum=1.0,
|
73 |
-
value=0.9,
|
74 |
-
step=0.05,
|
75 |
-
label="Top-p"
|
76 |
-
)
|
77 |
-
generate_button = gr.Button("Generate")
|
78 |
|
79 |
-
|
80 |
-
output_box = gr.Textbox(
|
81 |
-
label="Generated Text",
|
82 |
-
lines=10
|
83 |
-
)
|
84 |
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
|
|
89 |
)
|
90 |
|
91 |
demo.launch()
|
|
|
|
1 |
+
import torch
|
2 |
import gradio as gr
|
3 |
+
import plotly.express as px
|
4 |
+
from transformers import AutoModel, AutoTokenizer
|
5 |
|
6 |
+
########################################
|
7 |
+
# Load Transformer (DistilBERT) with attention
|
8 |
+
########################################
|
9 |
+
model_name = "distilbert-base-uncased"
|
10 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
11 |
+
# Note: output_attentions=True to extract attention matrices
|
12 |
+
model = AutoModel.from_pretrained(model_name, output_attentions=True)
|
13 |
+
model.eval()
|
14 |
|
15 |
+
def visualize_attention(text, layer=5):
|
|
|
|
|
|
|
16 |
"""
|
17 |
+
1. Tokenize input text.
|
18 |
+
2. Run DistilBERT forward pass to get attention matrices.
|
19 |
+
3. Pick a layer (0..5) and average across attention heads.
|
20 |
+
4. Generate a heatmap (Plotly) of shape (seq_len x seq_len).
|
21 |
+
5. Label axes with tokens (Query vs. Key).
|
|
|
|
|
|
|
22 |
"""
|
23 |
+
with torch.no_grad():
|
24 |
+
inputs = tokenizer.encode_plus(text, return_tensors="pt")
|
25 |
+
outputs = model(**inputs)
|
26 |
+
# outputs.attentions: tuple of shape [num_layers] each => (batch=1, num_heads, seq_len, seq_len)
|
27 |
+
all_attentions = outputs.attentions
|
28 |
+
# DistilBERT has 6 layers => valid indices: 0..5
|
29 |
+
attn_layer = all_attentions[layer].mean(dim=1) # average across heads => shape: (1, seq_len, seq_len)
|
30 |
+
|
31 |
+
# Convert to numpy for plotting
|
32 |
+
attn_matrix = attn_layer[0].cpu().numpy()
|
33 |
+
|
34 |
+
# Get tokens (including special tokens like [CLS], [SEP])
|
35 |
+
input_ids = inputs["input_ids"][0]
|
36 |
+
tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
37 |
+
|
38 |
+
# Build a Plotly heatmap
|
39 |
+
fig = px.imshow(
|
40 |
+
attn_matrix,
|
41 |
+
x=tokens,
|
42 |
+
y=tokens,
|
43 |
+
labels={"x": "Key (Being Attended to)", "y": "Query (Focusing)"},
|
44 |
+
color_continuous_scale="Blues",
|
45 |
+
title=f"DistilBERT Attention (Layer {layer})"
|
46 |
)
|
47 |
+
fig.update_xaxes(side="top")
|
48 |
|
49 |
+
# Add tooltip: shows row token, column token, and attention weight
|
50 |
+
fig.update_traces(
|
51 |
+
hovertemplate="Query: %{y}<br>Key: %{x}<br>Attention Weight: %{z:.3f}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
)
|
53 |
+
return fig
|
54 |
+
|
55 |
+
# Short explanation text for the UI
|
56 |
+
description_text = """
|
57 |
+
## Understanding Transformer Self-Attention
|
58 |
+
|
59 |
+
- **Rows = "Query token"** (the token that is looking at other tokens)
|
60 |
+
- **Columns = "Key token"** (the token being looked at)
|
61 |
+
- Darker (or higher) color = stronger attention.
|
62 |
+
|
63 |
+
**Transformers** process all tokens in **parallel**, not step-by-step like RNNs.
|
64 |
+
Thus, **long-distance dependencies** are easier to capture: any token can directly
|
65 |
+
attend to any other token, regardless of distance in the sentence.
|
66 |
+
"""
|
67 |
+
|
68 |
+
########################################
|
69 |
+
# Gradio Interface
|
70 |
+
########################################
|
71 |
+
with gr.Blocks() as demo:
|
72 |
+
gr.Markdown("# Transformer Self-Attention Visualization (DistilBERT)")
|
73 |
+
gr.Markdown(description_text)
|
74 |
|
75 |
with gr.Row():
|
76 |
+
text_input = gr.Textbox(
|
77 |
+
label="Enter a sentence",
|
78 |
+
value="Transformers handle long-range context in parallel."
|
79 |
+
)
|
80 |
+
layer_slider = gr.Slider(
|
81 |
+
minimum=0, maximum=5, step=1, value=5,
|
82 |
+
label="DistilBERT Layer (0=lowest, 5=highest)"
|
83 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
+
output_plot = gr.Plot(label="Attention Heatmap")
|
|
|
|
|
|
|
|
|
86 |
|
87 |
+
visualize_button = gr.Button("Visualize Attention")
|
88 |
+
visualize_button.click(
|
89 |
+
fn=visualize_attention,
|
90 |
+
inputs=[text_input, layer_slider],
|
91 |
+
outputs=output_plot
|
92 |
)
|
93 |
|
94 |
demo.launch()
|
95 |
+
|