EdBianchi commited on
Commit
1e35431
1 Parent(s): 389bf61

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -0
app.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ from transformers import BertTokenizer, BertModel
6
+
7
+ # Load pre-trained BERT model and tokenizer from HuggingFace
8
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
9
+ model = BertModel.from_pretrained('bert-base-uncased', output_attentions=True)
10
+
11
+ # App title and description
12
+ st.title("BERT Attention Map Visualizer")
13
+ st.write("""
14
+ ## Introduction
15
+ This application visualizes the attention mechanism of the BERT model for a given input sentence.
16
+ The attention mechanism allows BERT to focus on different parts of the sentence when encoding each token,
17
+ providing insights into how the model understands the context and relationships between words.
18
+ This app showcases how BERT generates attention maps and word embeddings using a pre-trained BERT model.
19
+
20
+ ### Attention Mechanism
21
+ The attention mechanism is a method to enhance the ability of the model to focus on important parts of the input sequence.
22
+ It computes a weighted sum of values (V) based on the similarity between queries (Q) and keys (K). The formulation is as follows:
23
+
24
+ $$
25
+ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V
26
+ $$
27
+
28
+ where:
29
+ - \( Q \) (Query): Represents the current token for which attention is being calculated.
30
+ - \( K \) (Key): Represents the tokens in the input sequence to compare against the query.
31
+ - \( V \) (Value): Represents the actual values used to compute the attention-weighted sum.
32
+ - \( d_k \): Dimension of the key vectors, used for scaling.
33
+
34
+ ### Key, Query, and Value
35
+ - **Query (Q)**: Captures the essence of the word/token we are focusing on.
36
+ - **Key (K)**: Represents all words/tokens we are comparing the query against.
37
+ - **Value (V)**: Contains the information of all tokens that is aggregated based on attention scores.
38
+
39
+ This mechanism allows the model to dynamically adjust its focus on different parts of the sentence, thereby improving contextual understanding.
40
+ """)
41
+
42
+ # Input sentence from the user
43
+ sentence = st.text_input("Enter a sentence:", "The cat is on the mat")
44
+
45
+ # Tokenize and encode the sentence
46
+ inputs = tokenizer(sentence, return_tensors='pt', add_special_tokens=True)
47
+
48
+ # Get the embeddings and attention weights from BERT
49
+ outputs = model(**inputs)
50
+ attention = outputs.attentions # Extract attention weights directly from the pretrained model
51
+ attention_weights = attention[-1].squeeze(0) # Get attention from the last layer
52
+
53
+ # Function to visualize attention weights
54
+ def visualize_attention(tokens, attention_weights):
55
+ attention_weights = attention_weights.detach().numpy()
56
+
57
+ fig, ax = plt.subplots(figsize=(8, 8))
58
+ cax = ax.matshow(attention_weights, cmap='viridis')
59
+
60
+ plt.xticks(range(len(tokens)), tokens, rotation=90)
61
+ plt.yticks(range(len(tokens)), tokens)
62
+
63
+ fig.colorbar(cax)
64
+ plt.title("Attention Map")
65
+ st.pyplot(fig)
66
+
67
+ # Extract tokens including special tokens
68
+ tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
69
+
70
+ # Remove special tokens for visualization
71
+ tokens_vis = [token for token in tokens if token not in tokenizer.all_special_tokens]
72
+
73
+ # Visualize the attention weights for the sentence excluding special tokens
74
+ visualize_attention(tokens_vis, attention_weights[0, 1:-1, 1:-1])
75
+
76
+ st.write("""
77
+ ### About BERT
78
+ BERT (Bidirectional Encoder Representations from Transformers) is a transformer-based model designed to understand the context of words in a sentence. It uses the attention mechanism to weigh the importance of different words when generating word embeddings. This attention mechanism is crucial for tasks like language translation, sentiment analysis, and more.
79
+ """)