Spaces:
Runtime error
Runtime error
# Part 1: Essential Imports and Setup | |
import gradio as gr | |
import torch | |
from transformers import WhisperProcessor, WhisperForConditionalGeneration | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
import librosa | |
import numpy as np | |
import plotly.graph_objects as go | |
import warnings | |
import os | |
from scipy.stats import kurtosis, skew | |
from anthropic import Anthropic | |
# Suppress warnings for cleaner output | |
warnings.filterwarnings('ignore') | |
# Initialize global model variables | |
processor = None | |
whisper_model = None | |
emotion_tokenizer = None | |
emotion_model = None | |
clinical_analyzer = None | |
# Part 2: Model Loading and Initialization | |
def load_models(): | |
"""Load and initialize speech and emotion analysis models.""" | |
global processor, whisper_model, emotion_tokenizer, emotion_model | |
try: | |
# Initialize speech recognition (Whisper) model | |
print("Loading Whisper model...") | |
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") | |
whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") | |
# Initialize emotion detection model | |
print("Loading emotion model...") | |
emotion_tokenizer = AutoTokenizer.from_pretrained("j-hartmann/emotion-english-distilroberta-base") | |
emotion_model = AutoModelForSequenceClassification.from_pretrained("j-hartmann/emotion-english-distilroberta-base") | |
# Set models to CPU for consistent performance | |
device = "cpu" | |
whisper_model.to(device) | |
emotion_model.to(device) | |
print("Models loaded successfully!") | |
return True | |
except Exception as e: | |
print(f"Error loading models: {str(e)}") | |
return False | |
# Part 3: Voice Feature Extraction | |
def extract_prosodic_features(waveform, sr): | |
"""Extract voice features including pitch, energy, and rhythm patterns.""" | |
try: | |
# Input validation | |
if waveform is None or len(waveform) == 0: | |
return None | |
features = {} | |
# Pitch analysis with enhanced accuracy | |
try: | |
pitches, magnitudes = librosa.piptrack( | |
y=waveform, | |
sr=sr, | |
fmin=50, # Minimum human voice frequency | |
fmax=2000, # Maximum human voice frequency | |
n_mels=128, # Frequency resolution | |
hop_length=512, | |
win_length=2048 | |
) | |
# Extract valid pitch contour | |
f0_contour = [ | |
pitches[magnitudes[:, t].argmax(), t] | |
for t in range(pitches.shape[1]) | |
if 50 <= pitches[magnitudes[:, t].argmax(), t] <= 2000 | |
] | |
# Calculate pitch statistics | |
if f0_contour: | |
features['pitch_mean'] = float(np.mean(f0_contour)) | |
features['pitch_std'] = float(np.std(f0_contour)) | |
features['pitch_range'] = float(np.ptp(f0_contour)) | |
else: | |
features['pitch_mean'] = 160.0 # Default adult pitch | |
features['pitch_std'] = 0.0 | |
features['pitch_range'] = 0.0 | |
except Exception as e: | |
print(f"Pitch extraction error: {e}") | |
features.update({'pitch_mean': 160.0, 'pitch_std': 0.0, 'pitch_range': 0.0}) | |
# Energy analysis | |
try: | |
rms = librosa.feature.rms( | |
y=waveform, | |
frame_length=2048, | |
hop_length=512, | |
center=True | |
)[0] | |
features.update({ | |
'energy_mean': float(np.mean(rms)), | |
'energy_std': float(np.std(rms)), | |
'energy_range': float(np.ptp(rms)) | |
}) | |
except Exception as e: | |
print(f"Energy extraction error: {e}") | |
features.update({'energy_mean': 0.02, 'energy_std': 0.0, 'energy_range': 0.0}) | |
# Rhythm analysis | |
try: | |
onset_env = librosa.onset.onset_strength( | |
y=waveform, | |
sr=sr, | |
hop_length=512, | |
aggregate=np.median | |
) | |
tempo = librosa.beat.tempo( | |
onset_envelope=onset_env, | |
sr=sr, | |
hop_length=512, | |
aggregate=None | |
)[0] | |
features['tempo'] = float(tempo) if 40 <= tempo <= 240 else 120.0 | |
except Exception as e: | |
print(f"Rhythm extraction error: {e}") | |
features['tempo'] = 120.0 | |
return features | |
except Exception as e: | |
print(f"Feature extraction failed: {e}") | |
return None | |
# Part 4: Clinical Analysis Integration | |
class ClinicalVoiceAnalyzer: | |
"""Analyze voice characteristics for psychological indicators.""" | |
def __init__(self): | |
"""Initialize the clinical analyzer with API and reference ranges.""" | |
self.anthropic = Anthropic(api_key=os.getenv('ANTHROPIC_API_KEY')) | |
self.model = "claude-3-opus-20240229" | |
self.reference_ranges = { | |
'pitch': {'min': 150, 'max': 400}, | |
'tempo': {'min': 90, 'max': 130}, | |
'energy': {'min': 0.01, 'max': 0.05} | |
} | |
print("Clinical analyzer ready") | |
def analyze_voice_metrics(self, features, emotions, transcription): | |
"""Generate clinical insights from voice and emotion data.""" | |
try: | |
prompt = self._create_clinical_prompt(features, emotions, transcription) | |
response = self.anthropic.messages.create( | |
model=self.model, | |
max_tokens=1000, | |
messages=[{"role": "user", "content": prompt}] | |
) | |
return self._format_analysis(response.content) | |
except Exception as e: | |
print(f"Clinical analysis error: {e}") | |
return self._generate_backup_analysis(features, emotions) | |
def _create_clinical_prompt(self, features, emotions, transcription): | |
"""Create detailed prompt for clinical analysis.""" | |
return f"""As a clinical voice analysis expert, provide a psychological assessment of: | |
Voice Metrics: | |
- Pitch: {features['pitch_mean']:.2f} Hz (Normal: {self.reference_ranges['pitch']['min']}-{self.reference_ranges['pitch']['max']} Hz) | |
- Pitch Variation: {features['pitch_std']:.2f} Hz | |
- Speech Rate: {features['tempo']:.2f} BPM (Normal: {self.reference_ranges['tempo']['min']}-{self.reference_ranges['tempo']['max']} BPM) | |
- Voice Energy: {features['energy_mean']:.4f} | |
Emotions Detected: | |
{', '.join(f'{emotion}: {score:.1%}' for emotion, score in emotions.items())} | |
Speech Content: | |
"{transcription}" | |
Provide: | |
1. Voice characteristic analysis | |
2. Emotional state assessment | |
3. Anxiety/depression indicators | |
4. Stress level evaluation | |
5. Clinical recommendations""" | |
def _format_analysis(self, analysis): | |
"""Format the clinical analysis output.""" | |
return f"\nClinical Assessment:\n{analysis}" | |
def _generate_backup_analysis(self, features, emotions): | |
"""Generate basic analysis when API is unavailable.""" | |
dominant_emotion = max(emotions.items(), key=lambda x: x[1]) | |
pitch_status = ( | |
"elevated" if features['pitch_mean'] > self.reference_ranges['pitch']['max'] | |
else "reduced" if features['pitch_mean'] < self.reference_ranges['pitch']['min'] | |
else "normal" | |
) | |
return f""" | |
Basic Voice Analysis (API Unavailable): | |
- Pitch Status: {pitch_status} ({features['pitch_mean']:.2f} Hz) | |
- Speech Rate: {features['tempo']:.2f} BPM | |
- Voice Energy Level: {features['energy_mean']:.4f} | |
- Primary Emotion: {dominant_emotion[0]} ({dominant_emotion[1]:.1%} confidence)""" | |
# Part 5: Visualization Functions | |
def create_feature_plots(features): | |
"""Create interactive visualizations of voice features.""" | |
try: | |
fig = go.Figure() | |
# Pitch visualization | |
pitch_data = { | |
'Mean': features['pitch_mean'], | |
'Std Dev': features['pitch_std'], | |
'Range': features['pitch_range'] | |
} | |
fig.add_trace(go.Bar( | |
name='Pitch Features (Hz)', | |
x=list(pitch_data.keys()), | |
y=list(pitch_data.values()), | |
marker_color='blue' | |
)) | |
# Energy visualization | |
energy_data = { | |
'Mean': features['energy_mean'], | |
'Std Dev': features['energy_std'], | |
'Range': features['energy_range'] | |
} | |
fig.add_trace(go.Bar( | |
name='Energy Features', | |
x=[f"Energy {k}" for k in energy_data.keys()], | |
y=list(energy_data.values()), | |
marker_color='red' | |
)) | |
# Tempo visualization | |
fig.add_trace(go.Scatter( | |
name='Speech Rate (BPM)', | |
x=['Tempo'], | |
y=[features['tempo']], | |
mode='markers', | |
marker=dict(size=15, color='green') | |
)) | |
# Layout configuration | |
fig.update_layout( | |
title='Voice Feature Analysis', | |
showlegend=True, | |
height=600, | |
barmode='group', | |
xaxis_title='Feature Type', | |
yaxis_title='Value', | |
template='plotly_white' | |
) | |
return fig.to_html(include_plotlyjs=True) | |
except Exception as e: | |
print(f"Plot creation error: {e}") | |
return None | |
def create_emotion_plot(emotions): | |
"""Create visualization of emotional analysis.""" | |
try: | |
fig = go.Figure(data=[ | |
go.Bar( | |
x=list(emotions.keys()), | |
y=list(emotions.values()), | |
marker_color=['#FF9999', '#66B2FF', '#99FF99', | |
'#FFCC99', '#FF99CC', '#99FFFF'] | |
) | |
]) | |
fig.update_layout( | |
title='Emotion Analysis', | |
xaxis_title='Emotion', | |
yaxis_title='Confidence Score', | |
yaxis_range=[0, 1], | |
template='plotly_white', | |
height=400 | |
) | |
return fig.to_html(include_plotlyjs=True) | |
except Exception as e: | |
print(f"Emotion plot error: {e}") | |
return None | |
# Part 6: Main Analysis Function | |
def analyze_audio(audio_input): | |
"""Process audio input and generate comprehensive analysis.""" | |
try: | |
# Validate input | |
if audio_input is None: | |
return "Please provide an audio input", None, None | |
# Load audio | |
audio_path = audio_input[0] if isinstance(audio_input, tuple) else audio_input | |
waveform, sr = librosa.load(audio_path, sr=16000, duration=30) | |
# Validate duration | |
duration = len(waveform) / sr | |
if duration < 0.5: | |
return "Audio too short (minimum 0.5 seconds needed)", None, None | |
# Extract features | |
features = extract_prosodic_features(waveform, sr) | |
if features is None: | |
return "Feature extraction failed", None, None | |
# Generate visualizations | |
feature_viz = create_feature_plots(features) | |
# Perform speech recognition | |
inputs = processor(waveform, sampling_rate=sr, return_tensors="pt").input_features | |
with torch.no_grad(): | |
predicted_ids = whisper_model.generate(inputs) | |
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] | |
# Analyze emotions | |
emotion_inputs = emotion_tokenizer( | |
transcription, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=512 | |
) | |
with torch.no_grad(): | |
emotion_outputs = emotion_model(**emotion_inputs) | |
emotions = torch.nn.functional.softmax(emotion_outputs.logits, dim=-1) | |
# Process emotion scores | |
emotion_labels = ['anger', 'fear', 'joy', 'neutral', 'sadness', 'surprise'] | |
emotion_scores = { | |
label: float(score) | |
for label, score in zip(emotion_labels, emotions[0].cpu().numpy()) | |
} | |
emotion_viz = create_emotion_plot(emotion_scores) | |
# Generate clinical analysis | |
global clinical_analyzer | |
if clinical_analyzer is None: | |
clinical_analyzer = ClinicalVoiceAnalyzer() | |
clinical_analysis = clinical_analyzer.analyze_voice_metrics( | |
features, emotion_scores, transcription | |
) | |
# Create comprehensive summary | |
summary = f"""Voice Analysis Summary: | |
Speech Content: | |
{transcription} | |
Voice Characteristics: | |
- Average Pitch: {features['pitch_mean']:.2f} Hz | |
- Pitch Variation: {features['pitch_std']:.2f} Hz | |
- Speech Rate (Tempo): {features['tempo']:.2f} BPM | |
- Voice Energy: {features['energy_mean']:.4f} | |
Dominant Emotion: {max(emotion_scores.items(), key=lambda x: x[1])[0]} | |
Emotion Confidence: {max(emotion_scores.values()):.2%} | |
Recording Duration: {duration:.2f} seconds | |
{clinical_analysis}""" | |
return summary, emotion_viz, feature_viz | |
except Exception as e: | |
error_msg = f"Analysis failed: {str(e)}" | |
print(error_msg) | |
return error_msg, None, None | |
# Part 7: Application Initialization | |
try: | |
print("===== Application Startup =====") | |
# Load required models | |
if not load_models(): | |
raise RuntimeError("Model loading failed") | |
# Initialize clinical analyzer | |
clinical_analyzer = ClinicalVoiceAnalyzer() | |
print("Clinical analyzer initialized") | |
# Create Gradio interface | |
demo = gr.Interface( | |
fn=analyze_audio, | |
inputs=gr.Audio( | |
sources=["microphone", "upload"], | |
type="filepath", | |
label="Audio Input (Recommended: 1-5 seconds of clear speech)" | |
), | |
outputs=[ | |
gr.Textbox(label="Analysis Summary", lines=15), | |
gr.HTML(label="Emotion Analysis"), | |
gr.HTML(label="Voice Feature Analysis") | |
], | |
title="Voice Analysis System with Clinical Interpretation", | |
description=""" | |
This application provides comprehensive voice analysis with clinical insights: | |
1. Voice Features: | |
- Pitch analysis (fundamental frequency and variation) | |
- Energy patterns (volume and intensity) | |
- Speech rate (words per minute) | |
- Voice quality metrics | |
2. Clinical Analysis: | |
- Mental health indicators | |
- Emotional state evaluation | |
- Risk assessment | |
- Clinical recommendations | |
3. Emotional Content: | |
- Emotion detection (6 basic emotions) | |
- Emotional intensity analysis | |
For optimal description=""" | |
This application provides comprehensive voice analysis with clinical insights: | |
1. Voice Features: | |
- Pitch analysis (fundamental frequency and variation) | |
- Energy patterns (volume and intensity) | |
- Speech rate (words per minute) | |
- Voice quality metrics | |
2. Clinical Analysis: | |
- Mental health indicators | |
- Emotional state evaluation | |
- Risk assessment | |
- Clinical recommendations | |
3. Emotional Content: | |
- Emotion detection (6 basic emotions) | |
- Emotional intensity analysis | |
For optimal results: | |
- Record in a quiet environment | |
- Speak clearly and naturally | |
- Keep recordings between 1-5 seconds | |
- Maintain consistent volume | |
Upload an audio file or record directly through your microphone. | |
""", | |
examples=None, | |
cache_examples=False | |
) | |
# Launch the interface | |
if __name__ == "__main__": | |
demo.launch() | |
except Exception as e: | |
print(f"Error during application startup: {str(e)}") | |
raise |