SpeechAnalysisDemo / models /nationality_model.py
dtrovato997's picture
fix : clip audio to max 2 mins
5277669
import os
import torch
import numpy as np
from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor
# Constants
MODEL_ID = "facebook/mms-lid-256"
SAMPLING_RATE = 16000
class NationalityModel:
def __init__(self, cache_dir=None):
if cache_dir is None:
if os.path.exists("/data"):
# HF Spaces persistent storage
self.cache_dir = "/data/nationality"
else:
# Local development or other platforms
self.cache_dir = "./cache/nationality"
else:
self.cache_dir = cache_dir
self.processor = None
self.model = None
os.makedirs(self.cache_dir, exist_ok=True)
def load(self):
try:
print(f"Loading nationality prediction model from {MODEL_ID}...")
print(f"Using cache directory: {self.cache_dir}")
self.processor = AutoFeatureExtractor.from_pretrained(
MODEL_ID,
cache_dir=self.cache_dir
)
self.model = Wav2Vec2ForSequenceClassification.from_pretrained(
MODEL_ID,
cache_dir=self.cache_dir
)
print("Nationality prediction model loaded successfully!")
return True
except Exception as e:
print(f"Error loading nationality prediction model: {e}")
return False
def predict(self, audio_data, sampling_rate):
if self.model is None or self.processor is None:
raise ValueError("Model not loaded. Call load() first.")
try:
if len(audio_data.shape) > 1:
audio_data = audio_data.mean(axis=0)
audio_data = audio_data.astype(np.float32)
inputs = self.processor(audio_data, sampling_rate=sampling_rate, return_tensors="pt")
with torch.no_grad():
outputs = self.model(**inputs).logits
# Get top 5 predictions
probabilities = torch.nn.functional.softmax(outputs, dim=-1)[0]
top_k_values, top_k_indices = torch.topk(probabilities, k=5)
top_languages = []
for i, idx in enumerate(top_k_indices):
lang_id = idx.item()
lang_code = self.model.config.id2label[lang_id]
probability = top_k_values[i].item()
top_languages.append({
"language_code": lang_code,
"probability": probability
})
# Get the most likely language
predicted_lang_id = torch.argmax(outputs, dim=-1)[0].item()
predicted_lang = self.model.config.id2label[predicted_lang_id]
max_probability = probabilities[predicted_lang_id].item()
return {
"predicted_language": predicted_lang,
"confidence": max_probability,
"top_languages": top_languages
}
except Exception as e:
raise Exception(f"Nationality prediction error: {str(e)}")