|
import os |
|
import torch |
|
import numpy as np |
|
from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor |
|
|
|
|
|
MODEL_ID = "facebook/mms-lid-256" |
|
SAMPLING_RATE = 16000 |
|
|
|
class NationalityModel: |
|
def __init__(self, cache_dir=None): |
|
if cache_dir is None: |
|
if os.path.exists("/data"): |
|
|
|
self.cache_dir = "/data/nationality" |
|
else: |
|
|
|
self.cache_dir = "./cache/nationality" |
|
else: |
|
self.cache_dir = cache_dir |
|
|
|
self.processor = None |
|
self.model = None |
|
os.makedirs(self.cache_dir, exist_ok=True) |
|
|
|
def load(self): |
|
try: |
|
print(f"Loading nationality prediction model from {MODEL_ID}...") |
|
print(f"Using cache directory: {self.cache_dir}") |
|
|
|
self.processor = AutoFeatureExtractor.from_pretrained( |
|
MODEL_ID, |
|
cache_dir=self.cache_dir |
|
) |
|
self.model = Wav2Vec2ForSequenceClassification.from_pretrained( |
|
MODEL_ID, |
|
cache_dir=self.cache_dir |
|
) |
|
print("Nationality prediction model loaded successfully!") |
|
return True |
|
except Exception as e: |
|
print(f"Error loading nationality prediction model: {e}") |
|
return False |
|
|
|
def predict(self, audio_data, sampling_rate): |
|
if self.model is None or self.processor is None: |
|
raise ValueError("Model not loaded. Call load() first.") |
|
|
|
try: |
|
if len(audio_data.shape) > 1: |
|
audio_data = audio_data.mean(axis=0) |
|
|
|
audio_data = audio_data.astype(np.float32) |
|
|
|
inputs = self.processor(audio_data, sampling_rate=sampling_rate, return_tensors="pt") |
|
|
|
with torch.no_grad(): |
|
outputs = self.model(**inputs).logits |
|
|
|
|
|
probabilities = torch.nn.functional.softmax(outputs, dim=-1)[0] |
|
top_k_values, top_k_indices = torch.topk(probabilities, k=5) |
|
|
|
top_languages = [] |
|
for i, idx in enumerate(top_k_indices): |
|
lang_id = idx.item() |
|
lang_code = self.model.config.id2label[lang_id] |
|
probability = top_k_values[i].item() |
|
top_languages.append({ |
|
"language_code": lang_code, |
|
"probability": probability |
|
}) |
|
|
|
|
|
predicted_lang_id = torch.argmax(outputs, dim=-1)[0].item() |
|
predicted_lang = self.model.config.id2label[predicted_lang_id] |
|
max_probability = probabilities[predicted_lang_id].item() |
|
|
|
return { |
|
"predicted_language": predicted_lang, |
|
"confidence": max_probability, |
|
"top_languages": top_languages |
|
} |
|
|
|
except Exception as e: |
|
raise Exception(f"Nationality prediction error: {str(e)}") |