mprateek commited on
Commit
a9ad9c4
1 Parent(s): 24700e8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -0
app.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ from transformers import pipeline
4
+ import pandas as pd
5
+ import matplotlib.pyplot as plt
6
+
7
+ # Set Streamlit configuration to disable deprecation warnings
8
+ st.set_option('deprecation.showPyplotGlobalUse', False)
9
+
10
+ # Initialize the image classification pipeline with the specified model
11
+ pipe = pipeline("image-classification", model="trpakov/vit-face-expression")
12
+
13
+ # Set the title of the Streamlit app
14
+ st.title("Emotion Recognition with vit-face-expression")
15
+
16
+ # Create a file uploader to upload images in JPG or PNG format
17
+ uploaded_images = st.file_uploader("Upload images", type=["jpg", "png"], accept_multiple_files=True)
18
+
19
+ # List to store selected file names
20
+ selected_file_names = []
21
+
22
+ # List to store selected images
23
+ selected_images = []
24
+
25
+ # Process uploaded images if any
26
+ if uploaded_images:
27
+ # Add a "Select All" checkbox in the sidebar for convenience
28
+ select_all = st.sidebar.checkbox("Select All", False)
29
+
30
+ # Iterate over each uploaded image
31
+ for idx, img in enumerate(uploaded_images):
32
+ image = Image.open(img)
33
+ checkbox_key = f"{img.name}_checkbox_{idx}" # Unique key for each checkbox
34
+
35
+ # Display thumbnail image and checkbox in sidebar
36
+ st.sidebar.image(image, caption=f"{img.name} ({img.size / 1024.0:.1f} KB)", width=40)
37
+ selected = st.sidebar.checkbox(f"Select {img.name}", value=select_all, key=checkbox_key)
38
+
39
+ # Add selected images to the list
40
+ if selected:
41
+ selected_images.append(image)
42
+ selected_file_names.append(img.name)
43
+
44
+ # Button to start emotion prediction
45
+ if st.button("Predict Emotions") and selected_images:
46
+ # Predict emotion for each selected image using the pipeline
47
+ results = [pipe(image) for image in selected_images]
48
+ emotions = [result[0]["label"].split("_")[-1].capitalize() for result in results]
49
+
50
+ # Display images and predicted emotions
51
+ for i, (image, result) in enumerate(zip(selected_images, results)):
52
+ st.image(image, caption=f"Predicted emotion: {emotions[i]}", use_column_width=True)
53
+ st.write(f"Emotion Scores for Image #{i+1}")
54
+ st.write(f"{emotions[i]}: {result[0]['score']:.4f}")
55
+ st.write(f"Original File Name: {selected_file_names[i]}")
56
+
57
+ # Calculate emotion statistics
58
+ emotion_counts = pd.Series(emotions).value_counts()
59
+ total_faces = len(selected_images)
60
+
61
+ # Define a color map for emotions
62
+ color_map = {
63
+ 'Neutral': '#B38B6D',
64
+ 'Happy': '#FFFF00',
65
+ 'Sad': '#0000FF',
66
+ 'Angry': '#FF0000',
67
+ 'Disgust': '#008000',
68
+ 'Surprise': '#FFA500',
69
+ 'Fear': '#000000'
70
+ }
71
+
72
+ # Plot pie chart for emotion distribution
73
+ st.write("Emotion Distribution (Pie Chart):")
74
+ fig_pie, ax_pie = plt.subplots()
75
+ pie_colors = [color_map.get(emotion, '#999999') for emotion in emotion_counts.index]
76
+ ax_pie.pie(emotion_counts, labels=emotion_counts.index, autopct='%1.1f%%', startangle=140, colors=pie_colors)
77
+ ax_pie.axis('equal')
78
+ ax_pie.set_title(f"Total Faces Analyzed: {total_faces}")
79
+ st.pyplot(fig_pie)
80
+
81
+ # Plot bar chart for emotion distribution
82
+ st.write("Emotion Distribution (Bar Chart):")
83
+ fig_bar, ax_bar = plt.subplots()
84
+ bar_colors = [color_map.get(emotion, '#999999') for emotion in emotion_counts.index]
85
+ emotion_counts.plot(kind='bar', color=bar_colors, ax=ax_bar)
86
+ ax_bar.set_xlabel('Emotion')
87
+ ax_bar.set_ylabel('Count')
88
+ ax_bar.set_title(f"Emotion Distribution - Total Faces Analyzed: {total_faces}")
89
+ ax_bar.yaxis.set_major_locator(plt.MaxNLocator(integer=True))
90
+ for i in ax_bar.patches:
91
+ ax_bar.text(i.get_x() + i.get_width() / 2, i.get_height() + 0.1, int(i.get_height()), ha='center', va='bottom')
92
+ st.pyplot(fig_bar)