Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import torch
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import numpy as np
|
5 |
+
from transformers import BertTokenizer, BertModel
|
6 |
+
|
7 |
+
# Load pre-trained BERT model and tokenizer from HuggingFace
|
8 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
9 |
+
model = BertModel.from_pretrained('bert-base-uncased', output_attentions=True)
|
10 |
+
|
11 |
+
# App title and description
|
12 |
+
st.title("BERT Attention Map Visualizer")
|
13 |
+
st.write("""
|
14 |
+
## Introduction
|
15 |
+
This application visualizes the attention mechanism of the BERT model for a given input sentence.
|
16 |
+
The attention mechanism allows BERT to focus on different parts of the sentence when encoding each token,
|
17 |
+
providing insights into how the model understands the context and relationships between words.
|
18 |
+
This app showcases how BERT generates attention maps and word embeddings using a pre-trained BERT model.
|
19 |
+
|
20 |
+
### Attention Mechanism
|
21 |
+
The attention mechanism is a method to enhance the ability of the model to focus on important parts of the input sequence.
|
22 |
+
It computes a weighted sum of values (V) based on the similarity between queries (Q) and keys (K). The formulation is as follows:
|
23 |
+
|
24 |
+
$$
|
25 |
+
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V
|
26 |
+
$$
|
27 |
+
|
28 |
+
where:
|
29 |
+
- \( Q \) (Query): Represents the current token for which attention is being calculated.
|
30 |
+
- \( K \) (Key): Represents the tokens in the input sequence to compare against the query.
|
31 |
+
- \( V \) (Value): Represents the actual values used to compute the attention-weighted sum.
|
32 |
+
- \( d_k \): Dimension of the key vectors, used for scaling.
|
33 |
+
|
34 |
+
### Key, Query, and Value
|
35 |
+
- **Query (Q)**: Captures the essence of the word/token we are focusing on.
|
36 |
+
- **Key (K)**: Represents all words/tokens we are comparing the query against.
|
37 |
+
- **Value (V)**: Contains the information of all tokens that is aggregated based on attention scores.
|
38 |
+
|
39 |
+
This mechanism allows the model to dynamically adjust its focus on different parts of the sentence, thereby improving contextual understanding.
|
40 |
+
""")
|
41 |
+
|
42 |
+
# Input sentence from the user
|
43 |
+
sentence = st.text_input("Enter a sentence:", "The cat is on the mat")
|
44 |
+
|
45 |
+
# Tokenize and encode the sentence
|
46 |
+
inputs = tokenizer(sentence, return_tensors='pt', add_special_tokens=True)
|
47 |
+
|
48 |
+
# Get the embeddings and attention weights from BERT
|
49 |
+
outputs = model(**inputs)
|
50 |
+
attention = outputs.attentions # Extract attention weights directly from the pretrained model
|
51 |
+
attention_weights = attention[-1].squeeze(0) # Get attention from the last layer
|
52 |
+
|
53 |
+
# Function to visualize attention weights
|
54 |
+
def visualize_attention(tokens, attention_weights):
|
55 |
+
attention_weights = attention_weights.detach().numpy()
|
56 |
+
|
57 |
+
fig, ax = plt.subplots(figsize=(8, 8))
|
58 |
+
cax = ax.matshow(attention_weights, cmap='viridis')
|
59 |
+
|
60 |
+
plt.xticks(range(len(tokens)), tokens, rotation=90)
|
61 |
+
plt.yticks(range(len(tokens)), tokens)
|
62 |
+
|
63 |
+
fig.colorbar(cax)
|
64 |
+
plt.title("Attention Map")
|
65 |
+
st.pyplot(fig)
|
66 |
+
|
67 |
+
# Extract tokens including special tokens
|
68 |
+
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
|
69 |
+
|
70 |
+
# Remove special tokens for visualization
|
71 |
+
tokens_vis = [token for token in tokens if token not in tokenizer.all_special_tokens]
|
72 |
+
|
73 |
+
# Visualize the attention weights for the sentence excluding special tokens
|
74 |
+
visualize_attention(tokens_vis, attention_weights[0, 1:-1, 1:-1])
|
75 |
+
|
76 |
+
st.write("""
|
77 |
+
### About BERT
|
78 |
+
BERT (Bidirectional Encoder Representations from Transformers) is a transformer-based model designed to understand the context of words in a sentence. It uses the attention mechanism to weigh the importance of different words when generating word embeddings. This attention mechanism is crucial for tasks like language translation, sentiment analysis, and more.
|
79 |
+
""")
|