|
import streamlit as st |
|
import torch |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
from transformers import AutoTokenizer |
|
from models.huggingface_model import SentimentClassifierForHuggingFace |
|
import numpy as np |
|
import io |
|
from PIL import Image |
|
|
|
|
|
@st.cache_resource |
|
def load_model(): |
|
model = SentimentClassifierForHuggingFace.from_pretrained("./") |
|
tokenizer = AutoTokenizer.from_pretrained("./") |
|
return model, tokenizer |
|
|
|
def predict_sentiment(text, model, tokenizer): |
|
|
|
tokens = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128) |
|
input_ids = tokens["input_ids"] |
|
|
|
|
|
model.eval() |
|
with torch.no_grad(): |
|
outputs = model(input_ids, return_attention=True, return_dict=True) |
|
|
|
|
|
logits = outputs["logits"] |
|
attention_weights = outputs["attention_weights"] |
|
|
|
|
|
probs = torch.nn.functional.softmax(logits, dim=1) |
|
prediction = torch.argmax(probs, dim=1).item() |
|
confidence = probs[0][prediction].item() |
|
sentiment = "Positive" if prediction == 1 else "Negative" |
|
|
|
|
|
tokens_list = [] |
|
for id in input_ids[0]: |
|
token = tokenizer.convert_ids_to_tokens(id.item()) |
|
tokens_list.append(token) |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(10, 2)) |
|
sns.heatmap( |
|
attention_weights.squeeze(0).cpu().numpy().reshape(1, -1), |
|
cmap="YlOrRd", |
|
annot=True, |
|
fmt=".2f", |
|
cbar=False, |
|
xticklabels=tokens_list, |
|
yticklabels=["Attention"], |
|
ax=ax |
|
) |
|
|
|
|
|
plt.xticks(rotation=45, ha="right", rotation_mode="anchor") |
|
plt.title(f"Prediction: {sentiment} (Confidence: {confidence:.4f})") |
|
plt.tight_layout() |
|
|
|
|
|
buf = io.BytesIO() |
|
fig.savefig(buf, format="png", dpi=150, bbox_inches="tight") |
|
buf.seek(0) |
|
img = Image.open(buf) |
|
plt.close(fig) |
|
|
|
return sentiment, confidence, img |
|
|
|
|
|
def main(): |
|
st.set_page_config( |
|
page_title="Sentiment Analysis with Attention", |
|
page_icon="🧠", |
|
layout="wide" |
|
) |
|
|
|
st.title("Sentiment Analysis with Attention Visualization") |
|
st.write("This model classifies text sentiment as positive or negative and visualizes which parts of the text it focused on using an attention mechanism.") |
|
|
|
|
|
try: |
|
model, tokenizer = load_model() |
|
model_loaded = True |
|
except Exception as e: |
|
st.error(f"Error loading model: {e}") |
|
model_loaded = False |
|
|
|
|
|
text_input = st.text_area( |
|
"Enter text to analyze:", |
|
value="I absolutely loved this movie! The acting was superb.", |
|
height=100, |
|
) |
|
|
|
|
|
if st.button("Analyze Sentiment") and model_loaded: |
|
with st.spinner("Analyzing..."): |
|
sentiment, confidence, viz_img = predict_sentiment(text_input, model, tokenizer) |
|
|
|
|
|
col1, col2 = st.columns([1, 3]) |
|
|
|
with col1: |
|
st.subheader("Prediction:") |
|
sentiment_color = "#5FD068" if sentiment == "Positive" else "#D21312" |
|
st.markdown( |
|
f"<div style='background-color:{sentiment_color}; padding:10px; border-radius:5px;" |
|
f"color:white; text-align:center; font-size:24px;'>{sentiment}</div>", |
|
unsafe_allow_html=True |
|
) |
|
st.metric("Confidence", f"{confidence:.2%}") |
|
|
|
with col2: |
|
st.subheader("Attention Visualization:") |
|
st.image(viz_img, use_column_width=True) |
|
st.caption("The heatmap shows which words the model focused on most when making its prediction.") |
|
|
|
st.markdown("---") |
|
st.subheader("How to interpret the visualization:") |
|
st.write( |
|
"The attention heatmap shows the weight assigned to each token in the text. " |
|
"Darker colors indicate where the model focused more attention when making its prediction. " |
|
"This can help identify which parts of the text were most influential for sentiment classification." |
|
) |
|
|
|
if __name__ == "__main__": |
|
main() |