import streamlit as st import torch import matplotlib.pyplot as plt import seaborn as sns from transformers import AutoTokenizer from models.huggingface_model import SentimentClassifierForHuggingFace import numpy as np import io from PIL import Image # Load model and tokenizer @st.cache_resource def load_model(): model = SentimentClassifierForHuggingFace.from_pretrained("./") tokenizer = AutoTokenizer.from_pretrained("./") return model, tokenizer def predict_sentiment(text, model, tokenizer): # Tokenize the input tokens = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128) input_ids = tokens["input_ids"] # Run inference model.eval() with torch.no_grad(): outputs = model(input_ids, return_attention=True, return_dict=True) # Get prediction results logits = outputs["logits"] attention_weights = outputs["attention_weights"] # Convert to probabilities and get prediction probs = torch.nn.functional.softmax(logits, dim=1) prediction = torch.argmax(probs, dim=1).item() confidence = probs[0][prediction].item() sentiment = "Positive" if prediction == 1 else "Negative" # Get token strings for visualization tokens_list = [] for id in input_ids[0]: token = tokenizer.convert_ids_to_tokens(id.item()) tokens_list.append(token) # Create visualization fig, ax = plt.subplots(figsize=(10, 2)) sns.heatmap( attention_weights.squeeze(0).cpu().numpy().reshape(1, -1), cmap="YlOrRd", annot=True, fmt=".2f", cbar=False, xticklabels=tokens_list, yticklabels=["Attention"], ax=ax ) # Rotate x-axis labels for better readability plt.xticks(rotation=45, ha="right", rotation_mode="anchor") plt.title(f"Prediction: {sentiment} (Confidence: {confidence:.4f})") plt.tight_layout() # Convert plot to image buf = io.BytesIO() fig.savefig(buf, format="png", dpi=150, bbox_inches="tight") buf.seek(0) img = Image.open(buf) plt.close(fig) return sentiment, confidence, img # Streamlit app def main(): st.set_page_config( page_title="Sentiment Analysis with Attention", page_icon="🧠", layout="wide" ) st.title("Sentiment Analysis with Attention Visualization") st.write("This model classifies text sentiment as positive or negative and visualizes which parts of the text it focused on using an attention mechanism.") # Load model and tokenizer try: model, tokenizer = load_model() model_loaded = True except Exception as e: st.error(f"Error loading model: {e}") model_loaded = False # Text input text_input = st.text_area( "Enter text to analyze:", value="I absolutely loved this movie! The acting was superb.", height=100, ) # Process when button is clicked if st.button("Analyze Sentiment") and model_loaded: with st.spinner("Analyzing..."): sentiment, confidence, viz_img = predict_sentiment(text_input, model, tokenizer) # Display results col1, col2 = st.columns([1, 3]) with col1: st.subheader("Prediction:") sentiment_color = "#5FD068" if sentiment == "Positive" else "#D21312" st.markdown( f"
{sentiment}
", unsafe_allow_html=True ) st.metric("Confidence", f"{confidence:.2%}") with col2: st.subheader("Attention Visualization:") st.image(viz_img, use_column_width=True) st.caption("The heatmap shows which words the model focused on most when making its prediction.") st.markdown("---") st.subheader("How to interpret the visualization:") st.write( "The attention heatmap shows the weight assigned to each token in the text. " "Darker colors indicate where the model focused more attention when making its prediction. " "This can help identify which parts of the text were most influential for sentiment classification." ) if __name__ == "__main__": main()