import streamlit as st from PIL import Image from transformers import pipeline import pandas as pd import matplotlib.pyplot as plt # Set Streamlit configuration to disable PyplotGlobalUseWarning st.set_option('deprecation.showPyplotGlobalUse', False) # Initialize an image classification pipeline pipe = pipeline("image-classification", model="trpakov/vit-face-expression") # Title of the Streamlit app st.title("Single Image Emotion Recognition") # Upload a single image uploaded_image = st.file_uploader("Upload an image", type=["jpg", "png"], accept_multiple_files=False) # Process the image immediately after it's uploaded if uploaded_image: # Load the image image = Image.open(uploaded_image) # Display the uploaded image st.image(image, caption="Uploaded Image", use_column_width=True) # Predict emotion using the pipeline result = pipe(image) predicted_class = result[0]["label"] predicted_emotion = predicted_class.split("_")[-1].capitalize() emotion_score = result[0]["score"] # Display predicted emotion and score st.write(f"Predicted Emotion: {predicted_emotion}") st.write(f"Emotion Score: {emotion_score:.4f}") # Emotion statistics (for visualization) emotion_counts = pd.Series([predicted_emotion]) # Define a color map for emotions color_map = { 'Neutral': '#B38B6D', 'Happy': '#FFFF00', 'Sad': '#0000FF', 'Angry': '#FF0000', 'Disgust': '#008000', 'Surprise': '#FFA500', 'Fear': '#000000' } # Assign colors to the pie chart pie_colors = [color_map.get(emotion, '#999999') for emotion in emotion_counts.index] # Plot a small pie chart fig_pie, ax_pie = plt.subplots(figsize=(4, 4)) ax_pie.pie(emotion_counts, labels=emotion_counts.index, autopct='%1.1f%%', startangle=140, colors=pie_colors) ax_pie.axis('equal') ax_pie.set_title("Emotion Distribution") st.pyplot(fig_pie) # Plot a small bar chart fig_bar, ax_bar = plt.subplots(figsize=(4, 4)) emotion_counts.plot(kind='bar', color=pie_colors, ax=ax_bar) ax_bar.set_xlabel('Emotion') ax_bar.set_ylabel('Count') ax_bar.set_title("Emotion Count") ax_bar.yaxis.set_major_locator(plt.MaxNLocator(integer=True)) for i in ax_bar.patches: ax_bar.text(i.get_x() + i.get_width() / 2, i.get_height() + 0.1, int(i.get_height()), ha='center', va='bottom') st.pyplot(fig_bar)