finalxls-r-mms / app.py
kasimali's picture
Upload folder using huggingface_hub
50ce094 verified
# FINALXLS-R-MMS
# ============================================================================
# CELL 1: SETUP AND INSTALLATION
# ============================================================================
import os
import warnings
warnings.filterwarnings('ignore')
print("πŸš€ MMS Language Identification Test (Final Corrected Version)")
print("=" * 60)
# Mount Google Drive
from google.colab import drive
# Install and update necessary packages
print("πŸ“¦ Installing and updating packages...")
print("βœ… Setup complete! Please restart the runtime now to apply updates.")
# ============================================================================
# CELL 2: MODEL LOADING AND MAPPINGS (CORRECTED)
# ============================================================================
import torch
import librosa
import pandas as pd
import numpy as np
from datetime import datetime
from transformers import Wav2Vec2FeatureExtractor, AutoModelForAudioClassification
from sklearn.metrics import accuracy_score, classification_report
# --- CORRECTED: Ground truth mapping from your 2-letter folder names ---
# This remains the same as your code.
CUSTOM_FOLDER_MAPPING = {
'as': 'asm', 'bn': 'ben', 'br': 'brx', 'doi': 'dgo', 'en': 'eng',
'gu': 'guj', 'hi': 'hin', 'kn': 'kan', 'kok': 'kok', 'ks': 'kas',
'mai': 'mai', 'ml': 'mal', 'mni': 'mni', 'mr': 'mar', 'ne': 'nep',
'or': 'ory', 'pa': 'pa', 'sa': 'san', 'sat': 'sat', 'sd': 'snd',
'ta': 'tam', 'te': 'tel', 'ur': 'urd'
}
# --- NEW: Comprehensive Normalization Mapping ---
# This map standardizes the model's predictions to match YOUR ground truth format.
NORMALIZATION_MAP = {
'asm': 'asm', 'ben': 'ben', 'brx': 'brx', 'dgo': 'dgo', 'eng': 'eng',
'guj': 'guj', 'hin': 'hin', 'kan': 'kan', 'kok': 'kok', 'kas': 'kas',
'mai': 'mai', 'mal': 'mal', 'mni': 'mni', 'mar': 'mar', 'ory': 'ory',
'pan': 'pa', # Corrects 'pan' to 'pa'
'san': 'san', 'sat': 'sat', 'snd': 'snd', 'tam': 'tam', 'tel': 'tel', 'urd': 'urd',
'npi': 'nep' # CRUCIAL: Fixes the Nepali mismatch
}
# For generating readable reports
ISO_TO_FULL_NAME = {
'asm': 'Assamese', 'ben': 'Bengali', 'brx': 'Bodo', 'dgo': 'Dogri', 'eng': 'English',
'guj': 'Gujarati', 'hin': 'Hindi', 'kan': 'Kannada', 'kok': 'Konkani', 'kas': 'Kashmiri',
'mai': 'Maithili', 'mal': 'Malayalam', 'mni': 'Manipuri', 'mar': 'Marathi', 'nep': 'Nepali',
'ory': 'Odia', 'pa': 'Punjabi', 'san': 'Sanskrit', 'sat': 'Santali', 'snd': 'Sindhi',
'tam': 'Tamil', 'tel': 'Telugu', 'urd': 'Urdu'
}
# --- Paths and Model Loading (No Changes) ---
AUDIO_FOLDER = "/content/drive/MyDrive/Audio_files"
RESULTS_FOLDER = "/content/drive/MyDrive/mms_lid_results"
os.makedirs(RESULTS_FOLDER, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"πŸ”§ Device: {device}")
MODEL_NAME = "facebook/mms-lid-256"
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_NAME)
model = AutoModelForAudioClassification.from_pretrained(MODEL_NAME).to(device)
model.eval()
print(f"βœ… MMS LID model and feature extractor loaded successfully: {MODEL_NAME}")
# ============================================================================
# CELL 3: AUDIO PROCESSING AND PREDICTION (CORRECTED)
# ============================================================================
def load_audio_raw(file_path):
try:
audio, sr = librosa.load(file_path, sr=16000, mono=True)
duration = len(audio) / 16000
return audio, duration
except Exception as e:
print(f"Error loading {file_path}: {e}")
return None, 0
def predict_language_mms_top5(audio_array):
"""
Predicts the top 5 languages, but only from the list of target Indian languages.
"""
try:
inputs = feature_extractor(audio_array, sampling_rate=16000, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=-1)[0]
# --- Whitelist Logic ---
target_lang_codes = list(CUSTOM_FOLDER_MAPPING.values())
target_indices = [model.config.label2id[lang] for lang in target_lang_codes if lang in model.config.label2id]
# Create a mask to only consider target languages
mask = torch.zeros_like(probabilities)
mask[target_indices] = 1
# Apply mask and re-normalize probabilities
masked_probs = probabilities * mask
if masked_probs.sum() > 0:
renormalized_probs = masked_probs / masked_probs.sum()
else:
renormalized_probs = masked_probs # Avoid division by zero
# Get Top-5 predictions from the whitelisted languages
top5_probs, top5_indices = torch.topk(renormalized_probs, 5)
top5_lang_codes = [model.config.id2label[i.item()] for i in top5_indices]
return top5_lang_codes, top5_probs.cpu().numpy()
except Exception as e:
return ["error"], [0.0]
def find_audio_files(base_path):
audio_files = []
for root, _, files in os.walk(base_path):
folder_code = os.path.basename(root).lower()
if folder_code in CUSTOM_FOLDER_MAPPING:
ground_truth_iso = CUSTOM_FOLDER_MAPPING[folder_code]
for file in files:
if file.lower().endswith(('.wav', '.mp3', '.m4a', '.flac', '.ogg')):
audio_files.append({
"file_path": os.path.join(root, file),
"filename": file,
"ground_truth": ground_truth_iso
})
return audio_files
print("βœ… Corrected prediction functions are ready!")
# ============================================================================
# CELL 4: PROCESS ALL FILES AND GENERATE REPORT (CORRECTED)
# ============================================================================
def run_full_analysis_corrected():
print("πŸš€ Processing FULL dataset with Corrected Top-5 Logic...")
audio_files = find_audio_files(AUDIO_FOLDER)
if not audio_files:
print("❌ No audio files found.")
return
results = []
print(f"πŸ”„ Processing {len(audio_files)} files...")
for i, file_info in enumerate(audio_files):
if (i + 1) % 100 == 0:
print(f"Progress: {i+1}/{len(audio_files)}")
audio, duration = load_audio_raw(str(file_info['file_path']))
if audio is None:
results.append({**file_info, 'predicted_language': 'load_error', 'top5_predictions': [], 'confidence': 0.0, 'duration': 0.0})
else:
top5_langs, top5_probs = predict_language_mms_top5(audio)
# Apply normalization to all predictions
normalized_top5 = [NORMALIZATION_MAP.get(lang, 'unknown') for lang in top5_langs]
results.append({
**file_info,
'predicted_language': normalized_top5[0], # Top-1 prediction
'confidence': top5_probs[0],
'duration': duration,
'is_short_file': duration < 3.0,
'top5_predictions': normalized_top5
})
results_df = pd.DataFrame(results)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
csv_path = f"{RESULTS_FOLDER}/mms_corrected_top5_results_{timestamp}.csv"
results_df.to_csv(csv_path, index=False)
print(f"\nβœ… Processing complete! Results saved to: {csv_path}")
# --- Final Detailed Analysis ---
print("\n" + "=" * 60)
print("πŸ“Š MMS LID MODEL - FINAL CORRECTED ANALYSIS")
print("=" * 60)
valid_df = results_df[results_df['predicted_language'] != 'load_error'].copy()
# Calculate Top-1 Accuracy
top1_accuracy = accuracy_score(valid_df['ground_truth'], valid_df['predicted_language'])
# Calculate Top-5 Accuracy
valid_df['is_top5_correct'] = valid_df.apply(lambda row: row['ground_truth'] in row['top5_predictions'], axis=1)
top5_accuracy = valid_df['is_top5_correct'].mean()
print(f"\n🎯 OVERALL TOP-1 ACCURACY: {top1_accuracy:.2%}")
print(f"🎯 OVERALL TOP-5 ACCURACY: {top5_accuracy:.2%}")
print(f"\nπŸ“‹ LANGUAGE-WISE ACCURACY:")
report_df = pd.DataFrame(classification_report(valid_df['ground_truth'], valid_df['predicted_language'], output_dict=True, zero_division=0)).transpose()
report_df['Language'] = report_df.index.map(ISO_TO_FULL_NAME)
print(report_df[['Language', 'precision', 'recall', 'f1-score', 'support']])
# Run the final, corrected analysis
run_full_analysis_corrected()