|
import gradio as gr |
|
import librosa |
|
import numpy as np |
|
import os |
|
import hashlib |
|
from datetime import datetime |
|
from transformers import pipeline |
|
import soundfile as sf |
|
import torch |
|
from tenacity import retry, stop_after_attempt, wait_fixed |
|
|
|
|
|
@retry(stop=stop_after_attempt(3), wait=wait_fixed(2)) |
|
def load_whisper_model(): |
|
try: |
|
model = pipeline( |
|
"automatic-speech-recognition", |
|
model="openai/whisper-tiny.en", |
|
device=-1, |
|
model_kwargs={"use_safetensors": True} |
|
) |
|
print("Whisper model loaded successfully.") |
|
return model |
|
except Exception as e: |
|
print(f"Failed to load Whisper model: {str(e)}") |
|
raise |
|
|
|
@retry(stop=stop_after_attempt(3), wait=wait_fixed(2)) |
|
def load_symptom_model(): |
|
try: |
|
model = pipeline( |
|
"text-classification", |
|
model="abhirajeshbhai/symptom-2-disease-net", |
|
device=-1, |
|
model_kwargs={"use_safetensors": True} |
|
) |
|
print("Symptom-2-Disease model loaded successfully.") |
|
return model |
|
except Exception as e: |
|
print(f"Failed to load Symptom-2-Disease model: {str(e)}") |
|
|
|
try: |
|
model = pipeline( |
|
"text-classification", |
|
model="distilbert-base-uncased", |
|
device=-1 |
|
) |
|
print("Fallback to distilbert-base-uncased model.") |
|
return model |
|
except Exception as fallback_e: |
|
print(f"Fallback model failed: {str(fallback_e)}") |
|
raise |
|
|
|
whisper = None |
|
symptom_classifier = None |
|
is_fallback_model = False |
|
|
|
try: |
|
whisper = load_whisper_model() |
|
except Exception as e: |
|
print(f"Whisper model initialization failed after retries: {str(e)}") |
|
|
|
try: |
|
symptom_classifier = load_symptom_model() |
|
except Exception as e: |
|
print(f"Symptom model initialization failed after retries: {str(e)}") |
|
symptom_classifier = None |
|
is_fallback_model = True |
|
|
|
def compute_file_hash(file_path): |
|
"""Compute MD5 hash of a file to check uniqueness.""" |
|
hash_md5 = hashlib.md5() |
|
with open(file_path, "rb") as f: |
|
for chunk in iter(lambda: f.read(4096), b""): |
|
hash_md5.update(chunk) |
|
return hash_md5.hexdigest() |
|
|
|
def transcribe_audio(audio_file): |
|
"""Transcribe audio using local Whisper model.""" |
|
if not whisper: |
|
return "Error: Whisper model not loaded. Check logs for details or ensure sufficient compute resources." |
|
try: |
|
|
|
audio, sr = librosa.load(audio_file, sr=16000) |
|
if len(audio) < 1600: |
|
return "Error: Audio too short. Please provide audio of at least 1 second." |
|
if np.max(np.abs(audio)) < 1e-4: |
|
return "Error: Audio too quiet. Please provide clear audio describing symptoms in English." |
|
|
|
|
|
temp_wav = f"/tmp/{os.path.basename(audio_file)}.wav" |
|
sf.write(temp_wav, audio, sr) |
|
|
|
|
|
with torch.no_grad(): |
|
result = whisper(temp_wav, generate_kwargs={"num_beams": 5}) |
|
transcription = result.get("text", "").strip() |
|
print(f"Transcription: {transcription}") |
|
|
|
|
|
try: |
|
os.remove(temp_wav) |
|
except Exception: |
|
pass |
|
|
|
if not transcription: |
|
return "Transcription empty. Please provide clear audio describing symptoms in English." |
|
|
|
words = transcription.split() |
|
if len(words) > 5 and len(set(words)) < len(words) / 2: |
|
return "Error: Transcription appears repetitive. Please provide clear, non-repetitive audio describing symptoms." |
|
return transcription |
|
except Exception as e: |
|
return f"Error transcribing audio: {str(e)}" |
|
|
|
def analyze_symptoms(text): |
|
"""Analyze symptoms using local Symptom-2-Disease model.""" |
|
if not symptom_classifier: |
|
return "Error: Symptom-2-Disease model not loaded. Check logs for details or ensure sufficient compute resources.", 0.0 |
|
try: |
|
if not text or "Error transcribing" in text: |
|
return "No valid transcription for analysis.", 0.0 |
|
with torch.no_grad(): |
|
result = symptom_classifier(text) |
|
if result and isinstance(result, list) and len(result) > 0: |
|
prediction = result[0]["label"] |
|
score = result[0]["score"] |
|
if is_fallback_model: |
|
print("Warning: Using fallback model (distilbert-base-uncased). Results may be less accurate.") |
|
prediction = f"{prediction} (using fallback model)" |
|
print(f"Health Prediction: {prediction}, Score: {score:.4f}") |
|
return prediction, score |
|
return "No health condition predicted", 0.0 |
|
except Exception as e: |
|
return f"Error analyzing symptoms: {str(e)}", 0.0 |
|
|
|
def analyze_voice(audio_file): |
|
"""Analyze voice for health indicators.""" |
|
try: |
|
|
|
unique_path = f"/tmp/gradio/{datetime.now().strftime('%Y%m%d%H%M%S%f')}_{os.path.basename(audio_file)}" |
|
os.rename(audio_file, unique_path) |
|
audio_file = unique_path |
|
|
|
|
|
file_hash = compute_file_hash(audio_file) |
|
print(f"Processing audio file: {audio_file}, Hash: {file_hash}") |
|
|
|
|
|
audio, sr = librosa.load(audio_file, sr=16000) |
|
print(f"Audio shape: {audio.shape}, Sampling rate: {sr}, Duration: {len(audio)/sr:.2f}s, Mean: {np.mean(audio):.4f}, Std: {np.std(audio):.4f}") |
|
|
|
|
|
transcription = transcribe_audio(audio_file) |
|
if "Error transcribing" in transcription: |
|
return transcription |
|
|
|
|
|
if "medicine" in transcription.lower() or "treatment" in transcription.lower(): |
|
feedback = "Error: This tool does not provide medication or treatment advice. Please describe symptoms only (e.g., 'I have a fever')." |
|
feedback += f"\n\n**Debug Info**: Transcription = '{transcription}', File Hash = {file_hash}" |
|
feedback += "\n**Disclaimer**: This is not a diagnostic tool. Consult a healthcare provider for medical advice." |
|
return feedback |
|
|
|
|
|
prediction, score = analyze_symptoms(transcription) |
|
if "Error analyzing" in prediction: |
|
return prediction |
|
|
|
|
|
if prediction == "No health condition predicted": |
|
feedback = "No significant health indicators detected." |
|
else: |
|
feedback = f"Possible health condition: {prediction} (confidence: {score:.4f}). Consult a doctor." |
|
|
|
feedback += f"\n\n**Debug Info**: Transcription = '{transcription}', Prediction = {prediction}, Confidence = {score:.4f}, File Hash = {file_hash}" |
|
feedback += "\n**Disclaimer**: This is not a diagnostic tool. Consult a healthcare provider for medical advice." |
|
|
|
|
|
try: |
|
os.remove(audio_file) |
|
print(f"Deleted temporary audio file: {audio_file}") |
|
except Exception as e: |
|
print(f"Failed to delete audio file: {str(e)}") |
|
|
|
return feedback |
|
except Exception as e: |
|
return f"Error processing audio: {str(e)}" |
|
|
|
def test_with_sample_audio(): |
|
"""Test the app with sample audio files.""" |
|
samples = ["audio_samples/sample.wav", "audio_samples/common_voice_en.wav"] |
|
results = [] |
|
for sample in samples: |
|
if os.path.exists(sample): |
|
results.append(analyze_voice(sample)) |
|
else: |
|
results.append(f"Sample not found: {sample}") |
|
return "\n".join(results) |
|
|
|
|
|
iface = gr.Interface( |
|
fn=analyze_voice, |
|
inputs=gr.Audio(type="filepath", label="Record or Upload Voice"), |
|
outputs=gr.Textbox(label="Health Assessment Feedback"), |
|
title="Health Voice Analyzer", |
|
description="Record or upload a voice sample describing symptoms (e.g., 'I have a fever') for preliminary health assessment. Supports English only. Use clear audio (WAV, 16kHz). Do not ask for medication or treatment advice." |
|
) |
|
|
|
if __name__ == "__main__": |
|
print(test_with_sample_audio()) |
|
iface.launch(server_name="0.0.0.0", server_port=7860) |