ericwei's picture
Initial model upload
f3ff1d7 verified
import streamlit as st
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoTokenizer
from models.huggingface_model import SentimentClassifierForHuggingFace
import numpy as np
import io
from PIL import Image
# Load model and tokenizer
@st.cache_resource
def load_model():
model = SentimentClassifierForHuggingFace.from_pretrained("./")
tokenizer = AutoTokenizer.from_pretrained("./")
return model, tokenizer
def predict_sentiment(text, model, tokenizer):
# Tokenize the input
tokens = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
input_ids = tokens["input_ids"]
# Run inference
model.eval()
with torch.no_grad():
outputs = model(input_ids, return_attention=True, return_dict=True)
# Get prediction results
logits = outputs["logits"]
attention_weights = outputs["attention_weights"]
# Convert to probabilities and get prediction
probs = torch.nn.functional.softmax(logits, dim=1)
prediction = torch.argmax(probs, dim=1).item()
confidence = probs[0][prediction].item()
sentiment = "Positive" if prediction == 1 else "Negative"
# Get token strings for visualization
tokens_list = []
for id in input_ids[0]:
token = tokenizer.convert_ids_to_tokens(id.item())
tokens_list.append(token)
# Create visualization
fig, ax = plt.subplots(figsize=(10, 2))
sns.heatmap(
attention_weights.squeeze(0).cpu().numpy().reshape(1, -1),
cmap="YlOrRd",
annot=True,
fmt=".2f",
cbar=False,
xticklabels=tokens_list,
yticklabels=["Attention"],
ax=ax
)
# Rotate x-axis labels for better readability
plt.xticks(rotation=45, ha="right", rotation_mode="anchor")
plt.title(f"Prediction: {sentiment} (Confidence: {confidence:.4f})")
plt.tight_layout()
# Convert plot to image
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=150, bbox_inches="tight")
buf.seek(0)
img = Image.open(buf)
plt.close(fig)
return sentiment, confidence, img
# Streamlit app
def main():
st.set_page_config(
page_title="Sentiment Analysis with Attention",
page_icon="🧠",
layout="wide"
)
st.title("Sentiment Analysis with Attention Visualization")
st.write("This model classifies text sentiment as positive or negative and visualizes which parts of the text it focused on using an attention mechanism.")
# Load model and tokenizer
try:
model, tokenizer = load_model()
model_loaded = True
except Exception as e:
st.error(f"Error loading model: {e}")
model_loaded = False
# Text input
text_input = st.text_area(
"Enter text to analyze:",
value="I absolutely loved this movie! The acting was superb.",
height=100,
)
# Process when button is clicked
if st.button("Analyze Sentiment") and model_loaded:
with st.spinner("Analyzing..."):
sentiment, confidence, viz_img = predict_sentiment(text_input, model, tokenizer)
# Display results
col1, col2 = st.columns([1, 3])
with col1:
st.subheader("Prediction:")
sentiment_color = "#5FD068" if sentiment == "Positive" else "#D21312"
st.markdown(
f"<div style='background-color:{sentiment_color}; padding:10px; border-radius:5px;"
f"color:white; text-align:center; font-size:24px;'>{sentiment}</div>",
unsafe_allow_html=True
)
st.metric("Confidence", f"{confidence:.2%}")
with col2:
st.subheader("Attention Visualization:")
st.image(viz_img, use_column_width=True)
st.caption("The heatmap shows which words the model focused on most when making its prediction.")
st.markdown("---")
st.subheader("How to interpret the visualization:")
st.write(
"The attention heatmap shows the weight assigned to each token in the text. "
"Darker colors indicate where the model focused more attention when making its prediction. "
"This can help identify which parts of the text were most influential for sentiment classification."
)
if __name__ == "__main__":
main()