Moodify / app.py
MReq's picture
Upload 6 files
389d5e1 verified
import gradio as gr
from transformers import AutoImageProcessor, SiglipForImageClassification
from PIL import Image
import torch
import numpy as np
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import io
import logging
import tensorflow as tf
from tensorflow import keras
import cv2
# ----------------- LOGGER SETUP -----------------
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger("face-analysis")
# ----------------- LOAD MODELS -----------------
# Emotion model (H5 format)
H5_MODEL_PATH = "my_model3.h5"
INPUT_SIZE = (48, 48)
emotion_model = keras.models.load_model(H5_MODEL_PATH)
logger.info("Emotion model loaded successfully")
logger.info(f"Model input shape: {emotion_model.input_shape}")
logger.info(f"Model output shape: {emotion_model.output_shape}")
# Age model
age_model_name = "prithivMLmods/facial-age-detection"
age_model = SiglipForImageClassification.from_pretrained(age_model_name)
age_processor = AutoImageProcessor.from_pretrained(age_model_name)
# Face detection cascade
HAAR_CASCADE_PATH = 'haarcascade_frontalface_default.xml'
face_cascade = cv2.CascadeClassifier(HAAR_CASCADE_PATH)
# Verify cascade loaded successfully
if face_cascade.empty():
logger.error(f"Failed to load Haar Cascade from {HAAR_CASCADE_PATH}")
logger.warning("Attempting to load from OpenCV data directory...")
# Try loading from OpenCV's data directory
HAAR_CASCADE_PATH = cv2.data.haarcascades + 'haarcascade_frontalface_default.xml'
face_cascade = cv2.CascadeClassifier(HAAR_CASCADE_PATH)
if face_cascade.empty():
logger.error("Still failed to load Haar Cascade. Face detection will not work.")
else:
logger.info(f"Haar Cascade loaded from OpenCV data: {HAAR_CASCADE_PATH}")
else:
logger.info(f"Haar Cascade loaded successfully from {HAAR_CASCADE_PATH}")
# Emotion classes
emotions = ["Angry", "Disgust", "Fear", "Happy", "Sad", "Surprise", "Neutral"]
# Age labels
id2label = {
"0": "age 01-10",
"1": "age 11-20",
"2": "age 21-30",
"3": "age 31-40",
"4": "age 41-55",
"5": "age 56-65",
"6": "age 66-80",
"7": "age 80+"
}
# ----------------- FACE DETECTION -----------------
def detect_and_crop_face(image: Image.Image):
"""
Detect face in image and crop it.
Returns: (cropped_face, message, success)
"""
try:
# Convert PIL to numpy array for OpenCV
img_array = np.asarray(image)
logger.debug(f"Image shape: {img_array.shape}, dtype: {img_array.dtype}")
# Convert RGB to BGR if needed (OpenCV uses BGR)
if len(img_array.shape) == 3 and img_array.shape[2] == 3:
img_array = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
# Convert to grayscale for better face detection
gray = cv2.cvtColor(img_array, cv2.COLOR_BGR2GRAY) if len(img_array.shape) == 3 else img_array
logger.debug(f"Grayscale shape: {gray.shape}")
# Detect faces with more lenient parameters
faces = face_cascade.detectMultiScale(
gray,
scaleFactor=1.1, # More sensitive (was 1.3)
minNeighbors=3, # More lenient (was 5)
minSize=(30, 30), # Minimum face size
flags=cv2.CASCADE_SCALE_IMAGE
)
logger.info(f"Face detection result: {len(faces)} face(s) detected")
if len(faces) == 0:
logger.warning("No face detected in image - returning original image")
# Fallback: return original image if no face detected
return image, "⚠️ No face detected - using full image", True
if len(faces) == 1:
# Single face detected - crop it
x, y, w, h = faces[0]
crop_img = image.crop((x, y, x+w, y+h))
logger.info(f"✓ Face detected and cropped: position ({x},{y}), size {w}x{h}")
return crop_img, f"✓ Face detected at ({x},{y}), size {w}x{h}", True
else:
# Multiple faces detected - use the largest one
logger.warning(f"Multiple faces detected ({len(faces)}), using largest face")
# Find the largest face
largest_face = max(faces, key=lambda face: face[2] * face[3])
x, y, w, h = largest_face
crop_img = image.crop((x, y, x+w, y+h))
return crop_img, f"⚠️ {len(faces)} faces detected, using largest one", True
except Exception as e:
logger.error(f"Face detection error: {e}")
import traceback
logger.error(traceback.format_exc())
# Return original image on error
return image, f"⚠️ Face detection error - using full image", True
# ----------------- PREDICT FUNCTIONS -----------------
def preprocess_image_for_emotion(image: Image.Image):
"""
Preprocess image for the H5 emotion model.
Model expects: (batch_size, 48, 48, 1) - 48x48 grayscale images
"""
image = image.convert("L") # Convert to grayscale
image = image.resize(INPUT_SIZE)
img_array = np.array(image, dtype=np.float32)
img_array = np.expand_dims(img_array, axis=-1) # (48, 48) -> (48, 48, 1)
img_array = img_array / 255.0 # Normalize
img_array = np.expand_dims(img_array, axis=0) # Add batch dimension
logger.debug(f"Preprocessed shape: {img_array.shape}, dtype: {img_array.dtype}")
return img_array
def predict_emotion(image: Image.Image):
try:
processed_image = preprocess_image_for_emotion(image)
predictions = emotion_model.predict(processed_image, verbose=0)
probs = predictions[0]
idx = np.argmax(probs)
result = {
"predicted_emotion": emotions[idx],
"confidence": round(float(probs[idx]), 4),
"all_confidences": {emotions[i]: float(probs[i]) for i in range(len(emotions))}
}
logger.info(f"Predicted Emotion: {result['predicted_emotion']} (Confidence: {result['confidence']})")
return result
except Exception as e:
logger.error(f"Emotion prediction error: {e}")
import traceback
logger.error(traceback.format_exc())
return {"error": str(e)}
def predict_age(image: Image.Image):
try:
inputs = age_processor(images=image.convert("RGB"), return_tensors="pt")
with torch.no_grad():
outputs = age_model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=1).squeeze().tolist()
prediction = {id2label[str(i)]: round(probs[i], 3) for i in range(len(probs))}
idx = int(torch.argmax(torch.tensor(probs)))
result = {
"predicted_age": id2label[str(idx)],
"confidence": round(probs[idx], 4),
"all_confidences": prediction
}
logger.info(f"Predicted Age Group: {result['predicted_age']} (Confidence: {result['confidence']})")
return result
except Exception as e:
logger.error(f"Age prediction error: {e}")
import traceback
logger.error(traceback.format_exc())
return {"error": str(e)}
# ----------------- FASTAPI APP -----------------
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def root():
return {
"message": "Face Emotion + Age Detection API",
"status": "running",
"endpoints": {
"GET /": "API information",
"GET /health": "Health check",
"POST /predict": "Upload image for emotion and age prediction",
"GET /gradio": "Gradio web interface"
}
}
@app.get("/health")
async def health():
return {
"status": "ok",
"emotion_model": "loaded",
"age_model": "loaded",
"face_cascade": "loaded" if not face_cascade.empty() else "failed",
"emotion_input_shape": str(emotion_model.input_shape),
"emotion_output_shape": str(emotion_model.output_shape)
}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
try:
contents = await file.read()
image = Image.open(io.BytesIO(contents))
# Detect and crop face (now always returns success=True with fallback)
cropped_face, face_msg, success = detect_and_crop_face(image)
# Predict on cropped face or full image (fallback)
emotion_result = predict_emotion(cropped_face)
age_result = predict_age(cropped_face)
logger.info(f"API Response -> Emotion: {emotion_result.get('predicted_emotion')} | Age: {age_result.get('predicted_age')}")
return JSONResponse(content={
"face_detection": face_msg,
"emotion": emotion_result,
"age": age_result
})
except Exception as e:
logger.error(f"API Error: {e}")
import traceback
logger.error(traceback.format_exc())
return JSONResponse(content={"error": str(e)}, status_code=500)
# ----------------- GRADIO DEMO -----------------
def gradio_wrapper(image):
if image is None:
return "No image provided", {}, "No image provided", {}, None, None, "No image uploaded"
# Detect and crop face (always succeeds with fallback)
cropped_face, face_msg, success = detect_and_crop_face(image)
# Get the processed image for visualization
processed_image = preprocess_image_for_emotion(cropped_face)
# Convert back to PIL for display
processed_display = Image.fromarray((processed_image[0, :, :, 0] * 255).astype(np.uint8), mode='L')
# Predict emotion and age on cropped face or full image
emotion_result = predict_emotion(cropped_face)
age_result = predict_age(cropped_face)
if "error" in emotion_result or "error" in age_result:
error_msg = emotion_result.get("error", "") or age_result.get("error", "")
return f"Error: {error_msg}", {}, f"Error: {error_msg}", {}, cropped_face, None, face_msg
return (
f"{emotion_result['predicted_emotion']} ({emotion_result['confidence']:.2f})",
emotion_result["all_confidences"],
f"{age_result['predicted_age']} ({age_result['confidence']:.2f})",
age_result["all_confidences"],
cropped_face, # Show the cropped face or full image
processed_display, # Show the processed 48x48 grayscale
face_msg # Face detection message
)
demo = gr.Interface(
fn=gradio_wrapper,
inputs=gr.Image(type="pil", label="Upload Face Image"),
outputs=[
gr.Label(num_top_classes=1, label="Top Emotion"),
gr.Label(label="Emotion Probabilities"),
gr.Label(num_top_classes=1, label="Top Age Group"),
gr.Label(label="Age Probabilities"),
gr.Image(type="pil", label="Detected & Cropped Face"),
gr.Image(type="pil", label="Processed Image (48x48 Grayscale)"),
gr.Textbox(label="Face Detection Status")
],
title="Face Emotion + Age Detection with Face Cropping",
description="Upload an image with a face. The system will:\n1. Detect and crop the face (or use full image if no face found)\n2. Analyze emotion (Angry, Happy, etc.)\n3. Estimate age group (01-10, 11-20, ... 80+)\n4. Show the processing steps",
examples=None
)
# Mount Gradio at /gradio
app = gr.mount_gradio_app(app, demo, path="/gradio")
# ----------------- RUN -----------------
if __name__ == "__main__":
logger.info("="*70)
logger.info("Starting Face Emotion + Age Detection Server")
logger.info("="*70)
logger.info(f"Emotion Model Input Shape: {emotion_model.input_shape}")
logger.info(f"Emotion Model Output Shape: {emotion_model.output_shape}")
logger.info(f"Number of emotion classes: {len(emotions)}")
logger.info("")
logger.info("Server will be available at:")
logger.info(" - Main API: http://0.0.0.0:7860")
logger.info(" - Gradio UI: http://0.0.0.0:7860/gradio")
logger.info("="*70)
uvicorn.run(app, host="0.0.0.0", port=7860)