Spaces:
Configuration error
Configuration error
import gradio as gr | |
import cv2 | |
import numpy as np | |
import pandas as pd | |
import time | |
import mediapipe as mp | |
import matplotlib.pyplot as plt | |
from matplotlib.colors import LinearSegmentedColormap | |
from matplotlib.collections import LineCollection | |
import os | |
import datetime | |
import tempfile | |
from typing import Dict, List, Tuple, Optional, Union, Any | |
import threading | |
import queue | |
import asyncio | |
import librosa | |
import torch | |
from moviepy.editor import VideoFileClip | |
from transformers import pipeline, AutoFeatureExtractor, AutoModelForAudioClassification | |
import google.generativeai as genai | |
from concurrent.futures import ThreadPoolExecutor | |
# --- Constants --- | |
VIDEO_FPS = 15 # Estimated/Target FPS for saved video | |
CSV_FILENAME_TEMPLATE = "facial_analysis_{timestamp}.csv" | |
VIDEO_FILENAME_TEMPLATE = "processed_{timestamp}.mp4" | |
AUDIO_FILENAME_TEMPLATE = "audio_{timestamp}.wav" | |
# --- MediaPipe Initialization --- | |
mp_face_mesh = mp.solutions.face_mesh | |
mp_drawing = mp.solutions.drawing_utils | |
mp_drawing_styles = mp.solutions.drawing_styles | |
face_mesh = mp_face_mesh.FaceMesh( | |
max_num_faces=1, | |
refine_landmarks=True, | |
min_detection_confidence=0.5, | |
min_tracking_confidence=0.5) | |
# --- Audio Model Initialization --- | |
# We'll initialize this in a function to avoid loading at startup | |
audio_classifier = None | |
audio_feature_extractor = None | |
def initialize_audio_model(): | |
global audio_classifier, audio_feature_extractor | |
if audio_classifier is None: | |
print("Loading audio classification model...") | |
model_name = "ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition" | |
audio_feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) | |
audio_classifier = AutoModelForAudioClassification.from_pretrained(model_name) | |
print("Audio model loaded successfully") | |
return audio_classifier, audio_feature_extractor | |
# --- Gemini API Configuration --- | |
# Replace with your Gemini API key | |
GEMINI_API_KEY = "your-gemini-api-key" # In production, load from environment variable | |
def configure_gemini(): | |
genai.configure(api_key=GEMINI_API_KEY) | |
# Set up the model | |
generation_config = { | |
"temperature": 0.2, | |
"top_p": 0.8, | |
"top_k": 40, | |
"max_output_tokens": 256, | |
} | |
safety_settings = [ | |
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, | |
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, | |
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, | |
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, | |
] | |
try: | |
model = genai.GenerativeModel( | |
model_name="gemini-1.5-flash", | |
generation_config=generation_config, | |
safety_settings=safety_settings | |
) | |
return model | |
except Exception as e: | |
print(f"Error configuring Gemini: {e}") | |
return None | |
# --- Metrics Definition --- | |
metrics = [ | |
"valence", "arousal", "dominance", "cognitive_load", | |
"emotional_stability", "openness", "agreeableness", | |
"neuroticism", "conscientiousness", "extraversion", | |
"stress_index", "engagement_level" | |
] | |
audio_metrics = [ | |
"audio_valence", "audio_arousal", "audio_intensity", | |
"audio_emotion", "audio_confidence" | |
] | |
ad_context_columns = ["ad_description", "ad_detail", "ad_type", "gemini_ad_analysis"] | |
user_state_column = ["user_state", "detailed_user_analysis"] | |
all_columns = ['timestamp', 'frame_number'] + metrics + audio_metrics + ad_context_columns + user_state_column | |
initial_metrics_df = pd.DataFrame(columns=all_columns) | |
# --- Live Processing Queue --- | |
processing_queue = queue.Queue() | |
results_queue = queue.Queue() | |
# --- Gemini Functions --- | |
def call_gemini_api_for_ad(model, description, detail, ad_type): | |
"""Uses Gemini to analyze ad context.""" | |
if not model: | |
return "Gemini model not available. Using simulated analysis." | |
if not description and not detail: | |
return "No ad context provided." | |
prompt = f""" | |
Analyze this advertisement context: | |
- Description: {description or 'N/A'} | |
- Detail/Focus: {detail or 'N/A'} | |
- Type/Genre: {ad_type} | |
Provide a concise analysis of how this ad might affect viewer emotions and cognition. | |
Focus on potential emotional triggers, cognitive demands, and engagement patterns. | |
Keep your analysis under 100 words. | |
""" | |
try: | |
response = model.generate_content(prompt) | |
return response.text | |
except Exception as e: | |
print(f"Error calling Gemini API: {e}") | |
return f"Simulated analysis: Ad='{description or 'N/A'}' ({ad_type}), Focus='{detail or 'N/A'}'." | |
def interpret_metrics_with_gemini(model, metrics_dict, audio_metrics_dict=None, ad_context=None, timestamp=None): | |
"""Uses Gemini to interpret facial and audio metrics -> detailed user state.""" | |
if not model: | |
return simple_user_state_analysis(metrics_dict, audio_metrics_dict), "Gemini model not available. Using rule-based analysis." | |
if not metrics_dict: | |
return "No response", "No metrics data available" | |
metrics_text = "\n".join([f"- {k}: {v:.3f}" for k, v in metrics_dict.items()]) | |
audio_text = "" | |
if audio_metrics_dict: | |
audio_text = "\n".join([f"- {k}: {v}" for k, v in audio_metrics_dict.items()]) | |
ad_text = "" | |
if ad_context: | |
ad_text = f""" | |
Ad Context: | |
- Description: {ad_context.get('ad_description', 'N/A')} | |
- Detail/Focus: {ad_context.get('ad_detail', 'N/A')} | |
- Type/Genre: {ad_context.get('ad_type', 'N/A')} | |
""" | |
timestamp_text = f"Timestamp: {timestamp:.2f} seconds" if timestamp is not None else "" | |
prompt = f""" | |
Analyze the following viewer metrics and provide a detailed assessment of their current state: | |
{timestamp_text} | |
Facial Expression Metrics: | |
{metrics_text} | |
{'Audio Expression Metrics:' if audio_text else ''} | |
{audio_text} | |
{ad_text} | |
First, provide a short 1-5 word state label that summarizes the viewer's current emotional and cognitive state. | |
Then, provide a more detailed 2-3 sentence analysis explaining what these metrics suggest about the viewer's: | |
- Emotional state | |
- Cognitive engagement | |
- Likely response to the content | |
- Any notable patterns or anomalies | |
Format your response as: | |
USER STATE: [state label] | |
DETAILED ANALYSIS: [your analysis] | |
""" | |
try: | |
response = model.generate_content(prompt) | |
text = response.text.strip() | |
# Parse the response | |
state_parts = text.split("USER STATE:", 1) | |
if len(state_parts) > 1: | |
state_text = state_parts[1].split("DETAILED ANALYSIS:", 1) | |
if len(state_text) > 1: | |
simple_state = state_text[0].strip() | |
detailed_analysis = state_text[1].strip() | |
return simple_state, detailed_analysis | |
# Fallback if parsing fails | |
simple_state = text.split('\n')[0].strip() | |
detailed_analysis = ' '.join(text.split('\n')[1:]).strip() | |
return simple_state, detailed_analysis | |
except Exception as e: | |
print(f"Error interpreting metrics with Gemini: {e}") | |
return simple_user_state_analysis(metrics_dict, audio_metrics_dict), "Error generating detailed analysis" | |
def simple_user_state_analysis(metrics_dict, audio_metrics_dict=None): | |
"""Simple rule-based user state analysis as fallback.""" | |
if not metrics_dict: | |
return "No metrics" | |
valence = metrics_dict.get('valence', 0.5) | |
arousal = metrics_dict.get('arousal', 0.5) | |
cog_load = metrics_dict.get('cognitive_load', 0.5) | |
stress = metrics_dict.get('stress_index', 0.5) | |
engagement = metrics_dict.get('engagement_level', 0.5) | |
# Include audio metrics when available | |
audio_emotion = None | |
audio_valence = 0.5 | |
if audio_metrics_dict: | |
audio_emotion = audio_metrics_dict.get('audio_emotion') | |
audio_valence = audio_metrics_dict.get('audio_valence', 0.5) | |
# Blend facial and audio valence | |
valence = (valence * 0.7) + (audio_valence * 0.3) | |
# Simple rule-based analysis | |
state = "Neutral" | |
if valence > 0.65 and arousal > 0.55 and engagement > 0.6: | |
state = "Positive, Engaged" | |
elif valence < 0.4 and stress > 0.6: | |
state = "Stressed, Negative" | |
elif cog_load > 0.7 and engagement < 0.4: | |
state = "Confused, Disengaged" | |
elif arousal < 0.4 and engagement < 0.5: | |
state = "Calm, Passive" | |
# Override with audio emotion if it's strong | |
if audio_emotion in ["happy", "excited"] and audio_metrics_dict.get('audio_confidence', 0) > 0.7: | |
state = audio_emotion.capitalize() | |
elif audio_emotion in ["angry", "sad", "fearful"] and audio_metrics_dict.get('audio_confidence', 0) > 0.7: | |
state = audio_emotion.capitalize() | |
return state | |
# --- Audio Analysis Functions --- | |
def extract_audio_from_video(video_path, output_audio_path=None): | |
"""Extract audio from video file""" | |
if output_audio_path is None: | |
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
output_audio_path = AUDIO_FILENAME_TEMPLATE.format(timestamp=timestamp) | |
try: | |
video = VideoFileClip(video_path) | |
video.audio.write_audiofile(output_audio_path, fps=16000, nbytes=2, codec='pcm_s16le') | |
return output_audio_path | |
except Exception as e: | |
print(f"Error extracting audio: {e}") | |
return None | |
def analyze_audio_segment(audio_path, start_time, duration=1.0): | |
"""Analyze a segment of audio for emotion""" | |
classifier, feature_extractor = initialize_audio_model() | |
try: | |
# Load audio segment | |
y, sr = librosa.load(audio_path, sr=16000, offset=start_time, duration=duration) | |
if len(y) < 100: # Too short to analyze | |
return None | |
# Extract features | |
inputs = feature_extractor(y, sampling_rate=sr, return_tensors="pt") | |
# Get predictions | |
with torch.no_grad(): | |
outputs = classifier(**inputs) | |
logits = outputs.logits | |
probabilities = torch.nn.functional.softmax(logits, dim=1) | |
# Get the predicted class and its probability | |
predicted_class_idx = torch.argmax(probabilities, dim=1).item() | |
confidence = probabilities[0][predicted_class_idx].item() | |
# Map to emotion labels (verify these match your model's labels) | |
emotion_labels = ["angry", "fearful", "happy", "neutral", "sad", "surprised"] | |
predicted_emotion = emotion_labels[predicted_class_idx] | |
# Calculate valence and arousal based on emotion | |
emotion_mappings = { | |
"angry": {"valence": 0.2, "arousal": 0.9, "intensity": 0.8}, | |
"fearful": {"valence": 0.3, "arousal": 0.8, "intensity": 0.7}, | |
"happy": {"valence": 0.9, "arousal": 0.7, "intensity": 0.6}, | |
"neutral": {"valence": 0.5, "arousal": 0.5, "intensity": 0.3}, | |
"sad": {"valence": 0.2, "arousal": 0.3, "intensity": 0.5}, | |
"surprised": {"valence": 0.6, "arousal": 0.8, "intensity": 0.7} | |
} | |
valence = emotion_mappings.get(predicted_emotion, {"valence": 0.5})["valence"] | |
arousal = emotion_mappings.get(predicted_emotion, {"arousal": 0.5})["arousal"] | |
intensity = emotion_mappings.get(predicted_emotion, {"intensity": 0.5})["intensity"] | |
# Return audio metrics | |
return { | |
"audio_valence": valence, | |
"audio_arousal": arousal, | |
"audio_intensity": intensity, | |
"audio_emotion": predicted_emotion, | |
"audio_confidence": confidence | |
} | |
except Exception as e: | |
print(f"Error analyzing audio segment: {e}") | |
return None | |
# --- Analysis Functions --- | |
def extract_face_landmarks(image, face_mesh_instance): | |
if image is None or face_mesh_instance is None: | |
return None | |
try: | |
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
image_rgb.flags.writeable = False | |
results = face_mesh_instance.process(image_rgb) | |
image_rgb.flags.writeable = True | |
if results.multi_face_landmarks: | |
return results.multi_face_landmarks[0] | |
except Exception as e: | |
print(f"Error in landmark extraction: {e}") | |
return None | |
def calculate_ear(landmarks): | |
if not landmarks: | |
return 0.0 | |
try: | |
LEFT_EYE = [33, 160, 158, 133, 153, 144] | |
RIGHT_EYE = [362, 385, 387, 263, 373, 380] | |
def get_coords(idx_list): | |
return np.array([(landmarks.landmark[i].x, landmarks.landmark[i].y) for i in idx_list]) | |
left_pts = get_coords(LEFT_EYE) | |
right_pts = get_coords(RIGHT_EYE) | |
def ear_aspect(pts): | |
v1 = np.linalg.norm(pts[1] - pts[5]) | |
v2 = np.linalg.norm(pts[2] - pts[4]) | |
h = np.linalg.norm(pts[0] - pts[3]) | |
return (v1 + v2) / (2.0 * h) if h > 1e-6 else 0.0 | |
return (ear_aspect(left_pts) + ear_aspect(right_pts)) / 2.0 | |
except (IndexError, AttributeError) as e: | |
print(f"Error calculating EAR: {e}") | |
return 0.0 | |
def calculate_mar(landmarks): | |
if not landmarks: | |
return 0.0 | |
try: | |
MOUTH = [61, 291, 39, 181, 0, 17, 269, 405] | |
pts = np.array([(landmarks.landmark[i].x, landmarks.landmark[i].y) for i in MOUTH]) | |
h = np.mean([np.linalg.norm(pts[1] - pts[7]), np.linalg.norm(pts[2] - pts[6]), np.linalg.norm(pts[3] - pts[5])]) | |
w = np.linalg.norm(pts[0] - pts[4]) | |
return h / w if w > 1e-6 else 0.0 | |
except (IndexError, AttributeError) as e: | |
print(f"Error calculating MAR: {e}") | |
return 0.0 | |
def calculate_eyebrow_position(landmarks): | |
if not landmarks: | |
return 0.0 | |
try: | |
L_BROW = 107 | |
R_BROW = 336 | |
L_EYE_C = 159 | |
R_EYE_C = 386 | |
l_brow_y = landmarks.landmark[L_BROW].y | |
r_brow_y = landmarks.landmark[R_BROW].y | |
l_eye_y = landmarks.landmark[L_EYE_C].y | |
r_eye_y = landmarks.landmark[R_EYE_C].y | |
l_dist = l_eye_y - l_brow_y | |
r_dist = r_eye_y - r_brow_y | |
avg_dist = (l_dist + r_dist) / 2.0 | |
norm = (avg_dist - 0.02) / 0.06 | |
return max(0.0, min(1.0, norm)) | |
except (IndexError, AttributeError) as e: | |
print(f"Error calculating Eyebrow Pos: {e}") | |
return 0.0 | |
def estimate_head_pose(landmarks): | |
if not landmarks: | |
return 0.0, 0.0 | |
try: | |
NOSE = 4 | |
L_EYE_C = 159 | |
R_EYE_C = 386 | |
nose_pt = np.array([landmarks.landmark[NOSE].x, landmarks.landmark[NOSE].y]) | |
l_eye_pt = np.array([landmarks.landmark[L_EYE_C].x, landmarks.landmark[L_EYE_C].y]) | |
r_eye_pt = np.array([landmarks.landmark[R_EYE_C].x, landmarks.landmark[R_EYE_C].y]) | |
eye_mid_y = (l_eye_pt[1] + r_eye_pt[1]) / 2.0 | |
eye_mid_x = (l_eye_pt[0] + r_eye_pt[0]) / 2.0 | |
v_tilt = nose_pt[1] - eye_mid_y | |
h_tilt = nose_pt[0] - eye_mid_x | |
v_tilt_norm = max(-1.0, min(1.0, v_tilt * 5.0)) | |
h_tilt_norm = max(-1.0, min(1.0, h_tilt * 10.0)) | |
return v_tilt_norm, h_tilt_norm | |
except (IndexError, AttributeError) as e: | |
print(f"Error estimating Head Pose: {e}") | |
return 0.0, 0.0 | |
def calculate_metrics(landmarks, ad_context=None): | |
if ad_context is None: | |
ad_context = {} | |
if not landmarks: | |
return {m: 0.5 for m in metrics} # Return defaults if no landmarks | |
# Calculate base features | |
ear = calculate_ear(landmarks) | |
mar = calculate_mar(landmarks) | |
eb_pos = calculate_eyebrow_position(landmarks) | |
v_tilt, h_tilt = estimate_head_pose(landmarks) | |
# Illustrative Context Adjustments | |
ad_type = ad_context.get('ad_type', 'Unk') | |
gem_txt = str(ad_context.get('gemini_ad_analysis', '')).lower() | |
val_mar_w = 2.5 if ad_type == 'Funny' or 'humor' in gem_txt else 2.0 | |
val_eb_w = 0.8 if ad_type == 'Serious' or 'sad' in gem_txt else 1.0 | |
arsl_base = 0.05 if ad_type == 'Action' or 'exciting' in gem_txt else 0.0 | |
# Calculate final metrics using base features and context adjustments | |
cl = max(0, min(1, 1.0 - ear * 2.5)) | |
val = max(0, min(1, mar * val_mar_w * (val_eb_w * (1.0 - eb_pos)))) | |
arsl = max(0, min(1, arsl_base + (mar + (1.0 - ear) + eb_pos) / 3.0)) | |
dom = max(0, min(1, 0.5 + v_tilt)) | |
neur = max(0, min(1, (cl * 0.6) + ((1.0 - val) * 0.4))) | |
em_stab = 1.0 - neur | |
extr = max(0, min(1, (arsl * 0.5) + (val * 0.5))) | |
open = max(0, min(1, 0.5 + ((mar - 0.5) * 0.5))) | |
agree = max(0, min(1, (val * 0.7) + ((1.0 - arsl) * 0.3))) | |
consc = max(0, min(1, (1.0 - abs(arsl - 0.5)) * 0.7 + (em_stab * 0.3))) | |
stress = max(0, min(1, (cl * 0.5) + (eb_pos * 0.3) + ((1.0 - val) * 0.2))) | |
engag = max(0, min(1, (arsl * 0.7) + ((1.0 - abs(h_tilt)) * 0.3))) | |
# Return dictionary of metrics | |
return { | |
'valence': val, 'arousal': arsl, 'dominance': dom, 'cognitive_load': cl, | |
'emotional_stability': em_stab, 'openness': open, 'agreeableness': agree, | |
'neuroticism': neur, 'conscientiousness': consc, 'extraversion': extr, | |
'stress_index': stress, 'engagement_level': engag | |
} | |
def update_metrics_visualization(metrics_values, audio_metrics=None, title=None): | |
if not metrics_values: | |
fig, ax = plt.subplots(figsize=(10, 8)) | |
ax.text(0.5, 0.5, "Waiting...", ha='center', va='center') | |
ax.axis('off') | |
fig.patch.set_facecolor('#FFFFFF') | |
ax.set_facecolor('#FFFFFF') | |
return fig | |
# Combine face and audio metrics for visualization | |
all_metrics = {} | |
for k, v in metrics_values.items(): | |
if k not in ('timestamp', 'frame_number', 'user_state', 'detailed_user_analysis'): | |
all_metrics[k] = v | |
if audio_metrics: | |
for k, v in audio_metrics.items(): | |
if isinstance(v, (int, float)): | |
all_metrics[k] = v | |
num_metrics = len(all_metrics) | |
nrows = (num_metrics + 2) // 3 | |
fig, axs = plt.subplots(nrows, 3, figsize=(10, nrows * 2.5), facecolor='#FFFFFF') | |
axs = axs.flatten() | |
if title: | |
fig.suptitle(title, fontsize=12) | |
colors = [(0.1, 0.1, 0.9), (0.9, 0.9, 0.1), (0.9, 0.1, 0.1)] | |
cmap = LinearSegmentedColormap.from_list("custom_cmap", colors, N=100) | |
norm = plt.Normalize(0, 1) | |
metric_idx = 0 | |
for key, value in all_metrics.items(): | |
if not isinstance(value, (int, float)): | |
value = 0.5 | |
value = max(0.0, min(1.0, value)) | |
ax = axs[metric_idx] | |
ax.set_title(key.replace('_', ' ').title(), fontsize=10) | |
ax.set_xlim(0, 1) | |
ax.set_ylim(0, 0.5) | |
ax.set_aspect('equal') | |
ax.axis('off') | |
ax.set_facecolor('#FFFFFF') | |
r = 0.4 | |
theta = np.linspace(np.pi, 0, 100) | |
x_bg = 0.5 + r * np.cos(theta) | |
y_bg = 0.1 + r * np.sin(theta) | |
ax.plot(x_bg, y_bg, 'k-', linewidth=3, alpha=0.2) | |
value_angle = np.pi * (1 - value) | |
num_points = max(2, int(100 * value)) | |
value_theta = np.linspace(np.pi, value_angle, num_points) | |
x_val = 0.5 + r * np.cos(value_theta) | |
y_val = 0.1 + r * np.sin(value_theta) | |
if len(x_val) > 1: | |
points = np.array([x_val, y_val]).T.reshape(-1, 1, 2) | |
segments = np.concatenate([points[:-1], points[1:]], axis=1) | |
segment_values = np.linspace(0, value, len(segments)) | |
lc = LineCollection(segments, cmap=cmap, norm=norm) | |
lc.set_array(segment_values) | |
lc.set_linewidth(5) | |
ax.add_collection(lc) | |
ax.text(0.5, 0.15, f"{value:.2f}", ha='center', va='center', fontsize=11, | |
fontweight='bold', bbox=dict(facecolor='white', alpha=0.7, boxstyle='round,pad=0.2')) | |
metric_idx += 1 | |
for i in range(metric_idx, len(axs)): | |
axs[i].axis('off') | |
plt.tight_layout(pad=0.5) | |
return fig | |
def create_user_state_display(state_text, detailed_analysis=None): | |
"""Create a visual display of the user state""" | |
fig, ax = plt.subplots(figsize=(10, 2.5)) | |
ax.axis('off') | |
# Display state | |
ax.text(0.5, 0.8, f"USER STATE: {state_text}", | |
ha='center', va='center', fontsize=14, fontweight='bold', | |
bbox=dict(facecolor='#e6f2ff', alpha=0.7, boxstyle='round,pad=0.5')) | |
# Display detailed analysis if available | |
if detailed_analysis: | |
ax.text(0.5, 0.3, detailed_analysis, | |
ha='center', va='center', fontsize=10, | |
bbox=dict(facecolor='#f2f2f2', alpha=0.7, boxstyle='round,pad=0.5')) | |
plt.tight_layout() | |
return fig | |
def annotate_frame(frame, landmarks): | |
"""Add facial landmark annotations to a frame""" | |
if frame is None: | |
return None | |
annotated = frame.copy() | |
if landmarks: | |
try: | |
mp_drawing.draw_landmarks( | |
image=annotated, | |
landmark_list=landmarks, | |
connections=mp_face_mesh.FACEMESH_TESSELATION, | |
landmark_drawing_spec=None, | |
connection_drawing_spec=mp_drawing_styles.get_default_face_mesh_tesselation_style() | |
) | |
mp_drawing.draw_landmarks( | |
image=annotated, | |
landmark_list=landmarks, | |
connections=mp_face_mesh.FACEMESH_CONTOURS, | |
landmark_drawing_spec=None, | |
connection_drawing_spec=mp_drawing_styles.get_default_face_mesh_contours_style() | |
) | |
except Exception as e: | |
print(f"Error drawing landmarks: {e}") | |
return annotated | |
# --- Background Processing Functions --- | |
def process_frames_in_background(session): | |
"""Background thread for processing frames and updating metrics""" | |
while True: | |
try: | |
# Get task from queue | |
task = processing_queue.get(timeout=1.0) | |
if task.get('command') == 'stop': | |
break | |
frame = task.get('frame') | |
if frame is None: | |
continue | |
# Process the frame | |
result = process_webcam_frame( | |
frame, | |
task.get('ad_context', {}), | |
task.get('metrics_data', initial_metrics_df.copy()), | |
task.get('frame_count', 0), | |
task.get('start_time', time.time()), | |
task.get('audio_path'), | |
task.get('gemini_model') | |
) | |
# Put result in results queue | |
results_queue.put({ | |
'annotated_frame': result[0], | |
'metrics': result[1], | |
'audio_metrics': result[2], | |
'metrics_df': result[3], | |
'state_fig': result[4], | |
'metrics_fig': result[5] | |
}) | |
# Mark task as done | |
processing_queue.task_done() | |
except queue.Empty: | |
continue | |
except Exception as e: | |
print(f"Error in background processing: {e}") | |
continue | |
# --- Video File Processing with Progress Updates --- | |
def process_video_file( | |
video_file: Union[str, np.ndarray], | |
ad_description: str = "", | |
ad_detail: str = "", | |
ad_type: str = "Video", | |
sampling_rate: int = 5, # Process every Nth frame | |
save_processed_video: bool = True, | |
progress=gr.Progress() | |
) -> Tuple[str, str, str, pd.DataFrame]: | |
""" | |
Process a video file and analyze facial expressions frame by frame | |
Args: | |
video_file: Path to video file or video array | |
ad_description: Description of the ad being watched | |
ad_detail: Detail focus of the ad | |
ad_type: Type of ad (Video, Image, Audio, Text, Funny, etc.) | |
sampling_rate: Process every Nth frame | |
save_processed_video: Whether to save the processed video with annotations | |
progress: Gradio progress bar | |
Returns: | |
Tuple of (csv_path, audio_path, processed_video_path, metrics_dataframe) | |
""" | |
# Initialize Gemini model | |
gemini_model = configure_gemini() | |
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
csv_path = CSV_FILENAME_TEMPLATE.format(timestamp=timestamp) | |
audio_path = AUDIO_FILENAME_TEMPLATE.format(timestamp=timestamp) | |
video_path = VIDEO_FILENAME_TEMPLATE.format(timestamp=timestamp) if save_processed_video else None | |
# Setup ad context | |
gemini_result = call_gemini_api_for_ad(gemini_model, ad_description, ad_detail, ad_type) | |
ad_context = { | |
"ad_description": ad_description, | |
"ad_detail": ad_detail, | |
"ad_type": ad_type, | |
"gemini_ad_analysis": gemini_result | |
} | |
progress(0, desc="Initializing video processing") | |
# Initialize capture | |
if isinstance(video_file, str): | |
cap = cv2.VideoCapture(video_file) | |
else: | |
# Create a temporary file for the video array | |
temp_dir = tempfile.mkdtemp() | |
temp_path = os.path.join(temp_dir, "temp_video.mp4") | |
# Convert video array to file | |
if isinstance(video_file, np.ndarray): | |
# Assuming it's a series of frames | |
h, w = video_file[0].shape[:2] if len(video_file) > 0 else (480, 640) | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
temp_writer = cv2.VideoWriter(temp_path, fourcc, 30, (w, h)) | |
for frame in video_file: | |
temp_writer.write(frame) | |
temp_writer.release() | |
video_file = temp_path | |
cap = cv2.VideoCapture(temp_path) | |
if not cap.isOpened(): | |
print("Error: Could not open video.") | |
return None, None, None, None | |
# Extract audio for analysis | |
audio_extracted = extract_audio_from_video(video_file, audio_path) | |
# Get video properties | |
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
# Initialize video writer if saving processed video | |
if save_processed_video: | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
out = cv2.VideoWriter(video_path, fourcc, fps, (frame_width, frame_height)) | |
# Process video frames | |
metrics_data = [] | |
frame_count = 0 | |
# Create a thread pool for audio processing | |
with ThreadPoolExecutor(max_workers=2) as executor: | |
# Queue for audio analysis results | |
audio_futures = {} | |
progress(0.1, desc="Starting frame analysis") | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
# Only process every Nth frame (according to sampling_rate) | |
process_this_frame = frame_count % sampling_rate == 0 | |
frame_timestamp = frame_count / fps | |
if process_this_frame: | |
progress(min(0.1 + 0.8 * (frame_count / total_frames), 0.9), | |
desc=f"Processing frame {frame_count}/{total_frames}") | |
# Extract facial landmarks | |
landmarks = extract_face_landmarks(frame, face_mesh) | |
# Submit audio analysis task if audio was extracted | |
if process_this_frame and audio_extracted and frame_timestamp not in audio_futures: | |
audio_futures[frame_timestamp] = executor.submit( | |
analyze_audio_segment, audio_path, frame_timestamp, 1.0 | |
) | |
# Get audio analysis results if available | |
audio_metrics = None | |
if frame_timestamp in audio_futures and audio_futures[frame_timestamp].done(): | |
audio_metrics = audio_futures[frame_timestamp].result() | |
# Calculate metrics if landmarks detected | |
if landmarks: | |
calculated_metrics = calculate_metrics(landmarks, ad_context) | |
user_state, detailed_analysis = interpret_metrics_with_gemini( | |
gemini_model, calculated_metrics, audio_metrics, ad_context, frame_timestamp | |
) | |
# Create a row for the dataframe | |
row = { | |
'timestamp': frame_timestamp, | |
'frame_number': frame_count, | |
**calculated_metrics | |
} | |
# Add audio metrics if available | |
if audio_metrics: | |
row.update(audio_metrics) | |
else: | |
# Default audio metrics | |
row.update({m: 0.5 for m in audio_metrics}) | |
# Add context and state | |
row.update(ad_context) | |
row['user_state'] = user_state | |
row['detailed_user_analysis'] = detailed_analysis | |
metrics_data.append(row) | |
# Annotate the frame with facial landmarks | |
if save_processed_video: | |
annotated_frame = annotate_frame(frame, landmarks) | |
# Add user state text to frame | |
cv2.putText( | |
annotated_frame, | |
f"State: {user_state}", | |
(10, 30), | |
cv2.FONT_HERSHEY_SIMPLEX, | |
0.7, | |
(0, 255, 0), | |
2 | |
) | |
# Add audio emotion if available | |
if audio_metrics and 'audio_emotion' in audio_metrics: | |
cv2.putText( | |
annotated_frame, | |
f"Audio: {audio_metrics['audio_emotion']}", | |
(10, 60), | |
cv2.FONT_HERSHEY_SIMPLEX, | |
0.7, | |
(255, 0, 0), | |
2 | |
) | |
out.write(annotated_frame) | |
elif save_processed_video: | |
# If no landmarks detected, still write the original frame to the video | |
out.write(frame) | |
elif save_processed_video: | |
# For frames not being analyzed, still include them in the output video | |
out.write(frame) | |
frame_count += 1 | |
# Wait for all audio analysis to complete | |
for future in audio_futures.values(): | |
if not future.done(): | |
future.result() # This will wait for completion | |
progress(0.95, desc="Finalizing results") | |
# Release resources | |
cap.release() | |
if save_processed_video: | |
out.release() | |
# Create DataFrame and save to CSV | |
metrics_df = pd.DataFrame(metrics_data) | |
if not metrics_df.empty: | |
metrics_df.to_csv(csv_path, index=False) | |
progress(1.0, desc="Processing complete") | |
else: | |
progress(1.0, desc="No facial data detected") | |
# Return results | |
return csv_path, audio_path, video_path, metrics_df | |
# --- Updated Webcam Processing Function --- | |
def process_webcam_frame( | |
frame: np.ndarray, | |
ad_context: Dict[str, Any], | |
metrics_data: pd.DataFrame, | |
frame_count: int, | |
start_time: float, | |
audio_path: str = None, | |
gemini_model = None | |
) -> Tuple[np.ndarray, Dict[str, float], Dict[str, Any], pd.DataFrame, object, object]: | |
""" | |
Process a single webcam frame with audio integration | |
Args: | |
frame: Input frame from webcam | |
ad_context: Ad context dictionary | |
metrics_data: DataFrame to accumulate metrics | |
frame_count: Current frame count | |
start_time: Start time of the session | |
audio_path: Path to extracted audio file (if available) | |
gemini_model: Configured Gemini model instance | |
Returns: | |
Tuple of (annotated_frame, metrics_dict, audio_metrics, updated_metrics_df, state_fig, metrics_fig) | |
""" | |
if frame is None: | |
return None, None, None, metrics_data, None, None | |
# Extract facial landmarks | |
landmarks = extract_face_landmarks(frame, face_mesh) | |
# Get current timestamp | |
current_time = time.time() | |
elapsed_time = current_time - start_time | |
# Analyze audio segment if available | |
audio_metrics = None | |
if audio_path and os.path.exists(audio_path): | |
audio_metrics = analyze_audio_segment(audio_path, elapsed_time, 1.0) | |
# Calculate metrics if landmarks detected | |
if landmarks: | |
calculated_metrics = calculate_metrics(landmarks, ad_context) | |
user_state, detailed_analysis = interpret_metrics_with_gemini( | |
gemini_model, calculated_metrics, audio_metrics, ad_context, elapsed_time | |
) | |
# Create a row for the dataframe | |
row = { | |
'timestamp': elapsed_time, | |
'frame_number': frame_count, | |
**calculated_metrics | |
} | |
# Add audio metrics if available | |
if audio_metrics: | |
row.update(audio_metrics) | |
else: | |
# Default audio metrics | |
row.update({m: 0.5 for m in audio_metrics}) | |
# Add context and state | |
row.update(ad_context) | |
row['user_state'] = user_state | |
row['detailed_user_analysis'] = detailed_analysis | |
# Add row to DataFrame | |
new_row_df = pd.DataFrame([row], columns=all_columns) | |
metrics_data = pd.concat([metrics_data, new_row_df], ignore_index=True) | |
# Create visualizations | |
metrics_plot = update_metrics_visualization( | |
calculated_metrics, | |
audio_metrics, | |
title=f"Frame {frame_count} Metrics" | |
) | |
state_plot = create_user_state_display(user_state, detailed_analysis) | |
# Annotate the frame | |
annotated_frame = annotate_frame(frame, landmarks) | |
# Add user state text to frame | |
cv2.putText( | |
annotated_frame, | |
f"State: {user_state}", | |
(10, 30), | |
cv2.FONT_HERSHEY_SIMPLEX, | |
0.7, | |
(0, 255, 0), | |
2 | |
) | |
# Add audio emotion if available | |
if audio_metrics and 'audio_emotion' in audio_metrics: | |
cv2.putText( | |
annotated_frame, | |
f"Audio: {audio_metrics['audio_emotion']}", | |
(10, 60), | |
cv2.FONT_HERSHEY_SIMPLEX, | |
0.7, | |
(255, 0, 0), | |
2 | |
) | |
return annotated_frame, calculated_metrics, audio_metrics, metrics_data, state_plot, metrics_plot | |
else: | |
# No face detected | |
return frame, None, None, metrics_data, None, None | |
# --- Updated Webcam Session Functions --- | |
def start_webcam_session( | |
ad_description: str = "", | |
ad_detail: str = "", | |
ad_type: str = "Video", | |
save_interval: int = 100, # Save CSV every N frames | |
record_audio: bool = False | |
) -> Dict[str, Any]: | |
""" | |
Initialize a webcam session for facial analysis with audio recording | |
Args: | |
ad_description: Description of the ad being watched | |
ad_detail: Detail focus of the ad | |
ad_type: Type of ad | |
save_interval: How often to save data to CSV | |
record_audio: Whether to record audio during session | |
Returns: | |
Session context dictionary | |
""" | |
# Generate timestamp for file naming | |
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
csv_path = CSV_FILENAME_TEMPLATE.format(timestamp=timestamp) | |
audio_path = AUDIO_FILENAME_TEMPLATE.format(timestamp=timestamp) if record_audio else None | |
# Initialize Gemini model | |
gemini_model = configure_gemini() | |
# Setup ad context | |
gemini_result = call_gemini_api_for_ad(gemini_model, ad_description, ad_detail, ad_type) | |
ad_context = { | |
"ad_description": ad_description, | |
"ad_detail": ad_detail, | |
"ad_type": ad_type, | |
"gemini_ad_analysis": gemini_result | |
} | |
# Initialize session context | |
session = { | |
"start_time": time.time(), | |
"frame_count": 0, | |
"metrics_data": initial_metrics_df.copy(), | |
"ad_context": ad_context, | |
"csv_path": csv_path, | |
"audio_path": audio_path, | |
"save_interval": save_interval, | |
"last_saved": 0, | |
"gemini_model": gemini_model, | |
"processing_thread": None | |
} | |
# Start background processing thread | |
processor = threading.Thread(target=process_frames_in_background, args=(session,)) | |
processor.daemon = True | |
processor.start() | |
session["processing_thread"] = processor | |
return session | |
def update_webcam_session( | |
session: Dict[str, Any], | |
frame: np.ndarray | |
) -> Tuple[np.ndarray, object, object, Dict[str, Any]]: | |
""" | |
Update webcam session with a new frame | |
Args: | |
session: Session context dictionary | |
frame: New frame from webcam | |
Returns: | |
Tuple of (annotated_frame, state_plot, metrics_plot, updated_session) | |
""" | |
if session is None: | |
return frame, None, None, session | |
# Add task to processing queue | |
processing_queue.put({ | |
'command': 'process', | |
'frame': frame.copy() if frame is not None else None, | |
'ad_context': session["ad_context"], | |
'metrics_data': session["metrics_data"], | |
'frame_count': session["frame_count"], | |
'start_time': session["start_time"], | |
'audio_path': session["audio_path"], | |
'gemini_model': session["gemini_model"] | |
}) | |
# Update frame count | |
session["frame_count"] += 1 | |
# Get result if available | |
try: | |
result = results_queue.get_nowait() | |
annotated_frame = result.get('annotated_frame', frame) | |
state_fig = result.get('state_fig') | |
metrics_fig = result.get('metrics_fig') | |
session["metrics_data"] = result.get('metrics_df', session["metrics_data"]) | |
results_queue.task_done() | |
except queue.Empty: | |
# No result yet, return original frame | |
annotated_frame = frame | |
state_fig = None | |
metrics_fig = None | |
# Save CSV periodically | |
if session["frame_count"] - session["last_saved"] >= session["save_interval"]: | |
if not session["metrics_data"].empty: | |
session["metrics_data"].to_csv(session["csv_path"], index=False) | |
session["last_saved"] = session["frame_count"] | |
return annotated_frame, state_fig, metrics_fig, session | |
def end_webcam_session(session: Dict[str, Any]) -> Tuple[str, str]: | |
""" | |
End a webcam session and save final results | |
Args: | |
session: Session context dictionary | |
Returns: | |
Tuple of (csv_path, audio_path) | |
""" | |
if session is None: | |
return None, None | |
# Stop background processing thread | |
if session["processing_thread"] and session["processing_thread"].is_alive(): | |
processing_queue.put({"command": "stop"}) | |
session["processing_thread"].join(timeout=2.0) | |
# Save final metrics to CSV | |
if not session["metrics_data"].empty: | |
session["metrics_data"].to_csv(session["csv_path"], index=False) | |
print(f"Session ended. Data saved to {session['csv_path']}") | |
return session["csv_path"], session["audio_path"] | |
# --- Create Enhanced Gradio Interface --- | |
def create_api_interface(): | |
with gr.Blocks(title="Enhanced Facial Analysis APIs") as iface: | |
gr.Markdown("# Enhanced Facial Analysis APIs\nAnalyze facial expressions and audio in videos or webcam feed") | |
with gr.Tab("Video File API"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
video_input = gr.Video(label="Upload Video") | |
vid_ad_desc = gr.Textbox(label="Ad Description") | |
vid_ad_detail = gr.Textbox(label="Ad Detail Focus") | |
vid_ad_type = gr.Radio( | |
["Video", "Image", "Audio", "Text", "Funny", "Serious", "Action", "Informative"], | |
label="Ad Type/Genre", | |
value="Video" | |
) | |
sampling_rate = gr.Slider( | |
minimum=1, maximum=30, step=1, value=5, | |
label="Sampling Rate (process every N frames)" | |
) | |
save_video = gr.Checkbox(label="Save Processed Video", value=True) | |
process_btn = gr.Button("Process Video") | |
with gr.Column(scale=2): | |
with gr.Row(): | |
output_text = gr.Textbox(label="Processing Status") | |
with gr.Row(): | |
output_video = gr.Video(label="Processed Video") | |
with gr.Row(): | |
output_plot = gr.Plot(label="Metrics Visualization") | |
user_state_plot = gr.Plot(label="User State Analysis") | |
with gr.Row(): | |
output_csv = gr.File(label="Download CSV Results") | |
output_audio = gr.Audio(label="Extracted Audio") | |
# Define function to handle video processing with live updates | |
def handle_video_processing(video, desc, detail, ad_type, rate, save_vid, progress=gr.Progress()): | |
if video is None: | |
return "No video uploaded", None, None, None, None, None | |
try: | |
progress(0.05, "Starting video processing...") | |
csv_path, audio_path, video_path, metrics_df = process_video_file( | |
video, | |
ad_description=desc, | |
ad_detail=detail, | |
ad_type=ad_type, | |
sampling_rate=rate, | |
save_processed_video=save_vid, | |
progress=progress | |
) | |
if metrics_df is None or metrics_df.empty: | |
return "No facial data detected in video", None, None, None, None, None | |
# Get a sample row for visualization | |
middle_idx = len(metrics_df) // 2 | |
sample_row = metrics_df.iloc[middle_idx].to_dict() | |
# Generate visualizations | |
metrics_plot = update_metrics_visualization( | |
{k: v for k, v in sample_row.items() if k in metrics}, | |
{k: v for k, v in sample_row.items() if k in audio_metrics}, | |
title=f"Sample Frame Metrics (Frame {sample_row['frame_number']})" | |
) | |
state_plot = create_user_state_display( | |
sample_row.get('user_state', 'No state'), | |
sample_row.get('detailed_user_analysis', '') | |
) | |
processed_frames = metrics_df.shape[0] | |
total_duration = metrics_df['timestamp'].max() if not metrics_df.empty else 0 | |
result_text = f"✅ Processing complete!\n" | |
result_text += f"• Analyzed {processed_frames} frames over {total_duration:.2f} seconds\n" | |
result_text += f"• CSV saved to: {csv_path}\n" | |
if audio_path: | |
result_text += f"• Audio extracted to: {audio_path}\n" | |
if video_path: | |
result_text += f"• Processed video saved to: {video_path}\n" | |
return result_text, csv_path, video_path, audio_path, metrics_plot, state_plot | |
except Exception as e: | |
return f"Error processing video: {str(e)}", None, None, None, None, None | |
process_btn.click( | |
handle_video_processing, | |
inputs=[video_input, vid_ad_desc, vid_ad_detail, vid_ad_type, sampling_rate, save_video], | |
outputs=[output_text, output_csv, output_video, output_audio, output_plot, user_state_plot] | |
) | |
with gr.Tab("Webcam API"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
webcam_input = gr.Image(sources="webcam", streaming=True, label="Webcam Input", type="numpy") | |
web_ad_desc = gr.Textbox(label="Ad Description") | |
web_ad_detail = gr.Textbox(label="Ad Detail Focus") | |
web_ad_type = gr.Radio( | |
["Video", "Image", "Audio", "Text", "Funny", "Serious", "Action", "Informative"], | |
label="Ad Type/Genre", | |
value="Video" | |
) | |
record_audio = gr.Checkbox(label="Record Audio", value=True) | |
start_session_btn = gr.Button("Start Session") | |
end_session_btn = gr.Button("End Session") | |
with gr.Column(scale=2): | |
with gr.Row(): | |
processed_output = gr.Image(label="Processed Feed", type="numpy") | |
with gr.Row(): | |
metrics_plot = gr.Plot(label="Live Metrics") | |
state_plot = gr.Plot(label="User State Analysis") | |
with gr.Row(): | |
session_status = gr.Textbox(label="Session Status") | |
download_csv = gr.File(label="Download Session Data") | |
# Session state | |
session_data = gr.State(value=None) | |
# Define session handlers | |
def start_session(desc, detail, ad_type, record_audio): | |
try: | |
session = start_webcam_session( | |
ad_description=desc, | |
ad_detail=detail, | |
ad_type=ad_type, | |
record_audio=record_audio | |
) | |
status_text = "✅ Session started successfully!\n\n" | |
status_text += f"• Ad Context: {desc} ({ad_type})\n" | |
status_text += f"• Focus: {detail}\n" | |
status_text += f"• Audio Recording: {'Enabled' if record_audio else 'Disabled'}\n" | |
status_text += f"• Data will be saved to: {session['csv_path']}" | |
return session, status_text | |
except Exception as e: | |
return None, f"Error starting session: {str(e)}" | |
def process_frame(frame, session): | |
if session is None or frame is None: | |
return frame, None, None, session | |
try: | |
annotated_frame, state_fig, metrics_fig, updated_session = update_webcam_session(session, frame) | |
return annotated_frame, state_fig, metrics_fig, updated_session | |
except Exception as e: | |
print(f"Error processing frame: {e}") | |
return frame, None, None, session | |
def end_session(session): | |
if session is None: | |
return "No active session", None | |
try: | |
csv_path, audio_path = end_webcam_session(session) | |
status_text = "✅ Session ended successfully!\n\n" | |
status_text += f"• Data saved to: {csv_path}\n" | |
if audio_path: | |
status_text += f"• Audio saved to: {audio_path}" | |
return status_text, csv_path | |
except Exception as e: | |
return f"Error ending session: {str(e)}", None | |
start_session_btn.click( | |
start_session, | |
inputs=[web_ad_desc, web_ad_detail, web_ad_type, record_audio], | |
outputs=[session_data, session_status] | |
) | |
webcam_input.stream( | |
process_frame, | |
inputs=[webcam_input, session_data], | |
outputs=[processed_output, state_plot, metrics_plot, session_data] | |
) | |
end_session_btn.click( | |
end_session, | |
inputs=[session_data], | |
outputs=[session_status, download_csv] | |
) | |
return iface | |
# Entry point | |
if __name__ == "__main__": | |
print("Starting Enhanced Facial Analysis API server...") | |
# Pre-initialize models if needed | |
# initialize_audio_model() | |
iface = create_api_interface() | |
iface.launch(debug=True) |