Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from transformers import BertTokenizer, BertModel | |
# Load pre-trained BERT model and tokenizer from HuggingFace | |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
model = BertModel.from_pretrained('bert-base-uncased', output_attentions=True) | |
# App title and description | |
st.title("BERT Attention Map Visualizer") | |
st.write(""" | |
## Introduction | |
This application visualizes the attention mechanism of the BERT model for a given input sentence. | |
The attention mechanism allows BERT to focus on different parts of the sentence when encoding each token, | |
providing insights into how the model understands the context and relationships between words. | |
This app showcases how BERT generates attention maps and word embeddings using a pre-trained BERT model. | |
### Attention Mechanism | |
The attention mechanism is a method to enhance the ability of the model to focus on important parts of the input sequence. | |
It computes a weighted sum of values (V) based on the similarity between queries (Q) and keys (K). The formulation is as follows: | |
$$ | |
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V | |
$$ | |
where: | |
- \( Q \) (Query): Represents the current token for which attention is being calculated. | |
- \( K \) (Key): Represents the tokens in the input sequence to compare against the query. | |
- \( V \) (Value): Represents the actual values used to compute the attention-weighted sum. | |
- \( d_k \): Dimension of the key vectors, used for scaling. | |
### Key, Query, and Value | |
- **Query (Q)**: Captures the essence of the word/token we are focusing on. | |
- **Key (K)**: Represents all words/tokens we are comparing the query against. | |
- **Value (V)**: Contains the information of all tokens that is aggregated based on attention scores. | |
This mechanism allows the model to dynamically adjust its focus on different parts of the sentence, thereby improving contextual understanding. | |
""") | |
# Input sentence from the user | |
sentence = st.text_input("Enter a sentence:", "The cat is on the mat") | |
# Tokenize and encode the sentence | |
inputs = tokenizer(sentence, return_tensors='pt', add_special_tokens=True) | |
# Get the embeddings and attention weights from BERT | |
outputs = model(**inputs) | |
attention = outputs.attentions # Extract attention weights directly from the pretrained model | |
attention_weights = attention[-1].squeeze(0) # Get attention from the last layer | |
# Function to visualize attention weights | |
def visualize_attention(tokens, attention_weights): | |
attention_weights = attention_weights.detach().numpy() | |
fig, ax = plt.subplots(figsize=(8, 8)) | |
cax = ax.matshow(attention_weights, cmap='viridis') | |
plt.xticks(range(len(tokens)), tokens, rotation=90) | |
plt.yticks(range(len(tokens)), tokens) | |
fig.colorbar(cax) | |
plt.title("Attention Map") | |
st.pyplot(fig) | |
# Extract tokens including special tokens | |
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]) | |
# Remove special tokens for visualization | |
tokens_vis = [token for token in tokens if token not in tokenizer.all_special_tokens] | |
# Visualize the attention weights for the sentence excluding special tokens | |
visualize_attention(tokens_vis, attention_weights[0, 1:-1, 1:-1]) | |
st.write(""" | |
### About BERT | |
BERT (Bidirectional Encoder Representations from Transformers) is a transformer-based model designed to understand the context of words in a sentence. It uses the attention mechanism to weigh the importance of different words when generating word embeddings. This attention mechanism is crucial for tasks like language translation, sentiment analysis, and more. | |
""") | |