kevin1911 commited on
Commit
0b23237
·
verified ·
1 Parent(s): 4e323a6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -76
app.py CHANGED
@@ -1,91 +1,95 @@
 
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
 
3
 
4
- def load_model(model_name="gpt2"):
5
- """Load a GPT-2 model and tokenizer from Hugging Face."""
6
- tokenizer = AutoTokenizer.from_pretrained(model_name)
7
- model = AutoModelForCausalLM.from_pretrained(model_name)
8
- return pipeline("text-generation", model=model, tokenizer=tokenizer)
 
 
 
9
 
10
- # Initialize the pipeline outside the function so it's loaded only once
11
- generator = load_model()
12
-
13
- def generate_text(prompt, max_length=100, temperature=1.0, top_p=0.9):
14
  """
15
- Generates text based on the prompt using a GPT-2 model.
16
- Args:
17
- prompt (str): Input text from the user.
18
- max_length (int): Max tokens in the prompt + generation.
19
- temperature (float): Controls randomness.
20
- top_p (float): Nucleus sampling hyperparameter.
21
- Returns:
22
- str: Generated text from GPT-2.
23
  """
24
- results = generator(
25
- prompt,
26
- max_length=max_length,
27
- temperature=temperature,
28
- top_p=top_p,
29
- num_return_sequences=1,
30
- # GPT-2 may not have a dedicated pad token, so eos_token_id used:
31
- pad_token_id=generator.tokenizer.eos_token_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  )
33
- return results[0]["generated_text"]
34
 
35
- # Build the Gradio interface
36
- with gr.Blocks() as demo:
37
- gr.Markdown(
38
- """
39
- # Educational GPT-2 Demo
40
- This demo demonstrates how a smaller Large Language Model (GPT-2) predicts text.
41
- Change the parameters below to see how the model's output is affected:
42
- - **Max Length** controls the total number of tokens in the output.
43
- - **Temperature** controls randomness (higher means more creative/chaotic).
44
- - **Top-p** controls the diversity of tokens (lower means more conservative choices).
45
- """
46
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  with gr.Row():
49
- with gr.Column():
50
- prompt = gr.Textbox(
51
- lines=4,
52
- label="Prompt",
53
- placeholder="Type a prompt here",
54
- value="Once upon a time,"
55
- )
56
- max_len = gr.Slider(
57
- minimum=20,
58
- maximum=200,
59
- value=100,
60
- step=1,
61
- label="Max Length"
62
- )
63
- temp = gr.Slider(
64
- minimum=0.1,
65
- maximum=2.0,
66
- value=1.0,
67
- step=0.1,
68
- label="Temperature"
69
- )
70
- top_p = gr.Slider(
71
- minimum=0.1,
72
- maximum=1.0,
73
- value=0.9,
74
- step=0.05,
75
- label="Top-p"
76
- )
77
- generate_button = gr.Button("Generate")
78
 
79
- with gr.Column():
80
- output_box = gr.Textbox(
81
- label="Generated Text",
82
- lines=10
83
- )
84
 
85
- generate_button.click(
86
- fn=generate_text,
87
- inputs=[prompt, max_len, temp, top_p],
88
- outputs=[output_box]
 
89
  )
90
 
91
  demo.launch()
 
 
1
+ import torch
2
  import gradio as gr
3
+ import plotly.express as px
4
+ from transformers import AutoModel, AutoTokenizer
5
 
6
+ ########################################
7
+ # Load Transformer (DistilBERT) with attention
8
+ ########################################
9
+ model_name = "distilbert-base-uncased"
10
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+ # Note: output_attentions=True to extract attention matrices
12
+ model = AutoModel.from_pretrained(model_name, output_attentions=True)
13
+ model.eval()
14
 
15
+ def visualize_attention(text, layer=5):
 
 
 
16
  """
17
+ 1. Tokenize input text.
18
+ 2. Run DistilBERT forward pass to get attention matrices.
19
+ 3. Pick a layer (0..5) and average across attention heads.
20
+ 4. Generate a heatmap (Plotly) of shape (seq_len x seq_len).
21
+ 5. Label axes with tokens (Query vs. Key).
 
 
 
22
  """
23
+ with torch.no_grad():
24
+ inputs = tokenizer.encode_plus(text, return_tensors="pt")
25
+ outputs = model(**inputs)
26
+ # outputs.attentions: tuple of shape [num_layers] each => (batch=1, num_heads, seq_len, seq_len)
27
+ all_attentions = outputs.attentions
28
+ # DistilBERT has 6 layers => valid indices: 0..5
29
+ attn_layer = all_attentions[layer].mean(dim=1) # average across heads => shape: (1, seq_len, seq_len)
30
+
31
+ # Convert to numpy for plotting
32
+ attn_matrix = attn_layer[0].cpu().numpy()
33
+
34
+ # Get tokens (including special tokens like [CLS], [SEP])
35
+ input_ids = inputs["input_ids"][0]
36
+ tokens = tokenizer.convert_ids_to_tokens(input_ids)
37
+
38
+ # Build a Plotly heatmap
39
+ fig = px.imshow(
40
+ attn_matrix,
41
+ x=tokens,
42
+ y=tokens,
43
+ labels={"x": "Key (Being Attended to)", "y": "Query (Focusing)"},
44
+ color_continuous_scale="Blues",
45
+ title=f"DistilBERT Attention (Layer {layer})"
46
  )
47
+ fig.update_xaxes(side="top")
48
 
49
+ # Add tooltip: shows row token, column token, and attention weight
50
+ fig.update_traces(
51
+ hovertemplate="Query: %{y}<br>Key: %{x}<br>Attention Weight: %{z:.3f}"
 
 
 
 
 
 
 
 
52
  )
53
+ return fig
54
+
55
+ # Short explanation text for the UI
56
+ description_text = """
57
+ ## Understanding Transformer Self-Attention
58
+
59
+ - **Rows = "Query token"** (the token that is looking at other tokens)
60
+ - **Columns = "Key token"** (the token being looked at)
61
+ - Darker (or higher) color = stronger attention.
62
+
63
+ **Transformers** process all tokens in **parallel**, not step-by-step like RNNs.
64
+ Thus, **long-distance dependencies** are easier to capture: any token can directly
65
+ attend to any other token, regardless of distance in the sentence.
66
+ """
67
+
68
+ ########################################
69
+ # Gradio Interface
70
+ ########################################
71
+ with gr.Blocks() as demo:
72
+ gr.Markdown("# Transformer Self-Attention Visualization (DistilBERT)")
73
+ gr.Markdown(description_text)
74
 
75
  with gr.Row():
76
+ text_input = gr.Textbox(
77
+ label="Enter a sentence",
78
+ value="Transformers handle long-range context in parallel."
79
+ )
80
+ layer_slider = gr.Slider(
81
+ minimum=0, maximum=5, step=1, value=5,
82
+ label="DistilBERT Layer (0=lowest, 5=highest)"
83
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
+ output_plot = gr.Plot(label="Attention Heatmap")
 
 
 
 
86
 
87
+ visualize_button = gr.Button("Visualize Attention")
88
+ visualize_button.click(
89
+ fn=visualize_attention,
90
+ inputs=[text_input, layer_slider],
91
+ outputs=output_plot
92
  )
93
 
94
  demo.launch()
95
+