Spaces:
Running
Running
Kevin King
REFAC: Update model loading to use staged approach and enhance audio analysis in Streamlit app
ea6ec54
import os | |
import streamlit as st | |
import numpy as np | |
import torch | |
import whisper | |
from transformers import pipeline, AutoModelForAudioClassification, AutoFeatureExtractor | |
from deepface import DeepFace | |
import logging | |
import soundfile as sf | |
import tempfile | |
import cv2 | |
from moviepy.editor import VideoFileClip | |
import time | |
import pandas as pd | |
from sklearn.metrics.pairwise import cosine_similarity | |
import matplotlib.pyplot as plt | |
# Create a cross-platform, writable cache directory for all libraries | |
CACHE_DIR = os.path.join(tempfile.gettempdir(), "affectlink_cache") | |
DEEPFACE_CACHE_PATH = os.path.join(CACHE_DIR, ".deepface", "weights") | |
os.makedirs(DEEPFACE_CACHE_PATH, exist_ok=True) # Proactively create the full path | |
os.environ['DEEPFACE_HOME'] = CACHE_DIR | |
os.environ['HF_HOME'] = CACHE_DIR | |
# --- Page Configuration --- | |
st.set_page_config(page_title="AffectLink Demo", page_icon="π", layout="wide") | |
st.title("AffectLink: Post-Hoc Emotion Analysis") | |
st.write("Upload a short video clip (under 30 seconds) to see a multimodal emotion analysis.") | |
# --- Logger Configuration --- | |
logging.basicConfig(level=logging.INFO) | |
# --- Emotion Mappings --- | |
UNIFIED_EMOTIONS = ['angry', 'happy', 'sad', 'neutral'] | |
TEXT_TO_UNIFIED = {'neutral': 'neutral', 'joy': 'happy', 'sadness': 'sad', 'anger': 'angry'} | |
SER_TO_UNIFIED = {'neu': 'neutral', 'hap': 'happy', 'sad': 'sad', 'ang': 'angry'} | |
FACIAL_TO_UNIFIED = {'neutral': 'neutral', 'happy': 'happy', 'sad': 'sad', 'angry': 'angry', 'fear':None, 'surprise':None, 'disgust':None} | |
AUDIO_SAMPLE_RATE = 16000 | |
# --- Model Loading (Staged) --- | |
def load_audio_models(): | |
with st.spinner("Loading audio analysis models..."): | |
whisper_model = whisper.load_model("tiny.en", download_root=os.path.join(CACHE_DIR, "whisper")) | |
text_classifier = pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base", top_k=None) | |
ser_model_name = "superb/hubert-large-superb-er" | |
ser_feature_extractor = AutoFeatureExtractor.from_pretrained(ser_model_name) | |
ser_model = AutoModelForAudioClassification.from_pretrained(ser_model_name) | |
return whisper_model, text_classifier, ser_model, ser_feature_extractor | |
# Models will be loaded on demand | |
# --- Helper Functions for Analysis --- | |
def create_unified_vector(scores_dict, mapping_dict): | |
vector = np.zeros(len(UNIFIED_EMOTIONS)) | |
total_score = 0 | |
# Use .items() to iterate over keys and values | |
for label, score in scores_dict.items(): | |
unified_label = mapping_dict.get(label) | |
if unified_label in UNIFIED_EMOTIONS: | |
vector[UNIFIED_EMOTIONS.index(unified_label)] += score | |
total_score += score | |
if total_score > 0: | |
vector /= total_score | |
return vector | |
def get_consistency_level(cosine_sim): | |
if np.isnan(cosine_sim): return "N/A" | |
if cosine_sim >= 0.8: return "High" | |
if cosine_sim >= 0.6: return "Medium" | |
if cosine_sim >= 0.3: return "Low" | |
return "Very Low" | |
# --- Helper Functions for Results Display --- | |
def process_timeline_to_df(timeline, mapping): | |
if not timeline: return pd.DataFrame(columns=UNIFIED_EMOTIONS) | |
df = pd.DataFrame.from_dict(timeline, orient='index') | |
df_unified = pd.DataFrame(index=df.index, columns=UNIFIED_EMOTIONS).fillna(0.0) | |
for raw_col in df.columns: | |
unified_col = mapping.get(raw_col) | |
if unified_col: | |
df_unified[unified_col] += df[raw_col] | |
return df_unified | |
def get_dominant_emotion_from_df(df): | |
if df.empty or df.sum().sum() == 0: return "N/A" | |
return df.sum().idxmax().capitalize() | |
def get_avg_unified_scores(df): | |
return df.mean().to_dict() if not df.empty else {} | |
def display_results(): | |
"""Display the final analysis results using data from session state""" | |
st.header("Analysis Results") | |
# Get data from session state | |
full_transcription = st.session_state.get('full_transcription', 'No speech detected.') | |
ser_timeline = st.session_state.get('ser_timeline', {}) | |
ter_timeline = st.session_state.get('ter_timeline', {}) | |
fer_timeline = st.session_state.get('fer_timeline', {}) | |
duration = st.session_state.get('duration', 0) | |
# Process timelines | |
fer_df = process_timeline_to_df(fer_timeline, FACIAL_TO_UNIFIED) | |
ser_df = process_timeline_to_df(ser_timeline, SER_TO_UNIFIED) | |
ter_df = process_timeline_to_df(ter_timeline, TEXT_TO_UNIFIED) | |
# Get dominant emotions | |
dominant_fer = get_dominant_emotion_from_df(fer_df) | |
dominant_ser = get_dominant_emotion_from_df(ser_df) | |
dominant_text = get_dominant_emotion_from_df(ter_df) | |
# Get average scores | |
fer_avg_scores = get_avg_unified_scores(fer_df) | |
ser_avg_scores = get_avg_unified_scores(ser_df) | |
ter_avg_scores = get_avg_unified_scores(ter_df) | |
# Calculate vectors and similarity | |
fer_vector = create_unified_vector(fer_avg_scores, {e:e for e in UNIFIED_EMOTIONS}) | |
ser_vector = create_unified_vector(ser_avg_scores, {e:e for e in UNIFIED_EMOTIONS}) | |
text_vector = create_unified_vector(ter_avg_scores, {e:e for e in UNIFIED_EMOTIONS}) | |
similarities = [cosine_similarity([fer_vector], [text_vector])[0][0], cosine_similarity([fer_vector], [ser_vector])[0][0], cosine_similarity([ser_vector], [text_vector])[0][0]] | |
avg_similarity = np.nanmean([s for s in similarities if not np.isnan(s)]) | |
# Display transcription | |
st.subheader("Transcription") | |
st.markdown(f"> *{full_transcription}*") | |
st.divider() | |
# Display summary and timeline | |
col1, col2 = st.columns([1, 2]) | |
with col1: | |
st.subheader("Multimodal Summary") | |
st.metric("Dominant Facial Emotion", dominant_fer) | |
st.metric("Dominant Text Emotion", dominant_text) | |
st.metric("Dominant Speech Emotion", dominant_ser) | |
st.metric("Emotion Consistency", get_consistency_level(avg_similarity), f"{avg_similarity:.2f} Avg. Cosine Similarity") | |
with col2: | |
st.subheader("Unified Emotion Timeline") | |
if duration > 0: | |
full_index = np.arange(0, duration, 0.5) | |
combined_df = pd.DataFrame(index=full_index) | |
# ECI Timeline Calculation | |
eci_timeline = {} | |
for t_stamp in full_index: | |
vectors = [] | |
# Interpolate to get a value for any timestamp | |
fer_scores = fer_df.reindex(fer_df.index.union([t_stamp])).interpolate(method='linear').loc[t_stamp] | |
if not fer_scores.isnull().all(): | |
vectors.append(create_unified_vector(fer_scores.to_dict(), {e:e for e in UNIFIED_EMOTIONS})) | |
if int(t_stamp) in ser_df.index: | |
vectors.append(create_unified_vector(ser_df.loc[int(t_stamp)].to_dict(), {e:e for e in UNIFIED_EMOTIONS})) | |
if int(t_stamp) in ter_df.index: | |
vectors.append(create_unified_vector(ter_df.loc[int(t_stamp)].to_dict(), {e:e for e in UNIFIED_EMOTIONS})) | |
if len(vectors) >= 2: | |
sims = [cosine_similarity([v1], [v2])[0][0] for i, v1 in enumerate(vectors) for v2 in vectors[i+1:]] | |
eci_timeline[t_stamp] = np.mean(sims) | |
if not fer_df.empty: | |
fer_df_resampled = fer_df.reindex(fer_df.index.union(full_index)).interpolate(method='linear').reindex(full_index) | |
for e in UNIFIED_EMOTIONS: combined_df[f'Facial_{e}'] = fer_df_resampled.get(e, 0.0) | |
if not ser_df.empty: | |
ser_df_resampled = ser_df.reindex(ser_df.index.union(full_index)).interpolate(method='linear').reindex(full_index) | |
for e in UNIFIED_EMOTIONS: combined_df[f'Speech_{e}'] = ser_df_resampled.get(e, 0.0) | |
if not ter_df.empty: | |
ter_df_resampled = ter_df.reindex(ter_df.index.union(full_index)).interpolate(method='linear').reindex(full_index) | |
for e in UNIFIED_EMOTIONS: combined_df[f'Text_{e}'] = ter_df_resampled.get(e, 0.0) | |
if eci_timeline: | |
eci_series = pd.Series(eci_timeline).reindex(full_index).interpolate(method='linear') | |
combined_df['ECI'] = eci_series | |
combined_df.fillna(0, inplace=True) | |
if not combined_df.empty: | |
fig, ax = plt.subplots(figsize=(10, 5)) | |
colors = {'happy': 'green', 'sad': 'blue', 'angry': 'red', 'neutral': 'gray'} | |
styles = {'Facial': '-', 'Speech': '--', 'Text': ':'} | |
for col in combined_df.columns: | |
if col == 'ECI': continue | |
modality, emotion = col.split('_') | |
if emotion in colors: | |
ax.plot(combined_df.index, combined_df[col], label=f'{modality} {emotion.capitalize()}', color=colors[emotion], linestyle=styles[modality], alpha=0.7) | |
if 'ECI' in combined_df.columns: | |
ax.plot(combined_df.index, combined_df['ECI'], label='Emotion Consistency', color='black', linewidth=2.5, alpha=0.9) | |
ax.set_title("Emotion Confidence Over Time (Normalized)") | |
ax.set_xlabel("Time (seconds)") | |
ax.set_ylabel("Confidence Score (0-1)") | |
ax.set_ylim(0, 1) | |
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5)) | |
ax.grid(True, which='both', linestyle='--', linewidth=0.5) | |
plt.tight_layout() | |
st.pyplot(fig) | |
else: | |
st.write("No emotion data available to plot.") | |
else: | |
st.write("No timeline data available.") | |
# --- Two-Stage UI and Processing Logic --- | |
uploaded_file = st.file_uploader("Choose a video file...", type=["mp4", "mov", "avi", "mkv"]) | |
# Initialize session state variables | |
if 'temp_video_path' not in st.session_state: | |
st.session_state.temp_video_path = None | |
if 'uploaded_file_id' not in st.session_state: | |
st.session_state.uploaded_file_id = None | |
# Clear previous results when a new file is uploaded | |
if uploaded_file is not None: | |
file_id = uploaded_file.file_id if hasattr(uploaded_file, 'file_id') else str(hash(uploaded_file.name + str(uploaded_file.size))) | |
if st.session_state.uploaded_file_id != file_id: | |
# New file uploaded, clear previous results | |
st.session_state.uploaded_file_id = file_id | |
for key in ['stage1_complete', 'stage2_complete', 'full_transcription', 'ser_timeline', 'ter_timeline', 'fer_timeline', 'duration']: | |
if key in st.session_state: | |
del st.session_state[key] | |
# Save the video file | |
if st.session_state.temp_video_path and os.path.exists(st.session_state.temp_video_path): | |
try: | |
os.unlink(st.session_state.temp_video_path) | |
except Exception: | |
pass | |
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tfile: | |
tfile.write(uploaded_file.read()) | |
st.session_state.temp_video_path = tfile.name | |
if uploaded_file is not None and st.session_state.temp_video_path: | |
st.video(st.session_state.temp_video_path) | |
# Stage 1: Audio & Text Analysis | |
if not st.session_state.get('stage1_complete', False): | |
if st.button("π΅ Step 1: Analyze Audio & Text", type="primary"): | |
try: | |
# Load audio models | |
whisper_model, text_classifier, ser_model, ser_feature_extractor = load_audio_models() | |
ser_timeline, ter_timeline = {}, {} | |
full_transcription = "No speech detected." | |
video_clip = VideoFileClip(st.session_state.temp_video_path) | |
duration = video_clip.duration | |
st.session_state.duration = duration | |
with st.spinner("Analyzing audio and text..."): | |
if video_clip.audio: | |
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as taudio: | |
video_clip.audio.write_audiofile(taudio.name, fps=AUDIO_SAMPLE_RATE, logger=None) | |
temp_audio_path = taudio.name | |
# Transcription | |
whisper_result = whisper_model.transcribe( | |
temp_audio_path, | |
word_timestamps=True, | |
fp16=False, | |
condition_on_previous_text=False | |
) | |
full_transcription = whisper_result['text'].strip() | |
# Speech emotion recognition | |
audio_array, _ = sf.read(temp_audio_path, dtype='float32') | |
if audio_array.ndim == 2: | |
audio_array = audio_array.mean(axis=1) | |
for i in range(int(duration)): | |
start_sample, end_sample = i * AUDIO_SAMPLE_RATE, (i + 1) * AUDIO_SAMPLE_RATE | |
chunk = audio_array[start_sample:end_sample] | |
if len(chunk) > 400: | |
inputs = ser_feature_extractor(chunk, sampling_rate=AUDIO_SAMPLE_RATE, return_tensors="pt", padding=True) | |
with torch.no_grad(): | |
logits = ser_model(**inputs).logits | |
scores = torch.nn.functional.softmax(logits, dim=1).squeeze() | |
ser_timeline[i] = {ser_model.config.id2label[k]: score.item() for k, score in enumerate(scores)} | |
# Text emotion recognition | |
words_in_segment = [seg['word'] for seg in whisper_result.get('segments', []) if seg['start'] >= i and seg['start'] < i+1 for seg in seg.get('words', [])] | |
segment_text = " ".join(words_in_segment).strip() | |
if segment_text: | |
text_emotions = text_classifier(segment_text)[0] | |
ter_timeline[i] = {emo['label']: emo['score'] for emo in text_emotions} | |
# Clean up audio file | |
if os.path.exists(temp_audio_path): | |
os.unlink(temp_audio_path) | |
video_clip.close() | |
# Store results in session state | |
st.session_state.full_transcription = full_transcription | |
st.session_state.ser_timeline = ser_timeline | |
st.session_state.ter_timeline = ter_timeline | |
st.session_state.stage1_complete = True | |
st.success("β Audio analysis complete! Speech and text emotions have been analyzed.") | |
st.rerun() | |
except Exception as e: | |
st.error(f"Error during audio analysis: {str(e)}") | |
else: | |
st.success("β Stage 1 (Audio & Text Analysis) - Complete!") | |
# Stage 2: Facial Analysis | |
if st.session_state.get('stage1_complete', False) and not st.session_state.get('stage2_complete', False): | |
if st.button("π Step 2: Analyze Facial Expressions", type="primary"): | |
try: | |
fer_timeline = {} | |
with st.spinner("Analyzing facial expressions..."): | |
cap = cv2.VideoCapture(st.session_state.temp_video_path) | |
fps = cap.get(cv2.CAP_PROP_FPS) or 30 | |
frame_count = 0 | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
timestamp = frame_count / fps | |
if frame_count % int(fps) == 0: | |
analysis = DeepFace.analyze(frame, actions=['emotion'], enforce_detection=False, silent=True) | |
if isinstance(analysis, list) and len(analysis) > 0: | |
fer_timeline[timestamp] = {k: v / 100.0 for k, v in analysis[0]['emotion'].items()} | |
frame_count += 1 | |
cap.release() | |
# Store results in session state | |
st.session_state.fer_timeline = fer_timeline | |
st.session_state.stage2_complete = True | |
st.success("β Facial analysis complete! All analyses are now finished.") | |
st.rerun() | |
except Exception as e: | |
st.error(f"Error during facial analysis: {str(e)}") | |
elif st.session_state.get('stage2_complete', False): | |
st.success("β Stage 2 (Facial Expression Analysis) - Complete!") | |
# Display results if both stages are complete | |
if st.session_state.get('stage1_complete', False) and st.session_state.get('stage2_complete', False): | |
display_results() | |
# Cleanup on app restart or when session ends | |
if st.session_state.temp_video_path and not uploaded_file: | |
try: | |
if os.path.exists(st.session_state.temp_video_path): | |
os.unlink(st.session_state.temp_video_path) | |
st.session_state.temp_video_path = None | |
except Exception: | |
pass |