Spaces:
Running
Running
import streamlit as st | |
from PIL import Image | |
from transformers import pipeline | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from pandas.plotting import parallel_coordinates | |
# Initialize session state for results, image names, and image sizes if not already present | |
if 'results' not in st.session_state: | |
st.session_state['results'] = [] | |
if 'image_names' not in st.session_state: | |
st.session_state['image_names'] = [] | |
if 'image_sizes' not in st.session_state: | |
st.session_state['image_sizes'] = [] | |
# Disable PyplotGlobalUseWarning | |
st.set_option('deprecation.showPyplotGlobalUse', False) | |
# Create an image classification pipeline with scores | |
pipe = pipeline("image-classification", model="trpakov/vit-face-expression", top_k=None) | |
# Streamlit app | |
st.title("Emotion Recognition with vit-face-expression") | |
# Upload images | |
uploaded_images = st.file_uploader("Upload images", type=["jpg", "png"], accept_multiple_files=True) | |
# Display thumbnail images alongside file names and sizes in the sidebar | |
selected_images = [] | |
if uploaded_images: | |
# Reset the image names and sizes lists each time new images are uploaded | |
st.session_state['image_names'] = [img.name for img in uploaded_images] | |
st.session_state['image_sizes'] = [round(img.size / 1024.0, 1) for img in uploaded_images] | |
# Add a "Select All" checkbox in the sidebar | |
select_all = st.sidebar.checkbox("Select All", False) | |
for idx, img in enumerate(uploaded_images): | |
image = Image.open(img) | |
checkbox_key = f"{img.name}_checkbox_{idx}" # Unique key for each checkbox | |
# Display thumbnail image and checkbox in sidebar | |
st.sidebar.image(image, caption=f"{img.name} {img.size / 1024.0:.1f} KB", width=40) | |
selected = st.sidebar.checkbox(f"Select {img.name}", value=select_all, key=checkbox_key) | |
if selected: | |
selected_images.append(image) | |
if st.button("Predict Emotions") and selected_images: | |
# Predict emotion for each selected image using the pipeline | |
st.session_state['results'] = [pipe(image) for image in selected_images] | |
# Generate DataFrame from results | |
if st.button("Generate HeatMap & DataFrame"): | |
# Access the results, image names, and sizes from the session state | |
results = st.session_state['results'] | |
image_names = st.session_state['image_names'] | |
image_sizes = st.session_state['image_sizes'] | |
if results: | |
# Initialize an empty list to store all the data | |
data = [] | |
# Iterate over the results and populate the list with dictionaries | |
for i, result_set in enumerate(results): | |
# Initialize a dictionary for the current set with zeros | |
current_data = { | |
'Happy': 0, | |
'Surprise': 0, | |
'Neutral': 0, | |
'Sad': 0, | |
'Disgust': 0, | |
'Angry': 0, | |
'Fear': 0, | |
# Add other emotions if necessary | |
'Image Name': image_names[i], | |
#'Image Size (KB)': image_sizes[i] | |
'Image Size (KB)': f"{image_sizes[i]:.1f}" # Format the size to one decimal place | |
} | |
for result in result_set: | |
# Capitalize the label and update the score in the current set | |
emotion = result['label'].capitalize() | |
score = round(result['score'], 4) # Round the score to 4 decimal places | |
current_data[emotion] = score | |
# Append the current data to the data list | |
data.append(current_data) | |
# Convert the list of dictionaries into a pandas DataFrame | |
df_emotions = pd.DataFrame(data) | |
# Display the DataFrame | |
st.write(df_emotions) | |
# Plotting the heatmap for the first seven columns | |
plt.figure(figsize=(10, 10)) | |
sns.heatmap(df_emotions.iloc[:, :7], annot=True, fmt=".1f", cmap='viridis') | |
plt.title('Heatmap of Emotion Scores') | |
plt.xlabel('Emotion Categories') | |
plt.ylabel('Data Points') | |
st.pyplot(plt) | |
# Optional: Save the DataFrame to a CSV file | |
df_emotions.to_csv('emotion_scores.csv', index=False) | |
st.success('DataFrame generated and saved as emotion_scores.csv') | |
with open('emotion_scores.csv', 'r') as f: | |
csv_file = f.read() | |
st.download_button( | |
label='Download Emotion Scores as CSV', | |
data=csv_file, | |
file_name='emotion_scores.csv', | |
mime='text/csv', | |
) | |
st.success('DataFrame generated and available for download.') | |
else: | |
st.error("No results to generate DataFrame. Please predict emotions first.") | |