EdBianchi's picture
Create app.py
1e35431 verified
raw
history blame contribute delete
No virus
3.67 kB
import streamlit as st
import torch
import matplotlib.pyplot as plt
import numpy as np
from transformers import BertTokenizer, BertModel
# Load pre-trained BERT model and tokenizer from HuggingFace
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased', output_attentions=True)
# App title and description
st.title("BERT Attention Map Visualizer")
st.write("""
## Introduction
This application visualizes the attention mechanism of the BERT model for a given input sentence.
The attention mechanism allows BERT to focus on different parts of the sentence when encoding each token,
providing insights into how the model understands the context and relationships between words.
This app showcases how BERT generates attention maps and word embeddings using a pre-trained BERT model.
### Attention Mechanism
The attention mechanism is a method to enhance the ability of the model to focus on important parts of the input sequence.
It computes a weighted sum of values (V) based on the similarity between queries (Q) and keys (K). The formulation is as follows:
$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V
$$
where:
- \( Q \) (Query): Represents the current token for which attention is being calculated.
- \( K \) (Key): Represents the tokens in the input sequence to compare against the query.
- \( V \) (Value): Represents the actual values used to compute the attention-weighted sum.
- \( d_k \): Dimension of the key vectors, used for scaling.
### Key, Query, and Value
- **Query (Q)**: Captures the essence of the word/token we are focusing on.
- **Key (K)**: Represents all words/tokens we are comparing the query against.
- **Value (V)**: Contains the information of all tokens that is aggregated based on attention scores.
This mechanism allows the model to dynamically adjust its focus on different parts of the sentence, thereby improving contextual understanding.
""")
# Input sentence from the user
sentence = st.text_input("Enter a sentence:", "The cat is on the mat")
# Tokenize and encode the sentence
inputs = tokenizer(sentence, return_tensors='pt', add_special_tokens=True)
# Get the embeddings and attention weights from BERT
outputs = model(**inputs)
attention = outputs.attentions # Extract attention weights directly from the pretrained model
attention_weights = attention[-1].squeeze(0) # Get attention from the last layer
# Function to visualize attention weights
def visualize_attention(tokens, attention_weights):
attention_weights = attention_weights.detach().numpy()
fig, ax = plt.subplots(figsize=(8, 8))
cax = ax.matshow(attention_weights, cmap='viridis')
plt.xticks(range(len(tokens)), tokens, rotation=90)
plt.yticks(range(len(tokens)), tokens)
fig.colorbar(cax)
plt.title("Attention Map")
st.pyplot(fig)
# Extract tokens including special tokens
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
# Remove special tokens for visualization
tokens_vis = [token for token in tokens if token not in tokenizer.all_special_tokens]
# Visualize the attention weights for the sentence excluding special tokens
visualize_attention(tokens_vis, attention_weights[0, 1:-1, 1:-1])
st.write("""
### About BERT
BERT (Bidirectional Encoder Representations from Transformers) is a transformer-based model designed to understand the context of words in a sentence. It uses the attention mechanism to weigh the importance of different words when generating word embeddings. This attention mechanism is crucial for tasks like language translation, sentiment analysis, and more.
""")