mprateek commited on
Commit
e0510c6
1 Parent(s): 8fe6335

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -53
app.py CHANGED
@@ -4,59 +4,38 @@ from transformers import pipeline
4
  import pandas as pd
5
  import matplotlib.pyplot as plt
6
 
7
- # Set Streamlit configuration to disable deprecation warnings
8
  st.set_option('deprecation.showPyplotGlobalUse', False)
9
 
10
- # Initialize the image classification pipeline with the specified model
11
  pipe = pipeline("image-classification", model="trpakov/vit-face-expression")
12
 
13
- # Set the title of the Streamlit app
14
- st.title("Emotion Recognition with vit-face-expression")
15
 
16
- # Create a file uploader to upload images in JPG or PNG format
17
- uploaded_images = st.file_uploader("Upload images", type=["jpg", "png"], accept_multiple_files=True)
18
 
19
- # List to store selected file names
20
- selected_file_names = []
 
 
21
 
22
- # List to store selected images
23
- selected_images = []
24
 
25
- # Process uploaded images if any
26
- if uploaded_images:
27
- # Add a "Select All" checkbox in the sidebar for convenience
28
- select_all = st.sidebar.checkbox("Select All", False)
29
-
30
- # Iterate over each uploaded image
31
- for idx, img in enumerate(uploaded_images):
32
- image = Image.open(img)
33
- checkbox_key = f"{img.name}_checkbox_{idx}" # Unique key for each checkbox
34
-
35
- # Display thumbnail image and checkbox in sidebar
36
- st.sidebar.image(image, caption=f"{img.name} ({img.size / 1024.0:.1f} KB)", width=40)
37
- selected = st.sidebar.checkbox(f"Select {img.name}", value=select_all, key=checkbox_key)
38
-
39
- # Add selected images to the list
40
- if selected:
41
- selected_images.append(image)
42
- selected_file_names.append(img.name)
43
 
44
- # Button to start emotion prediction
45
- if st.button("Predict Emotions") and selected_images:
46
- # Predict emotion for each selected image using the pipeline
47
- results = [pipe(image) for image in selected_images]
48
- emotions = [result[0]["label"].split("_")[-1].capitalize() for result in results]
49
 
50
- # Display images and predicted emotions
51
- for i, (image, result) in enumerate(zip(selected_images, results)):
52
- st.image(image, caption=f"Predicted emotion: {emotions[i]}", use_column_width=True)
53
- st.write(f"Emotion Scores for Image #{i+1}")
54
- st.write(f"{emotions[i]}: {result[0]['score']:.4f}")
55
- st.write(f"Original File Name: {selected_file_names[i]}")
56
-
57
- # Calculate emotion statistics
58
- emotion_counts = pd.Series(emotions).value_counts()
59
- total_faces = len(selected_images)
60
 
61
  # Define a color map for emotions
62
  color_map = {
@@ -69,23 +48,22 @@ if st.button("Predict Emotions") and selected_images:
69
  'Fear': '#000000'
70
  }
71
 
72
- # Plot pie chart for emotion distribution
73
- st.write("Emotion Distribution (Pie Chart):")
74
- fig_pie, ax_pie = plt.subplots()
75
  pie_colors = [color_map.get(emotion, '#999999') for emotion in emotion_counts.index]
 
 
 
76
  ax_pie.pie(emotion_counts, labels=emotion_counts.index, autopct='%1.1f%%', startangle=140, colors=pie_colors)
77
  ax_pie.axis('equal')
78
- ax_pie.set_title(f"Total Faces Analyzed: {total_faces}")
79
  st.pyplot(fig_pie)
80
 
81
- # Plot bar chart for emotion distribution
82
- st.write("Emotion Distribution (Bar Chart):")
83
- fig_bar, ax_bar = plt.subplots()
84
- bar_colors = [color_map.get(emotion, '#999999') for emotion in emotion_counts.index]
85
- emotion_counts.plot(kind='bar', color=bar_colors, ax=ax_bar)
86
  ax_bar.set_xlabel('Emotion')
87
  ax_bar.set_ylabel('Count')
88
- ax_bar.set_title(f"Emotion Distribution - Total Faces Analyzed: {total_faces}")
89
  ax_bar.yaxis.set_major_locator(plt.MaxNLocator(integer=True))
90
  for i in ax_bar.patches:
91
  ax_bar.text(i.get_x() + i.get_width() / 2, i.get_height() + 0.1, int(i.get_height()), ha='center', va='bottom')
 
4
  import pandas as pd
5
  import matplotlib.pyplot as plt
6
 
7
+ # Set Streamlit configuration to disable PyplotGlobalUseWarning
8
  st.set_option('deprecation.showPyplotGlobalUse', False)
9
 
10
+ # Initialize an image classification pipeline
11
  pipe = pipeline("image-classification", model="trpakov/vit-face-expression")
12
 
13
+ # Title of the Streamlit app
14
+ st.title("Single Image Emotion Recognition")
15
 
16
+ # Upload a single image
17
+ uploaded_image = st.file_uploader("Upload an image", type=["jpg", "png"], accept_multiple_files=False)
18
 
19
+ # Process the image immediately after it's uploaded
20
+ if uploaded_image:
21
+ # Load the image
22
+ image = Image.open(uploaded_image)
23
 
24
+ # Display the uploaded image
25
+ st.image(image, caption="Uploaded Image", use_column_width=True)
26
 
27
+ # Predict emotion using the pipeline
28
+ result = pipe(image)
29
+ predicted_class = result[0]["label"]
30
+ predicted_emotion = predicted_class.split("_")[-1].capitalize()
31
+ emotion_score = result[0]["score"]
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ # Display predicted emotion and score
34
+ st.write(f"Predicted Emotion: {predicted_emotion}")
35
+ st.write(f"Emotion Score: {emotion_score:.4f}")
 
 
36
 
37
+ # Emotion statistics (for visualization)
38
+ emotion_counts = pd.Series([predicted_emotion])
 
 
 
 
 
 
 
 
39
 
40
  # Define a color map for emotions
41
  color_map = {
 
48
  'Fear': '#000000'
49
  }
50
 
51
+ # Assign colors to the pie chart
 
 
52
  pie_colors = [color_map.get(emotion, '#999999') for emotion in emotion_counts.index]
53
+
54
+ # Plot a small pie chart
55
+ fig_pie, ax_pie = plt.subplots(figsize=(4, 4))
56
  ax_pie.pie(emotion_counts, labels=emotion_counts.index, autopct='%1.1f%%', startangle=140, colors=pie_colors)
57
  ax_pie.axis('equal')
58
+ ax_pie.set_title("Emotion Distribution")
59
  st.pyplot(fig_pie)
60
 
61
+ # Plot a small bar chart
62
+ fig_bar, ax_bar = plt.subplots(figsize=(4, 4))
63
+ emotion_counts.plot(kind='bar', color=pie_colors, ax=ax_bar)
 
 
64
  ax_bar.set_xlabel('Emotion')
65
  ax_bar.set_ylabel('Count')
66
+ ax_bar.set_title("Emotion Count")
67
  ax_bar.yaxis.set_major_locator(plt.MaxNLocator(integer=True))
68
  for i in ax_bar.patches:
69
  ax_bar.text(i.get_x() + i.get_width() / 2, i.get_height() + 0.1, int(i.get_height()), ha='center', va='bottom')