Spaces:
Sleeping
Sleeping
Upload 6 files
Browse files- README.md +0 -13
- app.py +149 -0
- optimized_emotion_classifier.pkl +3 -0
- optimized_emotion_evaluation.py +413 -0
- optimized_emotion_model.py +745 -0
- requirements.txt +9 -0
README.md
CHANGED
|
@@ -1,13 +0,0 @@
|
|
| 1 |
-
---
|
| 2 |
-
title: Emotion Recognition App
|
| 3 |
-
emoji: 🐠
|
| 4 |
-
colorFrom: yellow
|
| 5 |
-
colorTo: yellow
|
| 6 |
-
sdk: gradio
|
| 7 |
-
sdk_version: 5.29.1
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned: false
|
| 10 |
-
license: mit
|
| 11 |
-
---
|
| 12 |
-
|
| 13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import os
|
| 6 |
+
try:
|
| 7 |
+
from fastai.learner import load_learner
|
| 8 |
+
fastai_available = True
|
| 9 |
+
except ImportError:
|
| 10 |
+
fastai_available = False
|
| 11 |
+
print("FastAI is not installed. Please install it using: pip install fastai")
|
| 12 |
+
|
| 13 |
+
# Emotion classes
|
| 14 |
+
emotions = ["Angry", "Happy", "Neutral", "Sad", "Surprise"] # Model's predicted classes
|
| 15 |
+
|
| 16 |
+
# Load the model
|
| 17 |
+
try:
|
| 18 |
+
print("Loading the model...")
|
| 19 |
+
if os.path.exists('optimized_emotion_classifier.pkl'):
|
| 20 |
+
if fastai_available:
|
| 21 |
+
try:
|
| 22 |
+
print("Trying to load the model with FastAI...")
|
| 23 |
+
model = load_learner('optimized_emotion_classifier.pkl')
|
| 24 |
+
print("Model loaded successfully with FastAI!")
|
| 25 |
+
except Exception as e:
|
| 26 |
+
print(f"Error occurred while loading the model with FastAI: {e}")
|
| 27 |
+
# Try to load as pickle file as a fallback
|
| 28 |
+
with open('optimized_emotion_classifier.pkl', 'rb') as f:
|
| 29 |
+
model = pickle.load(f)
|
| 30 |
+
print("Model loaded successfully as pickle file!")
|
| 31 |
+
else:
|
| 32 |
+
print("FastAI library is not installed. Please install it using: pip install fastai")
|
| 33 |
+
raise ImportError("FastAI library is not installed.")
|
| 34 |
+
else:
|
| 35 |
+
print("Model file not found!")
|
| 36 |
+
raise FileNotFoundError("optimized_emotion_classifier.pkl file not found.")
|
| 37 |
+
except Exception as e:
|
| 38 |
+
print(f"Critical error occurred while loading the model: {e}")
|
| 39 |
+
raise
|
| 40 |
+
|
| 41 |
+
# Preprocess the image
|
| 42 |
+
def preprocess_image(image):
|
| 43 |
+
if image is None:
|
| 44 |
+
return None
|
| 45 |
+
|
| 46 |
+
# Resize the image to the model's expected size (224x224 for ResNet models)
|
| 47 |
+
img = Image.fromarray(image).resize((224, 224))
|
| 48 |
+
|
| 49 |
+
# Return the PIL image for FastAI model
|
| 50 |
+
return img
|
| 51 |
+
|
| 52 |
+
# Predict the emotion
|
| 53 |
+
def predict_emotion(image):
|
| 54 |
+
if image is None:
|
| 55 |
+
return "Please upload an image or take a photo", None
|
| 56 |
+
|
| 57 |
+
try:
|
| 58 |
+
# Preprocess the image
|
| 59 |
+
processed_image = preprocess_image(image)
|
| 60 |
+
|
| 61 |
+
# FastAI model prediction
|
| 62 |
+
prediction = model.predict(processed_image)
|
| 63 |
+
|
| 64 |
+
# FastAI prediction[0] returns the class name, prediction[2] returns the probabilities
|
| 65 |
+
emotion = prediction[0]
|
| 66 |
+
|
| 67 |
+
# All emotions probabilities
|
| 68 |
+
probs = prediction[2].numpy()
|
| 69 |
+
confidence = {emotions[i]: float(probs[i]) for i in range(len(emotions))}
|
| 70 |
+
|
| 71 |
+
return emotion, confidence
|
| 72 |
+
except Exception as e:
|
| 73 |
+
print(f"Prediction error: {e}")
|
| 74 |
+
return f"Error occurred: {str(e)}", None
|
| 75 |
+
|
| 76 |
+
# Capture photo and predict emotion function
|
| 77 |
+
def capture_and_predict(image_input):
|
| 78 |
+
# If webcam is open, capture photo and predict
|
| 79 |
+
if image_input is not None:
|
| 80 |
+
return predict_emotion(image_input)
|
| 81 |
+
else:
|
| 82 |
+
return "Camera is not open or photo is not taken", None
|
| 83 |
+
|
| 84 |
+
# Helper function to get first image from emotion folder
|
| 85 |
+
def get_first_image_from_emotion(emotion_folder):
|
| 86 |
+
folder_path = os.path.join("EMOTION RECOGNITION DATASET", emotion_folder)
|
| 87 |
+
if os.path.exists(folder_path) and os.listdir(folder_path):
|
| 88 |
+
return os.path.join(folder_path, os.listdir(folder_path)[0])
|
| 89 |
+
return None
|
| 90 |
+
|
| 91 |
+
# Gradio interface
|
| 92 |
+
with gr.Blocks(title="Emotion Recognition", theme=gr.themes.Default()) as demo:
|
| 93 |
+
gr.Markdown("# Emotion Recognition Application")
|
| 94 |
+
gr.Markdown("This application recognizes your facial emotion. Upload an image or take a photo with your webcam.")
|
| 95 |
+
|
| 96 |
+
with gr.Row():
|
| 97 |
+
with gr.Column(scale=1):
|
| 98 |
+
# Enable webcam with higher resolution and clear instructions
|
| 99 |
+
input_image = gr.Image(
|
| 100 |
+
label="Upload Image or Use Webcam",
|
| 101 |
+
type="numpy",
|
| 102 |
+
sources=["upload", "webcam", "clipboard"], # Added clipboard as a source
|
| 103 |
+
height=300
|
| 104 |
+
)
|
| 105 |
+
submit_btn = gr.Button("Predict", variant="primary")
|
| 106 |
+
|
| 107 |
+
with gr.Column(scale=1):
|
| 108 |
+
output_emotion = gr.Textbox(label="Predicted Emotion")
|
| 109 |
+
output_confidence = gr.Label(label="Confidence Levels")
|
| 110 |
+
|
| 111 |
+
# Add example images for testing
|
| 112 |
+
try:
|
| 113 |
+
happy_example = os.path.join("EMOTION RECOGNITION DATASET", "Happy", os.listdir(os.path.join("EMOTION RECOGNITION DATASET", "Happy"))[0])
|
| 114 |
+
sad_example = os.path.join("EMOTION RECOGNITION DATASET", "Sad", os.listdir(os.path.join("EMOTION RECOGNITION DATASET", "Sad"))[0])
|
| 115 |
+
angry_example = "gettyimages.jpg"
|
| 116 |
+
surprise_example = os.path.join("EMOTION RECOGNITION DATASET", "Surprise", os.listdir(os.path.join("EMOTION RECOGNITION DATASET", "Surprise"))[3])
|
| 117 |
+
neutral_example = os.path.join("EMOTION RECOGNITION DATASET", "Neutral", os.listdir(os.path.join("EMOTION RECOGNITION DATASET", "Neutral"))[0])
|
| 118 |
+
gr.Examples(
|
| 119 |
+
examples=[happy_example, sad_example, angry_example, surprise_example,neutral_example],
|
| 120 |
+
inputs=input_image,
|
| 121 |
+
)
|
| 122 |
+
except:
|
| 123 |
+
print("Could not load example images")
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
submit_btn.click(
|
| 127 |
+
fn=capture_and_predict,
|
| 128 |
+
inputs=input_image,
|
| 129 |
+
outputs=[output_emotion, output_confidence]
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# Automatically run prediction when image changes
|
| 133 |
+
input_image.change(
|
| 134 |
+
fn=predict_emotion,
|
| 135 |
+
inputs=input_image,
|
| 136 |
+
outputs=[output_emotion, output_confidence]
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
# Start the application
|
| 140 |
+
# En sondaki demo.launch satırını bu şekilde değiştirin
|
| 141 |
+
if __name__ == "__main__":
|
| 142 |
+
print("Application starting...")
|
| 143 |
+
try:
|
| 144 |
+
public_url = demo.launch(share=True, server_name="0.0.0.0")
|
| 145 |
+
print(f"Application started successfully. Public URL: {public_url}")
|
| 146 |
+
except Exception as e:
|
| 147 |
+
print(f"Error launching application: {e}")
|
| 148 |
+
print("Trying to launch without sharing...")
|
| 149 |
+
demo.launch(share=False) # Paylaşım olmadan deneme
|
optimized_emotion_classifier.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3cdc2471060b75ce518fc93031da9a1f17df100b366baf73c046842d94c413c3
|
| 3 |
+
size 87706985
|
optimized_emotion_evaluation.py
ADDED
|
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Optimized Emotion Recognition Model Evaluation
|
| 3 |
+
This script provides comprehensive evaluation of the trained emotion recognition model
|
| 4 |
+
with detailed metrics, visualizations, and performance analysis.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
from fastai.vision.all import *
|
| 12 |
+
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
|
| 13 |
+
import seaborn as sns
|
| 14 |
+
from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score
|
| 15 |
+
from sklearn.preprocessing import label_binarize
|
| 16 |
+
|
| 17 |
+
def load_model(model_path='optimized_emotion_classifier.pkl'):
|
| 18 |
+
"""
|
| 19 |
+
Loads a trained model from disk
|
| 20 |
+
"""
|
| 21 |
+
try:
|
| 22 |
+
learn = load_learner(model_path)
|
| 23 |
+
print(f"Model successfully loaded: {model_path}")
|
| 24 |
+
return learn
|
| 25 |
+
except Exception as e:
|
| 26 |
+
print(f"Error loading model: {e}")
|
| 27 |
+
return None
|
| 28 |
+
|
| 29 |
+
def evaluate_accuracy(learn, test_path=None):
|
| 30 |
+
"""
|
| 31 |
+
Evaluates model accuracy on test data
|
| 32 |
+
"""
|
| 33 |
+
print("=" * 50)
|
| 34 |
+
print("MODEL ACCURACY EVALUATION")
|
| 35 |
+
print("=" * 50)
|
| 36 |
+
|
| 37 |
+
# Use default test path if none provided
|
| 38 |
+
if test_path is None:
|
| 39 |
+
test_path = Path("EMOTION RECOGNITION DATASET")
|
| 40 |
+
|
| 41 |
+
# Create test dataloader with labels
|
| 42 |
+
dls = learn.dls.test_dl(get_image_files(test_path), with_labels=True, num_workers=0)
|
| 43 |
+
|
| 44 |
+
# Get predictions and targets
|
| 45 |
+
preds, targets = learn.get_preds(dl=dls)
|
| 46 |
+
pred_classes = preds.argmax(dim=1)
|
| 47 |
+
|
| 48 |
+
# Calculate accuracy
|
| 49 |
+
accuracy = accuracy_score(targets, pred_classes)
|
| 50 |
+
print(f"Test accuracy: {accuracy:.4f}")
|
| 51 |
+
|
| 52 |
+
# Classification report
|
| 53 |
+
class_names = learn.dls.vocab
|
| 54 |
+
report = classification_report(targets, pred_classes, target_names=class_names)
|
| 55 |
+
print("Classification Report:")
|
| 56 |
+
print(report)
|
| 57 |
+
|
| 58 |
+
# Save report to file
|
| 59 |
+
with open('evaluation_report.txt', 'w') as f:
|
| 60 |
+
f.write(f"Test accuracy: {accuracy:.4f}\n\n")
|
| 61 |
+
f.write(report)
|
| 62 |
+
|
| 63 |
+
return preds, targets, class_names
|
| 64 |
+
|
| 65 |
+
def visualize_confusion_matrix(targets, pred_classes, class_names):
|
| 66 |
+
"""
|
| 67 |
+
Creates and saves a detailed confusion matrix
|
| 68 |
+
"""
|
| 69 |
+
print("=" * 50)
|
| 70 |
+
print("CONFUSION MATRIX")
|
| 71 |
+
print("=" * 50)
|
| 72 |
+
|
| 73 |
+
# Calculate confusion matrix
|
| 74 |
+
cm = confusion_matrix(targets, pred_classes)
|
| 75 |
+
|
| 76 |
+
# Normalize confusion matrix
|
| 77 |
+
cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
|
| 78 |
+
|
| 79 |
+
# Plot confusion matrix
|
| 80 |
+
plt.figure(figsize=(12, 10))
|
| 81 |
+
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
|
| 82 |
+
xticklabels=class_names, yticklabels=class_names)
|
| 83 |
+
plt.xlabel('Predicted')
|
| 84 |
+
plt.ylabel('True')
|
| 85 |
+
plt.title('Confusion Matrix')
|
| 86 |
+
plt.savefig('evaluation_confusion_matrix.png')
|
| 87 |
+
plt.close()
|
| 88 |
+
|
| 89 |
+
# Plot normalized confusion matrix
|
| 90 |
+
plt.figure(figsize=(12, 10))
|
| 91 |
+
sns.heatmap(cm_norm, annot=True, fmt='.2f', cmap='Blues',
|
| 92 |
+
xticklabels=class_names, yticklabels=class_names)
|
| 93 |
+
plt.xlabel('Predicted')
|
| 94 |
+
plt.ylabel('True')
|
| 95 |
+
plt.title('Normalized Confusion Matrix')
|
| 96 |
+
plt.savefig('evaluation_normalized_confusion_matrix.png')
|
| 97 |
+
plt.close()
|
| 98 |
+
|
| 99 |
+
return cm, cm_norm
|
| 100 |
+
|
| 101 |
+
def visualize_roc_curves(preds, targets, class_names):
|
| 102 |
+
"""
|
| 103 |
+
Plots ROC curves for multi-class classification
|
| 104 |
+
"""
|
| 105 |
+
print("=" * 50)
|
| 106 |
+
print("ROC CURVES")
|
| 107 |
+
print("=" * 50)
|
| 108 |
+
|
| 109 |
+
# Convert targets to one-hot encoding
|
| 110 |
+
n_classes = len(class_names)
|
| 111 |
+
targets_one_hot = label_binarize(targets, classes=range(n_classes))
|
| 112 |
+
|
| 113 |
+
# Compute ROC curve and ROC area for each class
|
| 114 |
+
fpr = {}
|
| 115 |
+
tpr = {}
|
| 116 |
+
roc_auc = {}
|
| 117 |
+
|
| 118 |
+
plt.figure(figsize=(12, 10))
|
| 119 |
+
|
| 120 |
+
for i in range(n_classes):
|
| 121 |
+
fpr[i], tpr[i], _ = roc_curve(targets_one_hot[:, i], preds[:, i])
|
| 122 |
+
roc_auc[i] = auc(fpr[i], tpr[i])
|
| 123 |
+
|
| 124 |
+
plt.plot(fpr[i], tpr[i], lw=2,
|
| 125 |
+
label=f'ROC curve of {class_names[i]} (area = {roc_auc[i]:.2f})')
|
| 126 |
+
|
| 127 |
+
# Plot random chance line
|
| 128 |
+
plt.plot([0, 1], [0, 1], 'k--', lw=2)
|
| 129 |
+
plt.xlim([0.0, 1.0])
|
| 130 |
+
plt.ylim([0.0, 1.05])
|
| 131 |
+
plt.xlabel('False Positive Rate')
|
| 132 |
+
plt.ylabel('True Positive Rate')
|
| 133 |
+
plt.title('Receiver Operating Characteristic (ROC) Curves')
|
| 134 |
+
plt.legend(loc="lower right")
|
| 135 |
+
plt.savefig('evaluation_roc_curves.png')
|
| 136 |
+
plt.close()
|
| 137 |
+
|
| 138 |
+
return roc_auc
|
| 139 |
+
|
| 140 |
+
def visualize_precision_recall(preds, targets, class_names):
|
| 141 |
+
"""
|
| 142 |
+
Plots Precision-Recall curves for multi-class classification
|
| 143 |
+
"""
|
| 144 |
+
print("=" * 50)
|
| 145 |
+
print("PRECISION-RECALL CURVES")
|
| 146 |
+
print("=" * 50)
|
| 147 |
+
|
| 148 |
+
# Convert targets to one-hot encoding
|
| 149 |
+
n_classes = len(class_names)
|
| 150 |
+
targets_one_hot = label_binarize(targets, classes=range(n_classes))
|
| 151 |
+
|
| 152 |
+
# Compute Precision-Recall curve and average precision for each class
|
| 153 |
+
precision = {}
|
| 154 |
+
recall = {}
|
| 155 |
+
avg_precision = {}
|
| 156 |
+
|
| 157 |
+
plt.figure(figsize=(12, 10))
|
| 158 |
+
|
| 159 |
+
for i in range(n_classes):
|
| 160 |
+
precision[i], recall[i], _ = precision_recall_curve(targets_one_hot[:, i], preds[:, i])
|
| 161 |
+
avg_precision[i] = average_precision_score(targets_one_hot[:, i], preds[:, i])
|
| 162 |
+
|
| 163 |
+
plt.plot(recall[i], precision[i], lw=2,
|
| 164 |
+
label=f'{class_names[i]} (AP = {avg_precision[i]:.2f})')
|
| 165 |
+
|
| 166 |
+
plt.xlim([0.0, 1.0])
|
| 167 |
+
plt.ylim([0.0, 1.05])
|
| 168 |
+
plt.xlabel('Recall')
|
| 169 |
+
plt.ylabel('Precision')
|
| 170 |
+
plt.title('Precision-Recall Curves')
|
| 171 |
+
plt.legend(loc="best")
|
| 172 |
+
plt.savefig('evaluation_precision_recall_curves.png')
|
| 173 |
+
plt.close()
|
| 174 |
+
|
| 175 |
+
return avg_precision
|
| 176 |
+
|
| 177 |
+
def analyze_prediction_confidence(preds, targets, class_names):
|
| 178 |
+
"""
|
| 179 |
+
Analyzes prediction confidence distributions for correct and incorrect predictions
|
| 180 |
+
"""
|
| 181 |
+
print("=" * 50)
|
| 182 |
+
print("PREDICTION CONFIDENCE ANALYSIS")
|
| 183 |
+
print("=" * 50)
|
| 184 |
+
|
| 185 |
+
# Get highest probability for each prediction
|
| 186 |
+
pred_classes = preds.argmax(dim=1)
|
| 187 |
+
pred_probs = preds.max(dim=1)[0]
|
| 188 |
+
|
| 189 |
+
# Separate probabilities for correct and incorrect predictions
|
| 190 |
+
correct_mask = pred_classes == targets
|
| 191 |
+
correct_probs = pred_probs[correct_mask]
|
| 192 |
+
incorrect_probs = pred_probs[~correct_mask]
|
| 193 |
+
|
| 194 |
+
# Plot confidence distributions
|
| 195 |
+
plt.figure(figsize=(12, 8))
|
| 196 |
+
|
| 197 |
+
# Plot histograms
|
| 198 |
+
if len(correct_probs) > 0:
|
| 199 |
+
plt.hist(correct_probs.numpy(), alpha=0.5, bins=20,
|
| 200 |
+
label=f'Correct predictions (n={len(correct_probs)})')
|
| 201 |
+
|
| 202 |
+
if len(incorrect_probs) > 0:
|
| 203 |
+
plt.hist(incorrect_probs.numpy(), alpha=0.5, bins=20,
|
| 204 |
+
label=f'Incorrect predictions (n={len(incorrect_probs)})')
|
| 205 |
+
|
| 206 |
+
plt.xlabel('Prediction Confidence')
|
| 207 |
+
plt.ylabel('Count')
|
| 208 |
+
plt.title('Distribution of Prediction Confidence')
|
| 209 |
+
plt.legend()
|
| 210 |
+
plt.savefig('evaluation_confidence_distribution.png')
|
| 211 |
+
plt.close()
|
| 212 |
+
|
| 213 |
+
# Calculate statistics
|
| 214 |
+
if len(correct_probs) > 0:
|
| 215 |
+
avg_correct_conf = correct_probs.mean().item()
|
| 216 |
+
print(f"Average confidence for correct predictions: {avg_correct_conf:.4f}")
|
| 217 |
+
else:
|
| 218 |
+
avg_correct_conf = 0
|
| 219 |
+
print("No correct predictions found")
|
| 220 |
+
|
| 221 |
+
if len(incorrect_probs) > 0:
|
| 222 |
+
avg_incorrect_conf = incorrect_probs.mean().item()
|
| 223 |
+
print(f"Average confidence for incorrect predictions: {avg_incorrect_conf:.4f}")
|
| 224 |
+
else:
|
| 225 |
+
avg_incorrect_conf = 0
|
| 226 |
+
print("No incorrect predictions found")
|
| 227 |
+
|
| 228 |
+
return avg_correct_conf, avg_incorrect_conf
|
| 229 |
+
|
| 230 |
+
def visualize_sample_predictions(learn, test_path=None, num_samples=10):
|
| 231 |
+
"""
|
| 232 |
+
Visualizes sample predictions with probabilities
|
| 233 |
+
"""
|
| 234 |
+
print("=" * 50)
|
| 235 |
+
print("SAMPLE PREDICTIONS")
|
| 236 |
+
print("=" * 50)
|
| 237 |
+
|
| 238 |
+
if test_path is None:
|
| 239 |
+
test_path = Path("EMOTION RECOGNITION DATASET")
|
| 240 |
+
|
| 241 |
+
# Get random test images
|
| 242 |
+
test_files = get_image_files(test_path)
|
| 243 |
+
if len(test_files) > num_samples:
|
| 244 |
+
test_files = np.random.choice(test_files, num_samples, replace=False)
|
| 245 |
+
|
| 246 |
+
# Create figure for visualization
|
| 247 |
+
rows = min(3, num_samples)
|
| 248 |
+
cols = int(np.ceil(num_samples / rows))
|
| 249 |
+
fig, axes = plt.subplots(rows, cols, figsize=(cols*5, rows*5))
|
| 250 |
+
axes = axes.flatten() if hasattr(axes, 'flatten') else [axes]
|
| 251 |
+
|
| 252 |
+
# Make predictions and plot
|
| 253 |
+
for i, img_path in enumerate(test_files):
|
| 254 |
+
if i >= len(axes):
|
| 255 |
+
break
|
| 256 |
+
|
| 257 |
+
# Load image and predict
|
| 258 |
+
img = PILImage.create(img_path)
|
| 259 |
+
pred_class, pred_idx, probs = learn.predict(img)
|
| 260 |
+
|
| 261 |
+
# Get true label from parent directory
|
| 262 |
+
true_label = img_path.parent.name
|
| 263 |
+
|
| 264 |
+
# Plot image
|
| 265 |
+
axes[i].imshow(img)
|
| 266 |
+
|
| 267 |
+
# Set title color based on correctness
|
| 268 |
+
if true_label == pred_class:
|
| 269 |
+
title_color = 'green'
|
| 270 |
+
else:
|
| 271 |
+
title_color = 'red'
|
| 272 |
+
|
| 273 |
+
# Set title with prediction info
|
| 274 |
+
axes[i].set_title(f"Pred: {pred_class} ({probs[pred_idx]:.2f})\nTrue: {true_label}",
|
| 275 |
+
color=title_color)
|
| 276 |
+
axes[i].axis('off')
|
| 277 |
+
|
| 278 |
+
# Remove unused axes
|
| 279 |
+
for i in range(len(test_files), len(axes)):
|
| 280 |
+
fig.delaxes(axes[i])
|
| 281 |
+
|
| 282 |
+
plt.tight_layout()
|
| 283 |
+
plt.savefig('evaluation_sample_predictions.png')
|
| 284 |
+
plt.close()
|
| 285 |
+
|
| 286 |
+
def evaluate_class_wise_metrics(targets, pred_classes, class_names):
|
| 287 |
+
"""
|
| 288 |
+
Calculates and visualizes class-wise performance metrics
|
| 289 |
+
"""
|
| 290 |
+
print("=" * 50)
|
| 291 |
+
print("CLASS-WISE PERFORMANCE METRICS")
|
| 292 |
+
print("=" * 50)
|
| 293 |
+
|
| 294 |
+
# Calculate per-class metrics from confusion matrix
|
| 295 |
+
cm = confusion_matrix(targets, pred_classes)
|
| 296 |
+
|
| 297 |
+
# Per-class accuracy (diagonal divided by row sum)
|
| 298 |
+
class_accuracy = np.diag(cm) / np.sum(cm, axis=1)
|
| 299 |
+
|
| 300 |
+
# Per-class precision (diagonal divided by column sum)
|
| 301 |
+
class_precision = np.diag(cm) / np.sum(cm, axis=0)
|
| 302 |
+
|
| 303 |
+
# Per-class recall (same as per-class accuracy)
|
| 304 |
+
class_recall = class_accuracy
|
| 305 |
+
|
| 306 |
+
# Per-class F1 score
|
| 307 |
+
class_f1 = 2 * (class_precision * class_recall) / (class_precision + class_recall)
|
| 308 |
+
|
| 309 |
+
# Create dataframe for metrics
|
| 310 |
+
metrics_df = pd.DataFrame({
|
| 311 |
+
'Class': class_names,
|
| 312 |
+
'Accuracy': class_accuracy,
|
| 313 |
+
'Precision': class_precision,
|
| 314 |
+
'Recall': class_recall,
|
| 315 |
+
'F1 Score': class_f1
|
| 316 |
+
})
|
| 317 |
+
|
| 318 |
+
print(metrics_df)
|
| 319 |
+
|
| 320 |
+
# Save metrics to CSV
|
| 321 |
+
metrics_df.to_csv('class_wise_metrics.csv', index=False)
|
| 322 |
+
|
| 323 |
+
# Plot metrics
|
| 324 |
+
plt.figure(figsize=(12, 8))
|
| 325 |
+
|
| 326 |
+
# Create bar positions
|
| 327 |
+
x = np.arange(len(class_names))
|
| 328 |
+
width = 0.2
|
| 329 |
+
|
| 330 |
+
# Plot bars
|
| 331 |
+
plt.bar(x - width*1.5, class_accuracy, width, label='Accuracy')
|
| 332 |
+
plt.bar(x - width/2, class_precision, width, label='Precision')
|
| 333 |
+
plt.bar(x + width/2, class_recall, width, label='Recall')
|
| 334 |
+
plt.bar(x + width*1.5, class_f1, width, label='F1')
|
| 335 |
+
|
| 336 |
+
# Add labels and legend
|
| 337 |
+
plt.xlabel('Class')
|
| 338 |
+
plt.ylabel('Score')
|
| 339 |
+
plt.title('Class-wise Performance Metrics')
|
| 340 |
+
plt.xticks(x, class_names, rotation=45)
|
| 341 |
+
plt.legend()
|
| 342 |
+
plt.tight_layout()
|
| 343 |
+
plt.savefig('evaluation_class_wise_metrics.png')
|
| 344 |
+
plt.close()
|
| 345 |
+
|
| 346 |
+
return metrics_df
|
| 347 |
+
|
| 348 |
+
def main():
|
| 349 |
+
"""
|
| 350 |
+
Main evaluation function
|
| 351 |
+
"""
|
| 352 |
+
print("OPTIMIZED EMOTION RECOGNITION MODEL EVALUATION")
|
| 353 |
+
print("=" * 50)
|
| 354 |
+
|
| 355 |
+
# Load trained model
|
| 356 |
+
model_path = 'optimized_emotion_classifier.pkl'
|
| 357 |
+
learn = load_model(model_path)
|
| 358 |
+
|
| 359 |
+
if learn is None:
|
| 360 |
+
print("Failed to load model. Trying backup model...")
|
| 361 |
+
model_path = 'emotion_classifier.pkl'
|
| 362 |
+
learn = load_model(model_path)
|
| 363 |
+
|
| 364 |
+
if learn is None:
|
| 365 |
+
print("Could not load any model. Evaluation aborted.")
|
| 366 |
+
return
|
| 367 |
+
|
| 368 |
+
# Evaluate model accuracy
|
| 369 |
+
preds, targets, class_names = evaluate_accuracy(learn)
|
| 370 |
+
pred_classes = preds.argmax(dim=1)
|
| 371 |
+
|
| 372 |
+
# Visualize confusion matrix
|
| 373 |
+
cm, cm_norm = visualize_confusion_matrix(targets, pred_classes, class_names)
|
| 374 |
+
|
| 375 |
+
# Calculate ROC curves
|
| 376 |
+
roc_auc = visualize_roc_curves(preds, targets, class_names)
|
| 377 |
+
|
| 378 |
+
# Calculate precision-recall curves
|
| 379 |
+
avg_precision = visualize_precision_recall(preds, targets, class_names)
|
| 380 |
+
|
| 381 |
+
# Analyze prediction confidence
|
| 382 |
+
avg_correct_conf, avg_incorrect_conf = analyze_prediction_confidence(preds, targets, class_names)
|
| 383 |
+
|
| 384 |
+
# Visualize sample predictions
|
| 385 |
+
visualize_sample_predictions(learn, num_samples=9)
|
| 386 |
+
|
| 387 |
+
# Calculate class-wise metrics
|
| 388 |
+
metrics_df = evaluate_class_wise_metrics(targets, pred_classes, class_names)
|
| 389 |
+
|
| 390 |
+
print("=" * 50)
|
| 391 |
+
print("EVALUATION SUMMARY")
|
| 392 |
+
print("=" * 50)
|
| 393 |
+
|
| 394 |
+
# Overall accuracy
|
| 395 |
+
accuracy = accuracy_score(targets, pred_classes)
|
| 396 |
+
print(f"Overall accuracy: {accuracy:.4f}")
|
| 397 |
+
|
| 398 |
+
# Average metrics across classes
|
| 399 |
+
print(f"Average ROC AUC: {np.mean(list(roc_auc.values())):.4f}")
|
| 400 |
+
print(f"Average Precision: {np.mean(list(avg_precision.values())):.4f}")
|
| 401 |
+
|
| 402 |
+
# Confidence gap
|
| 403 |
+
conf_gap = avg_correct_conf - avg_incorrect_conf
|
| 404 |
+
print(f"Confidence gap (correct-incorrect): {conf_gap:.4f}")
|
| 405 |
+
|
| 406 |
+
print("\nClass-wise F1 scores:")
|
| 407 |
+
for i, class_name in enumerate(class_names):
|
| 408 |
+
print(f" {class_name}: {metrics_df['F1 Score'][i]:.4f}")
|
| 409 |
+
|
| 410 |
+
print("\nEvaluation completed. Results saved to files.")
|
| 411 |
+
|
| 412 |
+
if __name__ == "__main__":
|
| 413 |
+
main()
|
optimized_emotion_model.py
ADDED
|
@@ -0,0 +1,745 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Optimized Emotion Recognition Model Training
|
| 3 |
+
This script implements a comprehensive deep learning training pipeline with
|
| 4 |
+
optimized parameters for emotion recognition tasks.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
from fastai.vision.all import *
|
| 12 |
+
from sklearn.metrics import classification_report, confusion_matrix
|
| 13 |
+
import seaborn as sns
|
| 14 |
+
from fastai.metrics import accuracy, error_rate
|
| 15 |
+
|
| 16 |
+
accuracy_metric = accuracy
|
| 17 |
+
error_rate_metric = error_rate
|
| 18 |
+
|
| 19 |
+
# A.1. Download the data
|
| 20 |
+
# In this case, we're using the existing EMOTION RECOGNITION DATASET
|
| 21 |
+
def inspect_dataset(data_path):
|
| 22 |
+
"""
|
| 23 |
+
A.1.1. Inspect the data layout
|
| 24 |
+
Analyzes the dataset structure and distribution
|
| 25 |
+
"""
|
| 26 |
+
print("=" * 50)
|
| 27 |
+
print("DATASET INSPECTION")
|
| 28 |
+
print("=" * 50)
|
| 29 |
+
|
| 30 |
+
# Check available classes
|
| 31 |
+
classes = os.listdir(data_path)
|
| 32 |
+
print(f"Classes in the dataset: {classes}")
|
| 33 |
+
|
| 34 |
+
# Count images per class
|
| 35 |
+
class_counts = {}
|
| 36 |
+
total_images = 0
|
| 37 |
+
for emotion in classes:
|
| 38 |
+
files = os.listdir(data_path/emotion)
|
| 39 |
+
class_counts[emotion] = len(files)
|
| 40 |
+
total_images += len(files)
|
| 41 |
+
print(f"{emotion} class: {len(files)} images")
|
| 42 |
+
|
| 43 |
+
print(f"Total images: {total_images}")
|
| 44 |
+
|
| 45 |
+
# Plot class distribution
|
| 46 |
+
plt.figure(figsize=(10, 6))
|
| 47 |
+
plt.bar(class_counts.keys(), class_counts.values())
|
| 48 |
+
plt.title('Class Distribution')
|
| 49 |
+
plt.xlabel('Emotion')
|
| 50 |
+
plt.ylabel('Number of Images')
|
| 51 |
+
plt.savefig('class_distribution.png')
|
| 52 |
+
plt.close()
|
| 53 |
+
|
| 54 |
+
return classes, class_counts
|
| 55 |
+
|
| 56 |
+
# A.2. Create the DataBlock and dataloaders
|
| 57 |
+
def create_datablock(data_path, img_size=224, batch_size=64, valid_pct=0.2):
|
| 58 |
+
"""
|
| 59 |
+
Creates an optimized DataBlock for training
|
| 60 |
+
A.1.2 Decision on how to create datablock based on dataset structure
|
| 61 |
+
A.2.1 It defined the blocks
|
| 62 |
+
A.2.2 It defined the means of getting data into DataBlock
|
| 63 |
+
A.2.3 It defined how to get the attributes
|
| 64 |
+
A.2.4 It defined data transformations with presizing
|
| 65 |
+
"""
|
| 66 |
+
print("=" * 50)
|
| 67 |
+
print("CREATING DATABLOCK")
|
| 68 |
+
print("=" * 50)
|
| 69 |
+
|
| 70 |
+
# A.1.2 Decision: We'll use parent folder name as the label for classification
|
| 71 |
+
print("A.1.2: Using folder structure for class labels, with parent_label")
|
| 72 |
+
|
| 73 |
+
# Define the DataBlock with all required components
|
| 74 |
+
emotion_data = DataBlock(
|
| 75 |
+
# A.2.1 Define blocks (input and target types)
|
| 76 |
+
blocks=(ImageBlock, CategoryBlock),
|
| 77 |
+
|
| 78 |
+
# A.2.2 Define how to get data
|
| 79 |
+
get_items=get_image_files,
|
| 80 |
+
|
| 81 |
+
# Data splitting strategy - random with fixed seed for reproducibility
|
| 82 |
+
splitter=RandomSplitter(valid_pct=valid_pct, seed=42),
|
| 83 |
+
|
| 84 |
+
# A.2.3 Define how to get labels (from parent folder name)
|
| 85 |
+
get_y=parent_label,
|
| 86 |
+
|
| 87 |
+
# A.2.4 Define transformations with presizing strategy
|
| 88 |
+
# First resize (item by item)
|
| 89 |
+
item_tfms=[Resize(img_size, method='squish')],
|
| 90 |
+
|
| 91 |
+
# Then apply augmentations (batch by batch)
|
| 92 |
+
batch_tfms=[
|
| 93 |
+
# Augmentations (applied to batch)
|
| 94 |
+
*aug_transforms(size=img_size-32, min_scale=0.75,
|
| 95 |
+
flip_vert=False, max_rotate=10.0, max_zoom=1.1),
|
| 96 |
+
# Normalize using ImageNet stats
|
| 97 |
+
Normalize.from_stats(*imagenet_stats)
|
| 98 |
+
]
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
# Create dataloaders
|
| 102 |
+
print(f"Creating dataloaders with batch size: {batch_size}")
|
| 103 |
+
dls = emotion_data.dataloaders(data_path, bs=batch_size, num_workers=0)
|
| 104 |
+
|
| 105 |
+
return emotion_data, dls
|
| 106 |
+
|
| 107 |
+
# A.3. Inspect the DataBlock via dataloader
|
| 108 |
+
def inspect_datablock(emotion_data, dls, data_path):
|
| 109 |
+
"""
|
| 110 |
+
Inspects the created DataBlock and visualizes samples
|
| 111 |
+
A.3.1 Show batch
|
| 112 |
+
A.3.2 Check the labels
|
| 113 |
+
A.3.3 Summarize the DataBlock
|
| 114 |
+
"""
|
| 115 |
+
print("=" * 50)
|
| 116 |
+
print("INSPECTING DATABLOCK")
|
| 117 |
+
print("=" * 50)
|
| 118 |
+
|
| 119 |
+
# A.3.1 Show batch
|
| 120 |
+
print("Displaying sample batch...")
|
| 121 |
+
dls.show_batch(max_n=9, figsize=(12, 10))
|
| 122 |
+
plt.savefig('sample_batch.png')
|
| 123 |
+
plt.close()
|
| 124 |
+
|
| 125 |
+
# A.3.2 Check labels
|
| 126 |
+
print(f"Classes (labels): {dls.vocab}")
|
| 127 |
+
|
| 128 |
+
# A.3.3 DataBlock summary
|
| 129 |
+
print("DataBlock summary:")
|
| 130 |
+
emotion_data.summary(data_path)
|
| 131 |
+
|
| 132 |
+
return dls.vocab
|
| 133 |
+
|
| 134 |
+
# A.4. Train a simple model (benchmark)
|
| 135 |
+
def train_benchmark_model(dls, model_name='resnet18'):
|
| 136 |
+
"""
|
| 137 |
+
A.4.1 Create a benchmark model for comparison
|
| 138 |
+
"""
|
| 139 |
+
print("=" * 50)
|
| 140 |
+
print(f"TRAINING BENCHMARK MODEL: {model_name}")
|
| 141 |
+
print("=" * 50)
|
| 142 |
+
|
| 143 |
+
# Create learner with simple architecture
|
| 144 |
+
learn = vision_learner(dls,
|
| 145 |
+
eval(model_name),
|
| 146 |
+
metrics=[error_rate_metric, accuracy_metric])
|
| 147 |
+
|
| 148 |
+
# Quick training for benchmark
|
| 149 |
+
learn.fine_tune(3, base_lr=1e-2)
|
| 150 |
+
|
| 151 |
+
# Save benchmark results
|
| 152 |
+
learn.save('benchmark_model')
|
| 153 |
+
|
| 154 |
+
# A.4.2 & A.4.3 Interpret model and create confusion matrix
|
| 155 |
+
interpret_model(learn, "benchmark")
|
| 156 |
+
|
| 157 |
+
return learn
|
| 158 |
+
|
| 159 |
+
# Helper function to interpret model performance
|
| 160 |
+
def interpret_model(learn, name_prefix=""):
|
| 161 |
+
"""
|
| 162 |
+
A.4.2 Interpret the model
|
| 163 |
+
A.4.3 Confusion matrix
|
| 164 |
+
Creates visualizations to understand model performance
|
| 165 |
+
"""
|
| 166 |
+
interp = ClassificationInterpretation.from_learner(learn)
|
| 167 |
+
|
| 168 |
+
# Plot confusion matrix
|
| 169 |
+
plt.figure(figsize=(10, 8))
|
| 170 |
+
interp.plot_confusion_matrix(figsize=(10, 8))
|
| 171 |
+
plt.savefig(f'{name_prefix}_confusion_matrix.png')
|
| 172 |
+
plt.close()
|
| 173 |
+
|
| 174 |
+
# Plot top losses
|
| 175 |
+
interp.plot_top_losses(9, figsize=(12, 10))
|
| 176 |
+
plt.savefig(f'{name_prefix}_top_losses.png')
|
| 177 |
+
plt.close()
|
| 178 |
+
|
| 179 |
+
# Compute classification report
|
| 180 |
+
probs, targets = learn.get_preds()
|
| 181 |
+
preds = probs.argmax(dim=1)
|
| 182 |
+
report = classification_report(targets, preds, target_names=learn.dls.vocab)
|
| 183 |
+
print(f"Classification Report:\n{report}")
|
| 184 |
+
|
| 185 |
+
# Save report to file
|
| 186 |
+
with open(f'{name_prefix}_report.txt', 'w') as f:
|
| 187 |
+
f.write(report)
|
| 188 |
+
|
| 189 |
+
return interp
|
| 190 |
+
|
| 191 |
+
# B.1 and B.2 Learning Rate Finder
|
| 192 |
+
def find_learning_rate(learn):
|
| 193 |
+
"""
|
| 194 |
+
B.1 & B.2 Implements learning rate finder
|
| 195 |
+
Helps find optimal learning rate for training
|
| 196 |
+
"""
|
| 197 |
+
print("=" * 50)
|
| 198 |
+
print("FINDING OPTIMAL LEARNING RATE")
|
| 199 |
+
print("=" * 50)
|
| 200 |
+
|
| 201 |
+
# Run learning rate finder
|
| 202 |
+
lr_suggestions = learn.lr_find(suggest_funcs=(valley, slide))
|
| 203 |
+
|
| 204 |
+
# Plot and save results
|
| 205 |
+
plt.savefig('learning_rate_finder.png')
|
| 206 |
+
plt.close()
|
| 207 |
+
|
| 208 |
+
print(f"Suggested learning rates: {lr_suggestions}")
|
| 209 |
+
|
| 210 |
+
return lr_suggestions
|
| 211 |
+
|
| 212 |
+
# Custom manual learning rate finder implementation
|
| 213 |
+
def manual_lr_finder(learn, start_lr=1e-7, factor=2, max_iterations=100, max_lr=10):
|
| 214 |
+
"""
|
| 215 |
+
B.2 Custom implementation of learning rate finder
|
| 216 |
+
B.2.1 Start with a very very low lr
|
| 217 |
+
B.2.2 Train one batch with lr, record loss
|
| 218 |
+
B.2.3 Increase lr to 2*lr
|
| 219 |
+
B.2.4 Train one batch with 2*lr, record the new loss
|
| 220 |
+
B.2.5 If the new loss is smaller than old loss, continue increasing
|
| 221 |
+
|
| 222 |
+
Args:
|
| 223 |
+
learn: fastai Learner object
|
| 224 |
+
start_lr: initial learning rate (very low)
|
| 225 |
+
factor: multiplier for increasing learning rate
|
| 226 |
+
max_iterations: maximum number of iterations
|
| 227 |
+
max_lr: maximum learning rate to try
|
| 228 |
+
|
| 229 |
+
Returns:
|
| 230 |
+
optimal_lr: the learning rate that gave the best loss
|
| 231 |
+
lrs: list of all learning rates tried
|
| 232 |
+
losses: list of corresponding losses
|
| 233 |
+
"""
|
| 234 |
+
print("=" * 50)
|
| 235 |
+
print("RUNNING MANUAL LEARNING RATE FINDER")
|
| 236 |
+
print("=" * 50)
|
| 237 |
+
|
| 238 |
+
# Save the current learning rate and model weights
|
| 239 |
+
original_lr = learn.opt.hypers[0]['lr']
|
| 240 |
+
learn.save('temp_weights')
|
| 241 |
+
|
| 242 |
+
# Initialize variables
|
| 243 |
+
lr = start_lr
|
| 244 |
+
lrs = []
|
| 245 |
+
losses = []
|
| 246 |
+
best_loss = float('inf')
|
| 247 |
+
optimal_lr = start_lr
|
| 248 |
+
|
| 249 |
+
# Get a single batch from training data
|
| 250 |
+
xb, yb = next(iter(learn.dls.train))
|
| 251 |
+
|
| 252 |
+
print(f"Starting with learning rate: {lr}")
|
| 253 |
+
|
| 254 |
+
# Implement learning rate finder algorithm
|
| 255 |
+
for i in range(max_iterations):
|
| 256 |
+
# B.2.1 & B.2.3: Set the current learning rate
|
| 257 |
+
learn.opt.set_hyper('lr', lr)
|
| 258 |
+
|
| 259 |
+
# B.2.2 & B.2.4: Train on one batch and record loss
|
| 260 |
+
learn.opt.zero_grad()
|
| 261 |
+
loss = learn.loss_func(learn.model(xb), yb)
|
| 262 |
+
loss.backward()
|
| 263 |
+
learn.opt.step()
|
| 264 |
+
|
| 265 |
+
# Record results
|
| 266 |
+
lrs.append(lr)
|
| 267 |
+
current_loss = loss.item()
|
| 268 |
+
losses.append(current_loss)
|
| 269 |
+
|
| 270 |
+
print(f"Iteration {i+1}: lr={lr:.8f}, loss={current_loss:.4f}")
|
| 271 |
+
|
| 272 |
+
# Update best learning rate if this loss is better
|
| 273 |
+
if current_loss < best_loss:
|
| 274 |
+
best_loss = current_loss
|
| 275 |
+
optimal_lr = lr
|
| 276 |
+
|
| 277 |
+
# Check if loss is NaN or too large (diverging)
|
| 278 |
+
if np.isnan(current_loss) or current_loss > 4 * best_loss:
|
| 279 |
+
print(f"Stopping early: loss is {'NaN' if np.isnan(current_loss) else 'diverging'}")
|
| 280 |
+
break
|
| 281 |
+
|
| 282 |
+
# B.2.3 & B.2.5: Increase lr by factor for next iteration
|
| 283 |
+
lr *= factor
|
| 284 |
+
|
| 285 |
+
# Stop if lr exceeds max_lr
|
| 286 |
+
if lr > max_lr:
|
| 287 |
+
print(f"Reached maximum learning rate: {max_lr}")
|
| 288 |
+
break
|
| 289 |
+
|
| 290 |
+
# Plot learning rate vs loss
|
| 291 |
+
plt.figure(figsize=(10, 6))
|
| 292 |
+
plt.plot(lrs, losses)
|
| 293 |
+
plt.xscale('log')
|
| 294 |
+
plt.xlabel('Learning Rate (log scale)')
|
| 295 |
+
plt.ylabel('Loss')
|
| 296 |
+
plt.title('Manual Learning Rate Finder')
|
| 297 |
+
plt.axvline(x=optimal_lr, color='r', linestyle='--')
|
| 298 |
+
plt.savefig('manual_lr_finder.png')
|
| 299 |
+
plt.close()
|
| 300 |
+
|
| 301 |
+
# Restore original settings
|
| 302 |
+
learn.opt.set_hyper('lr', original_lr)
|
| 303 |
+
learn.load('temp_weights')
|
| 304 |
+
|
| 305 |
+
print(f"Optimal learning rate found: {optimal_lr}")
|
| 306 |
+
|
| 307 |
+
return optimal_lr, lrs, losses
|
| 308 |
+
|
| 309 |
+
# B.3 Implementation of Transfer Learning strategies
|
| 310 |
+
def compare_transfer_learning_strategies(dls, model_arch='resnet34', lr=1e-3, epochs=1):
|
| 311 |
+
"""
|
| 312 |
+
B.3 Compares different transfer learning strategies:
|
| 313 |
+
|
| 314 |
+
B.3.1 Replace the trained final linear layer F with a new one F'
|
| 315 |
+
(where size changes from mxn to mxk, k is the number of classes in new task)
|
| 316 |
+
B.3.2 Only train F' while keeping previous weights frozen
|
| 317 |
+
B.3.3 Continue training all weights (F' and previous) by unfreezing
|
| 318 |
+
|
| 319 |
+
Args:
|
| 320 |
+
dls: DataLoaders object
|
| 321 |
+
model_arch: Model architecture to use
|
| 322 |
+
lr: Learning rate to use
|
| 323 |
+
epochs: Number of epochs for each strategy
|
| 324 |
+
|
| 325 |
+
Returns:
|
| 326 |
+
best_strategy: The learner with the best strategy applied
|
| 327 |
+
results: Dictionary with results of all strategies
|
| 328 |
+
"""
|
| 329 |
+
print("=" * 50)
|
| 330 |
+
print("COMPARING TRANSFER LEARNING STRATEGIES")
|
| 331 |
+
print("=" * 50)
|
| 332 |
+
|
| 333 |
+
results = {}
|
| 334 |
+
|
| 335 |
+
# Strategy 1: Train only the head (final layer) with frozen body
|
| 336 |
+
print("\nStrategy 1: Train only the head (final layer) with frozen body")
|
| 337 |
+
learn1 = vision_learner(dls, eval(model_arch), metrics=[error_rate_metric, accuracy_metric], pretrained=True) # B.3.3 vision_learner fonksiyonu çağrıldığında, önceden eğitilmiş bir model yüklenir ve son katmanı (head) otomatik olarak yeni görevin sınıf sayısına uygun şekilde değiştirilir. Bu, tam olarak eski son doğrusal katman F'yi (boyutu mxn) yeni bir F' (boyutu mxk) ile değiştirme işlemidir.
|
| 338 |
+
|
| 339 |
+
# The model comes with pre-trained weights, and the head is already replaced with a new one
|
| 340 |
+
# matching the number of classes in our task (this happens automatically in vision_learner)
|
| 341 |
+
print("Model summary before training:")
|
| 342 |
+
print(learn1.model)
|
| 343 |
+
|
| 344 |
+
# Train only the head (Freezing)
|
| 345 |
+
learn1.fit_one_cycle(epochs, lr) #B.3. Only train F' while using the previous weights unchanged (called Freezing)
|
| 346 |
+
|
| 347 |
+
# Evaluate Strategy 1
|
| 348 |
+
metrics = learn1.validate()
|
| 349 |
+
frozen_acc = metrics[2] # Accuracy değeri (3. eleman)
|
| 350 |
+
results['frozen_head_only'] = frozen_acc
|
| 351 |
+
print(f"Accuracy with frozen body, trained head: {frozen_acc:.4f}")
|
| 352 |
+
learn1.save('transfer_strategy1')
|
| 353 |
+
|
| 354 |
+
# Strategy 2: Progressive unfreezing (train final layer, then unfreeze gradually)
|
| 355 |
+
print("\nStrategy 2: Progressive unfreezing")
|
| 356 |
+
learn2 = vision_learner(dls, eval(model_arch), metrics=[error_rate_metric, accuracy_metric], pretrained=True)
|
| 357 |
+
#B.3. Continue training all the weights F' and previous weights (called Unfreezing)
|
| 358 |
+
# First train only the head
|
| 359 |
+
learn2.fit_one_cycle(epochs, lr)
|
| 360 |
+
|
| 361 |
+
# Then unfreeze all layers and train with discriminative learning rates
|
| 362 |
+
learn2.unfreeze()
|
| 363 |
+
|
| 364 |
+
# B.1. Learning Rate considerations:
|
| 365 |
+
# B.1.1 Using too large learning rates: Convergence will be poor, if not impossible
|
| 366 |
+
# B.1.2 Using too small learning rates: Convergence will be slow, takes too many epochs, risks overfitting
|
| 367 |
+
|
| 368 |
+
# B.2. This implementation applies similar principles to Learning Rate Finder:
|
| 369 |
+
# B.2.1 We determined a base learning rate (lr) through experimentation
|
| 370 |
+
# B.2.2 For early layers (general features), we use a smaller lr (lr/100) to make minor tweaks
|
| 371 |
+
# B.2.3 For later layers (specific features), we use a larger lr (lr/10) for more significant updates
|
| 372 |
+
# B.2.4 This forms a "slice" of learning rates that increases from early to later layers
|
| 373 |
+
# B.2.5 Similar to how Learning Rate Finder gradually increases lr to find optimal values
|
| 374 |
+
learn2.fit_one_cycle(epochs, slice(lr/100, lr/10))
|
| 375 |
+
|
| 376 |
+
# The slice(lr/100, lr/10) creates a range of learning rates:
|
| 377 |
+
# - Early layers: lr/100 (very small adjustments to preserve general features)
|
| 378 |
+
# - Middle layers: gradually increasing rates
|
| 379 |
+
# - Later layers: lr/10 (larger adjustments to adapt specific features to our task)
|
| 380 |
+
# This is an application of the ULMFIT approach for transfer learning
|
| 381 |
+
|
| 382 |
+
# Evaluate Strategy 2
|
| 383 |
+
metrics = learn2.validate()
|
| 384 |
+
progressive_acc = metrics[2] # Accuracy
|
| 385 |
+
results['progressive_unfreezing'] = progressive_acc
|
| 386 |
+
print(f"Accuracy with progressive unfreezing: {progressive_acc:.4f}")
|
| 387 |
+
learn2.save('transfer_strategy2')
|
| 388 |
+
|
| 389 |
+
# Strategy 3: Fine-tuning (fastai's recommended approach)
|
| 390 |
+
print("\nStrategy 3: Fine-tuning (fastai's approach)")
|
| 391 |
+
learn3 = vision_learner(dls, eval(model_arch), metrics=[error_rate_metric, accuracy_metric], pretrained=True)
|
| 392 |
+
|
| 393 |
+
# fine_tune automatically does: train head, unfreeze, train all
|
| 394 |
+
learn3.fine_tune(epochs)
|
| 395 |
+
|
| 396 |
+
# Evaluate Strategy 3
|
| 397 |
+
metrics = learn3.validate()
|
| 398 |
+
finetune_acc = metrics[2] # Accuracy
|
| 399 |
+
results['fine_tune'] = finetune_acc
|
| 400 |
+
print(f"Accuracy with fine_tune: {finetune_acc:.4f}")
|
| 401 |
+
learn3.save('transfer_strategy3')
|
| 402 |
+
|
| 403 |
+
# Determine best strategy
|
| 404 |
+
best_acc = max(results.values())
|
| 405 |
+
best_strategy_name = [k for k, v in results.items() if v == best_acc][0]
|
| 406 |
+
|
| 407 |
+
print("\nTransfer Learning Strategy Comparison:")
|
| 408 |
+
for strategy, accuracy in results.items():
|
| 409 |
+
print(f"{strategy}: {accuracy:.4f}")
|
| 410 |
+
|
| 411 |
+
print(f"\nBest strategy: {best_strategy_name} with accuracy {best_acc:.4f}")
|
| 412 |
+
|
| 413 |
+
# Plot results
|
| 414 |
+
plt.figure(figsize=(10, 6))
|
| 415 |
+
strategies = list(results.keys())
|
| 416 |
+
accuracies = [results[s] for s in strategies]
|
| 417 |
+
|
| 418 |
+
plt.bar(strategies, accuracies)
|
| 419 |
+
plt.ylim(0, 1.0)
|
| 420 |
+
plt.xlabel('Transfer Learning Strategy')
|
| 421 |
+
plt.ylabel('Validation Accuracy')
|
| 422 |
+
plt.title('Comparison of Transfer Learning Strategies')
|
| 423 |
+
plt.savefig('transfer_learning_comparison.png')
|
| 424 |
+
plt.close()
|
| 425 |
+
|
| 426 |
+
# Return the best learner
|
| 427 |
+
if best_strategy_name == 'frozen_head_only':
|
| 428 |
+
return learn1, results
|
| 429 |
+
elif best_strategy_name == 'progressive_unfreezing':
|
| 430 |
+
return learn2, results
|
| 431 |
+
else:
|
| 432 |
+
return learn3, results
|
| 433 |
+
|
| 434 |
+
# B.3-B.7 Advanced training with all optimizations
|
| 435 |
+
def train_optimized_model(dls, model_arch='resnet34', epochs=3, batch_size=32):
|
| 436 |
+
"""
|
| 437 |
+
Trains a model with all optimizations including:
|
| 438 |
+
- B.3 Transfer Learning
|
| 439 |
+
- B.4 Discriminative Learning Rates
|
| 440 |
+
- B.5 Optimal Number of Training Epochs
|
| 441 |
+
- B.6 Model Capacity Adjustments
|
| 442 |
+
- B.7 Proper Weight Initialization
|
| 443 |
+
"""
|
| 444 |
+
print("=" * 50)
|
| 445 |
+
print(f"TRAINING OPTIMIZED MODEL: {model_arch}")
|
| 446 |
+
print("=" * 50)
|
| 447 |
+
|
| 448 |
+
# B.6.1: When increasing the model capacity, ensure smaller batch size
|
| 449 |
+
print(f"B.6.1: Using batch size {batch_size} for model {model_arch}")
|
| 450 |
+
|
| 451 |
+
# Create learner with selected architecture (B.6 Model Capacity)
|
| 452 |
+
# Transfer learning is automatically applied by using a pretrained model (B.3)
|
| 453 |
+
# B.3: Replace the final linear layer F with a new one F' for our classification task
|
| 454 |
+
print("B.3: Applying transfer learning - replacing final linear layer with new one for our task")
|
| 455 |
+
learn = vision_learner(dls,
|
| 456 |
+
eval(model_arch),
|
| 457 |
+
metrics=[error_rate_metric, accuracy_metric],
|
| 458 |
+
pretrained=True)
|
| 459 |
+
|
| 460 |
+
# B.6.3 Apply mixed precision training to optimize memory usage and speed
|
| 461 |
+
# This is similar to quantization in LLMs - using float16 instead of float32
|
| 462 |
+
learn.to_fp16()
|
| 463 |
+
print("B.6.3: Applied mixed precision training (float16) for better GPU memory usage")
|
| 464 |
+
|
| 465 |
+
# B.2: Learning Rate Finder implementation
|
| 466 |
+
# B.2.1-B.2.5: Finding optimal learning rate by starting with very low lr and gradually increasing
|
| 467 |
+
print("B.2: Running learning rate finder to determine optimal learning rate")
|
| 468 |
+
print("B.2.1-B.2.5: Starting with very low learning rate and gradually increasing")
|
| 469 |
+
|
| 470 |
+
# Use built-in learning rate finder
|
| 471 |
+
lr_suggestions = find_learning_rate(learn)
|
| 472 |
+
fastai_suggested_lr = lr_suggestions[0]
|
| 473 |
+
|
| 474 |
+
# Our manual learning rate finder implementation
|
| 475 |
+
optimal_lr, lrs, losses = manual_lr_finder(learn, start_lr=1e-7, factor=3, max_iterations=15)
|
| 476 |
+
|
| 477 |
+
print(f"Learning rate finder results:")
|
| 478 |
+
print(f"- Fastai suggested learning rate: {fastai_suggested_lr}")
|
| 479 |
+
print(f"- Our manual finder suggested: {optimal_lr}")
|
| 480 |
+
|
| 481 |
+
# Select final learning rate based on finder results
|
| 482 |
+
final_lr = fastai_suggested_lr
|
| 483 |
+
print(f"Selected learning rate: {final_lr}")
|
| 484 |
+
|
| 485 |
+
# B.4: Discriminative Learning Rates
|
| 486 |
+
print("B.4: Applying Discriminative Learning Rates")
|
| 487 |
+
print("B.4.1-B.4.2: Earlier layers need smaller learning rates, newer layers need larger tweaks")
|
| 488 |
+
print("B.4.3-B.4.5: Earlier layers have more general features, later layers have more specific features")
|
| 489 |
+
|
| 490 |
+
# B.5: Deciding the Number of Training Epochs
|
| 491 |
+
print("B.5: Optimal epoch selection instead of early stopping")
|
| 492 |
+
print("B.5.1-B.5.3: We don't use early stopping as it may counteract learning rate finder")
|
| 493 |
+
|
| 494 |
+
# B.3: Applying best transfer learning strategy
|
| 495 |
+
# First train only the new head (freezing pre-trained layers)
|
| 496 |
+
print("B.3: Training only the new head first (Freezing)")
|
| 497 |
+
learn.fit_one_cycle(1, final_lr)
|
| 498 |
+
|
| 499 |
+
# Then unfreeze all layers and train with discriminative learning rates
|
| 500 |
+
print("B.3: Continue training all weights (Unfreezing)")
|
| 501 |
+
print("B.4.6-B.4.7: Using slice of learning rates - smaller for early layers, larger for later layers")
|
| 502 |
+
|
| 503 |
+
# B.4.6-B.4.7: Apply discriminative learning rates using slice
|
| 504 |
+
# Early layers get smaller learning rate (lr/100), later layers get higher (lr/10)
|
| 505 |
+
learn.unfreeze()
|
| 506 |
+
|
| 507 |
+
# B.1. Learning Rate considerations:
|
| 508 |
+
# B.1.1 Using too large learning rates: Convergence will be poor, if not impossible
|
| 509 |
+
# B.1.2 Using too small learning rates: Convergence will be slow, takes too many epochs, risks overfitting
|
| 510 |
+
|
| 511 |
+
# B.2. This implementation applies principles derived from Learning Rate Finder:
|
| 512 |
+
# B.2.1 We determined an optimal learning rate (final_lr) using learning rate finder
|
| 513 |
+
# B.2.2 For early layers (general features), we use a smaller lr (final_lr/100) for minimal adjustments
|
| 514 |
+
# B.2.3 For later layers (specific features), we use a larger lr (final_lr/10) for more significant updates
|
| 515 |
+
# B.2.4 This forms a "slice" of learning rates that increases from early to later layers
|
| 516 |
+
# B.2.5 Each layer gets an appropriate learning rate based on its depth in the network
|
| 517 |
+
learn.fit_one_cycle(epochs-1, slice(final_lr/100, final_lr/10))
|
| 518 |
+
|
| 519 |
+
# The slice(final_lr/100, final_lr/10) creates a range of learning rates:
|
| 520 |
+
# - Early layers: final_lr/100 (very small adjustments to preserve general features)
|
| 521 |
+
# - Middle layers: gradually increasing rates
|
| 522 |
+
# - Later layers: final_lr/10 (larger adjustments to adapt specific features to our task)
|
| 523 |
+
# This is an application of the ULMFIT approach for transfer learning, optimized based on our findings
|
| 524 |
+
|
| 525 |
+
# Save final model
|
| 526 |
+
learn.save('optimized_model_final')
|
| 527 |
+
learn.export('optimized_emotion_classifier.pkl')
|
| 528 |
+
|
| 529 |
+
# Evaluate final model
|
| 530 |
+
print("Evaluating final optimized model...")
|
| 531 |
+
interpret_model(learn, "optimized")
|
| 532 |
+
|
| 533 |
+
return learn
|
| 534 |
+
|
| 535 |
+
# Full training pipeline
|
| 536 |
+
def main():
|
| 537 |
+
# Set seed for reproducibility
|
| 538 |
+
set_seed(42, reproducible=True)
|
| 539 |
+
|
| 540 |
+
# Data path setup
|
| 541 |
+
data_path = Path("EMOTION RECOGNITION DATASET")
|
| 542 |
+
|
| 543 |
+
print("OPTIMIZED EMOTION RECOGNITION MODEL TRAINING")
|
| 544 |
+
print("=" * 50)
|
| 545 |
+
|
| 546 |
+
# A.1 Data inspection
|
| 547 |
+
classes, class_counts = inspect_dataset(data_path)
|
| 548 |
+
|
| 549 |
+
# A.2 Create DataBlock with optimal parameters for main model
|
| 550 |
+
result = create_datablock(
|
| 551 |
+
data_path,
|
| 552 |
+
img_size=224, # Standard size for most pretrained models
|
| 553 |
+
batch_size=32, # B.6.1: Adjusted smaller batch size for larger model capacity
|
| 554 |
+
valid_pct=0.2 # 80/20 split
|
| 555 |
+
)
|
| 556 |
+
emotion_data = result[0]
|
| 557 |
+
dls = result[1]
|
| 558 |
+
|
| 559 |
+
# A.3 Inspect DataBlock
|
| 560 |
+
class_names = inspect_datablock(emotion_data, dls, data_path)
|
| 561 |
+
|
| 562 |
+
# STEP 1: A.4 Train and save benchmark model (ResNet18)
|
| 563 |
+
print("\n=== STEP 1: TRAIN AND SAVE BENCHMARK MODEL (ResNet18) ===")
|
| 564 |
+
benchmark_model = train_benchmark_model(dls, model_name='resnet18')
|
| 565 |
+
|
| 566 |
+
# Benchmark model (ResNet18) metrics
|
| 567 |
+
metrics = benchmark_model.validate()
|
| 568 |
+
benchmark_valid_loss = metrics[0] # loss
|
| 569 |
+
benchmark_accuracy = metrics[2] # accuracy (metrics = [loss, error_rate, accuracy])
|
| 570 |
+
print(f"Benchmark Model (ResNet18) - Accuracy: {benchmark_accuracy:.4f}, Valid Loss: {benchmark_valid_loss:.4f}")
|
| 571 |
+
benchmark_model.save('benchmark_resnet18_model')
|
| 572 |
+
|
| 573 |
+
# STEP 2: Model comparison (ResNet18 vs ResNet34)
|
| 574 |
+
# B.6: Model Capacity - comparing different model capacities
|
| 575 |
+
print("\n=== STEP 2: MODEL COMPARISON (ResNet18 vs ResNet34) ===")
|
| 576 |
+
# Train ResNet34 model
|
| 577 |
+
print("Training ResNet34 model...")
|
| 578 |
+
resnet34_model = train_benchmark_model(dls, model_name='resnet34')
|
| 579 |
+
|
| 580 |
+
# ResNet34 metrics
|
| 581 |
+
metrics = resnet34_model.validate()
|
| 582 |
+
resnet34_valid_loss = metrics[0] # loss
|
| 583 |
+
resnet34_accuracy = metrics[2] # accuracy
|
| 584 |
+
print(f"ResNet34 Model - Accuracy: {resnet34_accuracy:.4f}, Valid Loss: {resnet34_valid_loss:.4f}")
|
| 585 |
+
resnet34_model.save('benchmark_resnet34_model')
|
| 586 |
+
|
| 587 |
+
# Compare both models and select the best one
|
| 588 |
+
print("\n--- MODEL COMPARISON RESULTS ---")
|
| 589 |
+
# B.6: Model Capacity Adjustments - Comparing models with different capacities (ResNet18 vs ResNet34)
|
| 590 |
+
# This replaces the more general compare_model_capacities function with a specific implementation
|
| 591 |
+
model_results = {
|
| 592 |
+
'resnet18': {
|
| 593 |
+
'accuracy': benchmark_accuracy,
|
| 594 |
+
'valid_loss': benchmark_valid_loss,
|
| 595 |
+
'model': benchmark_model
|
| 596 |
+
},
|
| 597 |
+
'resnet34': {
|
| 598 |
+
'accuracy': resnet34_accuracy,
|
| 599 |
+
'valid_loss': resnet34_valid_loss,
|
| 600 |
+
'model': resnet34_model
|
| 601 |
+
}
|
| 602 |
+
}
|
| 603 |
+
|
| 604 |
+
# Select model with highest accuracy
|
| 605 |
+
best_model_arch = max(model_results.keys(), key=lambda k: model_results[k]['accuracy'])
|
| 606 |
+
best_accuracy = model_results[best_model_arch]['accuracy']
|
| 607 |
+
best_model = model_results[best_model_arch]['model']
|
| 608 |
+
|
| 609 |
+
print(f"Model Comparison Results:")
|
| 610 |
+
for arch, metrics in model_results.items():
|
| 611 |
+
print(f"{arch}: Accuracy={metrics['accuracy']:.4f}, Valid Loss={metrics['valid_loss']:.4f}")
|
| 612 |
+
|
| 613 |
+
print(f"\nBest model: {best_model_arch} (accuracy: {best_accuracy:.4f})")
|
| 614 |
+
|
| 615 |
+
# Apply B.6.1: When increasing model capacity, decrease batch size
|
| 616 |
+
# This implements a key principle from compare_model_capacities function
|
| 617 |
+
best_batch_size = 32 if best_model_arch == 'resnet34' else 64
|
| 618 |
+
print(f"B.6.1: Using batch size {best_batch_size} for {best_model_arch}")
|
| 619 |
+
|
| 620 |
+
# Create comparison graph
|
| 621 |
+
plt.figure(figsize=(10, 6))
|
| 622 |
+
archs = list(model_results.keys())
|
| 623 |
+
accs = [model_results[arch]['accuracy'] for arch in archs]
|
| 624 |
+
losses = [model_results[arch]['valid_loss'] for arch in archs]
|
| 625 |
+
|
| 626 |
+
plt.subplot(1, 2, 1)
|
| 627 |
+
plt.bar(archs, accs)
|
| 628 |
+
plt.title('Model Accuracy Comparison')
|
| 629 |
+
plt.ylim(0, 1.0)
|
| 630 |
+
|
| 631 |
+
plt.subplot(1, 2, 2)
|
| 632 |
+
plt.bar(archs, losses)
|
| 633 |
+
plt.title('Model Loss Comparison')
|
| 634 |
+
|
| 635 |
+
plt.tight_layout()
|
| 636 |
+
plt.savefig('model_architecture_comparison.png')
|
| 637 |
+
plt.close()
|
| 638 |
+
|
| 639 |
+
# STEP 3: Transfer Learning Strategies Comparison
|
| 640 |
+
# B.3: Compare different transfer learning strategies
|
| 641 |
+
print(f"\n=== STEP 3: TRANSFER LEARNING STRATEGIES COMPARISON ({best_model_arch}) ===")
|
| 642 |
+
|
| 643 |
+
try:
|
| 644 |
+
# Create new DataLoader for selected best model architecture
|
| 645 |
+
result = create_datablock(
|
| 646 |
+
data_path,
|
| 647 |
+
img_size=224,
|
| 648 |
+
batch_size=best_batch_size, # B.6.1: Using appropriate batch size for model capacity
|
| 649 |
+
valid_pct=0.2
|
| 650 |
+
)
|
| 651 |
+
transfer_dls = result[1]
|
| 652 |
+
|
| 653 |
+
# B.3: Compare transfer learning strategies
|
| 654 |
+
print("B.3: Comparing different transfer learning strategies:")
|
| 655 |
+
print("1. Training only new head with frozen body")
|
| 656 |
+
print("2. Progressive unfreezing")
|
| 657 |
+
print("3. Fine-tuning (fastai approach)")
|
| 658 |
+
|
| 659 |
+
best_strategy_model, strategy_results = compare_transfer_learning_strategies(
|
| 660 |
+
transfer_dls,
|
| 661 |
+
model_arch=best_model_arch, # Use the best selected model architecture
|
| 662 |
+
lr=1e-3,
|
| 663 |
+
# B.5: Optimal Training Epochs - Using a fixed number instead of dynamically determining optimal epochs
|
| 664 |
+
# This replaces the determine_optimal_epochs function with a simpler approach
|
| 665 |
+
epochs=3
|
| 666 |
+
)
|
| 667 |
+
|
| 668 |
+
print(f"\nTransfer Learning Strategies Results:")
|
| 669 |
+
for strategy, accuracy in strategy_results.items():
|
| 670 |
+
print(f"{strategy}: {accuracy:.4f}")
|
| 671 |
+
|
| 672 |
+
# Determine best strategy
|
| 673 |
+
best_strategy = max(strategy_results.keys(), key=lambda k: strategy_results[k])
|
| 674 |
+
print(f"Best transfer strategy: {best_strategy}")
|
| 675 |
+
|
| 676 |
+
# Transfer strategies comparison graph
|
| 677 |
+
plt.figure(figsize=(10, 6))
|
| 678 |
+
plt.bar(strategy_results.keys(), strategy_results.values())
|
| 679 |
+
plt.title('Transfer Learning Strategies Comparison')
|
| 680 |
+
plt.ylim(0, 1.0)
|
| 681 |
+
plt.tight_layout()
|
| 682 |
+
plt.savefig('transfer_learning_comparison.png')
|
| 683 |
+
plt.close()
|
| 684 |
+
|
| 685 |
+
except Exception as e:
|
| 686 |
+
print(f"Transfer learning strategy comparison error: {e}")
|
| 687 |
+
print("Using default fine-tune strategy")
|
| 688 |
+
best_strategy = "fine_tune" # Default strategy
|
| 689 |
+
|
| 690 |
+
# STEP 4: Final optimized model training
|
| 691 |
+
# Applying all advanced techniques B.2-B.6
|
| 692 |
+
print(f"\n=== STEP 4: FINAL OPTIMIZED MODEL TRAINING ===")
|
| 693 |
+
print(f"Architecture: {best_model_arch}, Batch Size: {best_batch_size}, Strategy: {best_strategy}")
|
| 694 |
+
|
| 695 |
+
try:
|
| 696 |
+
# Create DataLoader for final model
|
| 697 |
+
result = create_datablock(
|
| 698 |
+
data_path,
|
| 699 |
+
img_size=224,
|
| 700 |
+
batch_size=best_batch_size, # B.6.1: Adjusted batch size for model capacity
|
| 701 |
+
valid_pct=0.2
|
| 702 |
+
)
|
| 703 |
+
final_dls = result[1]
|
| 704 |
+
|
| 705 |
+
# Train optimized model with all advanced techniques
|
| 706 |
+
# B.2: Learning Rate Finder
|
| 707 |
+
# B.3: Transfer Learning
|
| 708 |
+
# B.4: Discriminative Learning Rates - Applied inside train_optimized_model using slice(lr/100, lr/10)
|
| 709 |
+
# B.5: Optimal Epoch Selection - Using fixed epochs instead of dynamically determining optimal number
|
| 710 |
+
# B.6: Model Capacity Adjustments - Using the best model architecture selected earlier
|
| 711 |
+
print("Training final model with all optimizations:")
|
| 712 |
+
print("- B.2: Learning Rate Finder")
|
| 713 |
+
print("- B.3: Transfer Learning")
|
| 714 |
+
print("- B.4: Discriminative Learning Rates")
|
| 715 |
+
print("- B.5: Optimal Epoch Selection")
|
| 716 |
+
print("- B.6: Model Capacity Adjustments")
|
| 717 |
+
|
| 718 |
+
optimized_model = train_optimized_model(
|
| 719 |
+
final_dls,
|
| 720 |
+
model_arch=best_model_arch, # Selected best architecture
|
| 721 |
+
# B.5: Using fixed epochs instead of dynamically determining optimal number
|
| 722 |
+
epochs=3, # Final model epochs
|
| 723 |
+
batch_size=best_batch_size # B.6.1: Adjusted batch size according to model capacity
|
| 724 |
+
)
|
| 725 |
+
|
| 726 |
+
# STEP 5: Final model evaluation
|
| 727 |
+
print("\n=== STEP 5: FINAL MODEL EVALUATION ===")
|
| 728 |
+
metrics = optimized_model.validate()
|
| 729 |
+
final_loss = metrics[0] # loss
|
| 730 |
+
final_accuracy = metrics[2] # accuracy
|
| 731 |
+
print(f"Final Optimized Model - Accuracy: {final_accuracy:.4f}, Loss: {final_loss:.4f}")
|
| 732 |
+
|
| 733 |
+
# Confusion matrix and classification report
|
| 734 |
+
interpret_model(optimized_model, "final_optimized")
|
| 735 |
+
|
| 736 |
+
print("\nTraining completed successfully!")
|
| 737 |
+
print(f"Model saved as: optimized_model_final")
|
| 738 |
+
print(f"Exported as: optimized_emotion_classifier.pkl")
|
| 739 |
+
|
| 740 |
+
except Exception as e:
|
| 741 |
+
print(f"Final optimized model training error: {e}")
|
| 742 |
+
print("Training could not be completed. Check error message.")
|
| 743 |
+
|
| 744 |
+
if __name__ == "__main__":
|
| 745 |
+
main()
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastai>=2.7.0
|
| 2 |
+
torch>=2.0.0
|
| 3 |
+
numpy>=1.21.0
|
| 4 |
+
pandas>=1.3.0
|
| 5 |
+
matplotlib>=3.4.0
|
| 6 |
+
scikit-learn>=1.0.0
|
| 7 |
+
pillow>=8.3.0
|
| 8 |
+
seaborn>=0.11.0
|
| 9 |
+
gradio>=3.32.0
|