from typing import List, Dict import numpy as np import librosa import nltk import eng_to_ipa as ipa import re from collections import defaultdict from loguru import logger import time from src.AI_Models.wave2vec_inference import ( Wave2Vec2Inference, Wave2Vec2ONNXInference, export_to_onnx, ) # Download required NLTK data try: nltk.download("cmudict", quiet=True) from nltk.corpus import cmudict except: print("Warning: NLTK data not available") class Wav2Vec2CharacterASR: """Wav2Vec2 character-level ASR with support for both ONNX and Transformers inference""" def __init__( self, model_name: str = "facebook/wav2vec2-large-960h-lv60-self", onnx: bool = False, quantized: bool = False, ): """ Initialize Wav2Vec2 character-level model Args: model_name: HuggingFace model name onnx: If True, use ONNX runtime for inference. If False, use Transformers onnx_model_path: Path to the ONNX model file (only used if onnx=True) """ self.use_onnx = onnx self.sample_rate = 16000 self.model_name = model_name # Check thử path của onnx model có tồn tại hay không if onnx: import os if not os.path.exists( "wav2vec2-large-960h-lv60-self" + (".quant" if quantized else "") + ".onnx" ): export_to_onnx(model_name, quantize=quantized) self.model = ( Wave2Vec2Inference(model_name) if not onnx else Wave2Vec2ONNXInference( model_name, "wav2vec2-large-960h-lv60-self" + (".quant" if quantized else "") + ".onnx", ) ) def transcribe_to_characters(self, audio_path: str) -> Dict: try: start_time = time.time() character_transcript = self.model.file_to_text(audio_path) character_transcript = self._clean_character_transcript( character_transcript ) phoneme_like_transcript = self._characters_to_phoneme_representation( character_transcript ) logger.info(f"Transcription time: {time.time() - start_time:.2f}s") return { "character_transcript": character_transcript, "phoneme_representation": phoneme_like_transcript, } except Exception as e: print(f"Transformers transcription error: {e}") return self._empty_result() def _calculate_confidence_scores(self, logits: np.ndarray) -> List[float]: """Calculate confidence scores from logits using numpy""" # Apply softmax exp_logits = np.exp(logits - np.max(logits, axis=-1, keepdims=True)) softmax_probs = exp_logits / np.sum(exp_logits, axis=-1, keepdims=True) # Get max probabilities max_probs = np.max(softmax_probs, axis=-1)[0] return max_probs.tolist() def _clean_character_transcript(self, transcript: str) -> str: """Clean and standardize character transcript""" # Remove extra spaces and special tokens logger.info(f"Raw transcript before cleaning: {transcript}") cleaned = re.sub(r"\s+", " ", transcript) cleaned = cleaned.strip().lower() return cleaned def _characters_to_phoneme_representation(self, text: str) -> str: """Convert character-based transcript to phoneme-like representation for comparison""" if not text: return "" words = text.split() phoneme_words = [] g2p = SimpleG2P() for word in words: try: if g2p: word_data = g2p.text_to_phonemes(word)[0] phoneme_words.extend(word_data["phonemes"]) else: phoneme_words.extend(self._simple_letter_to_phoneme(word)) except: # Fallback: simple letter-to-sound mapping phoneme_words.extend(self._simple_letter_to_phoneme(word)) return " ".join(phoneme_words) def _simple_letter_to_phoneme(self, word: str) -> List[str]: """Simple fallback letter-to-phoneme conversion""" letter_to_phoneme = { "a": "æ", "b": "b", "c": "k", "d": "d", "e": "ɛ", "f": "f", "g": "ɡ", "h": "h", "i": "ɪ", "j": "dʒ", "k": "k", "l": "l", "m": "m", "n": "n", "o": "ʌ", "p": "p", "q": "k", "r": "r", "s": "s", "t": "t", "u": "ʌ", "v": "v", "w": "w", "x": "ks", "y": "j", "z": "z", } phonemes = [] for letter in word.lower(): if letter in letter_to_phoneme: phonemes.append(letter_to_phoneme[letter]) return phonemes def _empty_result(self) -> Dict: """Return empty result structure""" return { "character_transcript": "", "phoneme_representation": "", "raw_predicted_ids": [], "confidence_scores": [], } def get_model_info(self) -> Dict: """Get information about the loaded model""" info = { "model_name": self.model_name, "sample_rate": self.sample_rate, "inference_method": "ONNX" if self.use_onnx else "Transformers", } if self.use_onnx: info.update( { "onnx_model_path": self.onnx_model_path, "input_name": self.input_name, "output_name": self.output_name, "session_providers": self.session.get_providers(), } ) return info class SimpleG2P: """Simple Grapheme-to-Phoneme converter for reference text""" def __init__(self): try: self.cmu_dict = cmudict.dict() except: self.cmu_dict = {} print("Warning: CMU dictionary not available") def text_to_phonemes(self, text: str) -> List[Dict]: """Convert text to phoneme sequence""" words = self._clean_text(text).split() phoneme_sequence = [] for word in words: word_phonemes = self._get_word_phonemes(word) phoneme_sequence.append( { "word": word, "phonemes": word_phonemes, "ipa": self._get_ipa(word), "phoneme_string": " ".join(word_phonemes), } ) return phoneme_sequence def get_reference_phoneme_string(self, text: str) -> str: """Get reference phoneme string for comparison""" phoneme_sequence = self.text_to_phonemes(text) all_phonemes = [] for word_data in phoneme_sequence: all_phonemes.extend(word_data["phonemes"]) return " ".join(all_phonemes) def _clean_text(self, text: str) -> str: """Clean text for processing""" text = re.sub(r"[^\w\s\']", " ", text) text = re.sub(r"\s+", " ", text) return text.lower().strip() def _get_word_phonemes(self, word: str) -> List[str]: """Get phonemes for a word""" word_lower = word.lower() if word_lower in self.cmu_dict: # Remove stress markers and convert to Wav2Vec2 phoneme format phonemes = self.cmu_dict[word_lower][0] clean_phonemes = [re.sub(r"[0-9]", "", p) for p in phonemes] return self._convert_to_wav2vec_format(clean_phonemes) else: return self._estimate_phonemes(word) def _convert_to_wav2vec_format(self, cmu_phonemes: List[str]) -> List[str]: """Convert CMU phonemes to Wav2Vec2 format""" # Mapping from CMU to Wav2Vec2/eSpeak phonemes cmu_to_espeak = { "AA": "ɑ", "AE": "æ", "AH": "ʌ", "AO": "ɔ", "AW": "aʊ", "AY": "aɪ", "EH": "ɛ", "ER": "ɝ", "EY": "eɪ", "IH": "ɪ", "IY": "i", "OW": "oʊ", "OY": "ɔɪ", "UH": "ʊ", "UW": "u", "B": "b", "CH": "tʃ", "D": "d", "DH": "ð", "F": "f", "G": "ɡ", "HH": "h", "JH": "dʒ", "K": "k", "L": "l", "M": "m", "N": "n", "NG": "ŋ", "P": "p", "R": "r", "S": "s", "SH": "ʃ", "T": "t", "TH": "θ", "V": "v", "W": "w", "Y": "j", "Z": "z", "ZH": "ʒ", } converted = [] for phoneme in cmu_phonemes: converted_phoneme = cmu_to_espeak.get(phoneme, phoneme.lower()) converted.append(converted_phoneme) return converted def _get_ipa(self, word: str) -> str: """Get IPA transcription""" try: return ipa.convert(word) except: return f"/{word}/" def _estimate_phonemes(self, word: str) -> List[str]: """Estimate phonemes for unknown words""" # Basic phoneme estimation with eSpeak-style output phoneme_map = { "ch": ["tʃ"], "sh": ["ʃ"], "th": ["θ"], "ph": ["f"], "ck": ["k"], "ng": ["ŋ"], "qu": ["k", "w"], "a": ["æ"], "e": ["ɛ"], "i": ["ɪ"], "o": ["ʌ"], "u": ["ʌ"], "b": ["b"], "c": ["k"], "d": ["d"], "f": ["f"], "g": ["ɡ"], "h": ["h"], "j": ["dʒ"], "k": ["k"], "l": ["l"], "m": ["m"], "n": ["n"], "p": ["p"], "r": ["r"], "s": ["s"], "t": ["t"], "v": ["v"], "w": ["w"], "x": ["k", "s"], "y": ["j"], "z": ["z"], } word = word.lower() phonemes = [] i = 0 while i < len(word): # Check 2-letter combinations first if i <= len(word) - 2: two_char = word[i : i + 2] if two_char in phoneme_map: phonemes.extend(phoneme_map[two_char]) i += 2 continue # Single character char = word[i] if char in phoneme_map: phonemes.extend(phoneme_map[char]) i += 1 return phonemes class PhonemeComparator: """Compare reference and learner phoneme sequences""" def __init__(self): # Vietnamese speakers' common phoneme substitutions self.substitution_patterns = { "θ": ["f", "s", "t"], # TH → F, S, T "ð": ["d", "z", "v"], # DH → D, Z, V "v": ["w", "f"], # V → W, F "r": ["l"], # R → L "l": ["r"], # L → R "z": ["s"], # Z → S "ʒ": ["ʃ", "z"], # ZH → SH, Z "ŋ": ["n"], # NG → N } # Difficulty levels for Vietnamese speakers self.difficulty_map = { "θ": 0.9, # th (think) "ð": 0.9, # th (this) "v": 0.8, # v "z": 0.8, # z "ʒ": 0.9, # zh (measure) "r": 0.7, # r "l": 0.6, # l "w": 0.5, # w "f": 0.4, # f "s": 0.3, # s "ʃ": 0.5, # sh "tʃ": 0.4, # ch "dʒ": 0.5, # j "ŋ": 0.3, # ng } def compare_phoneme_sequences( self, reference_phonemes: str, learner_phonemes: str ) -> List[Dict]: """Compare reference and learner phoneme sequences""" # Split phoneme strings ref_phones = reference_phonemes.split() learner_phones = learner_phonemes.split() print(f"Reference phonemes: {ref_phones}") print(f"Learner phonemes: {learner_phones}") # Simple alignment comparison comparisons = [] max_len = max(len(ref_phones), len(learner_phones)) for i in range(max_len): ref_phoneme = ref_phones[i] if i < len(ref_phones) else "" learner_phoneme = learner_phones[i] if i < len(learner_phones) else "" if ref_phoneme and learner_phoneme: # Both present - check accuracy if ref_phoneme == learner_phoneme: status = "correct" score = 1.0 elif self._is_acceptable_substitution(ref_phoneme, learner_phoneme): status = "acceptable" score = 0.7 else: status = "wrong" score = 0.2 elif ref_phoneme and not learner_phoneme: # Missing phoneme status = "missing" score = 0.0 elif learner_phoneme and not ref_phoneme: # Extra phoneme status = "extra" score = 0.0 else: continue comparison = { "position": i, "reference_phoneme": ref_phoneme, "learner_phoneme": learner_phoneme, "status": status, "score": score, "difficulty": self.difficulty_map.get(ref_phoneme, 0.3), } comparisons.append(comparison) return comparisons def _is_acceptable_substitution(self, reference: str, learner: str) -> bool: """Check if learner phoneme is acceptable substitution for Vietnamese speakers""" acceptable = self.substitution_patterns.get(reference, []) return learner in acceptable # ============================================================================= # WORD ANALYZER # ============================================================================= class WordAnalyzer: """Analyze word-level pronunciation accuracy using character-based ASR""" def __init__(self): self.g2p = SimpleG2P() self.comparator = PhonemeComparator() def analyze_words(self, reference_text: str, learner_phonemes: str) -> Dict: """Analyze word-level pronunciation using phoneme representation from character ASR""" # Get reference phonemes by word reference_words = self.g2p.text_to_phonemes(reference_text) # Get overall phoneme comparison reference_phoneme_string = self.g2p.get_reference_phoneme_string(reference_text) phoneme_comparisons = self.comparator.compare_phoneme_sequences( reference_phoneme_string, learner_phonemes ) # Map phonemes back to words word_highlights = self._create_word_highlights( reference_words, phoneme_comparisons ) # Identify wrong words wrong_words = self._identify_wrong_words(word_highlights, phoneme_comparisons) return { "word_highlights": word_highlights, "phoneme_differences": phoneme_comparisons, "wrong_words": wrong_words, } def _create_word_highlights( self, reference_words: List[Dict], phoneme_comparisons: List[Dict] ) -> List[Dict]: """Create word highlighting data""" word_highlights = [] phoneme_index = 0 for word_data in reference_words: word = word_data["word"] word_phonemes = word_data["phonemes"] num_phonemes = len(word_phonemes) # Get phoneme scores for this word word_phoneme_scores = [] for j in range(num_phonemes): if phoneme_index + j < len(phoneme_comparisons): comparison = phoneme_comparisons[phoneme_index + j] word_phoneme_scores.append(comparison["score"]) # Calculate word score word_score = np.mean(word_phoneme_scores) if word_phoneme_scores else 0.0 # Create word highlight highlight = { "word": word, "score": float(word_score), "status": self._get_word_status(word_score), "color": self._get_word_color(word_score), "phonemes": word_phonemes, "ipa": word_data["ipa"], "phoneme_scores": word_phoneme_scores, "phoneme_start_index": phoneme_index, "phoneme_end_index": phoneme_index + num_phonemes - 1, } word_highlights.append(highlight) phoneme_index += num_phonemes return word_highlights def _identify_wrong_words( self, word_highlights: List[Dict], phoneme_comparisons: List[Dict] ) -> List[Dict]: """Identify words that were pronounced incorrectly""" wrong_words = [] for word_highlight in word_highlights: if word_highlight["score"] < 0.6: # Threshold for wrong pronunciation # Find specific phoneme errors for this word start_idx = word_highlight["phoneme_start_index"] end_idx = word_highlight["phoneme_end_index"] wrong_phonemes = [] missing_phonemes = [] for i in range(start_idx, min(end_idx + 1, len(phoneme_comparisons))): comparison = phoneme_comparisons[i] if comparison["status"] == "wrong": wrong_phonemes.append( { "expected": comparison["reference_phoneme"], "actual": comparison["learner_phoneme"], "difficulty": comparison["difficulty"], } ) elif comparison["status"] == "missing": missing_phonemes.append( { "phoneme": comparison["reference_phoneme"], "difficulty": comparison["difficulty"], } ) wrong_word = { "word": word_highlight["word"], "score": word_highlight["score"], "expected_phonemes": word_highlight["phonemes"], "ipa": word_highlight["ipa"], "wrong_phonemes": wrong_phonemes, "missing_phonemes": missing_phonemes, "tips": self._get_vietnamese_tips(wrong_phonemes, missing_phonemes), } wrong_words.append(wrong_word) return wrong_words def _get_word_status(self, score: float) -> str: """Get word status from score""" if score >= 0.8: return "excellent" elif score >= 0.6: return "good" elif score >= 0.4: return "needs_practice" else: return "poor" def _get_word_color(self, score: float) -> str: """Get color for word highlighting""" if score >= 0.8: return "#22c55e" # Green elif score >= 0.6: return "#84cc16" # Light green elif score >= 0.4: return "#eab308" # Yellow else: return "#ef4444" # Red def _get_vietnamese_tips( self, wrong_phonemes: List[Dict], missing_phonemes: List[Dict] ) -> List[str]: """Get Vietnamese-specific pronunciation tips""" tips = [] # Tips for specific Vietnamese pronunciation challenges vietnamese_tips = { "θ": "Đặt lưỡi giữa răng trên và dưới, thổi nhẹ (think, three)", "ð": "Giống θ nhưng rung dây thanh âm (this, that)", "v": "Chạm môi dưới vào răng trên, không dùng cả hai môi như tiếng Việt", "r": "Cuộn lưỡi nhưng không chạm vào vòm miệng, không lăn lưỡi", "l": "Đầu lưỡi chạm vào vòm miệng sau răng", "z": "Giống âm 's' nhưng có rung dây thanh âm", "ʒ": "Giống âm 'ʃ' (sh) nhưng có rung dây thanh âm", "w": "Tròn môi như âm 'u', không dùng răng như âm 'v'", } # Add tips for wrong phonemes for wrong in wrong_phonemes: expected = wrong["expected"] actual = wrong["actual"] if expected in vietnamese_tips: tips.append(f"Âm '{expected}': {vietnamese_tips[expected]}") else: tips.append(f"Luyện âm '{expected}' thay vì '{actual}'") # Add tips for missing phonemes for missing in missing_phonemes: phoneme = missing["phoneme"] if phoneme in vietnamese_tips: tips.append(f"Thiếu âm '{phoneme}': {vietnamese_tips[phoneme]}") return tips class SimpleFeedbackGenerator: """Generate simple, actionable feedback in Vietnamese""" def generate_feedback( self, overall_score: float, wrong_words: List[Dict], phoneme_comparisons: List[Dict], ) -> List[str]: """Generate Vietnamese feedback""" feedback = [] # Overall feedback in Vietnamese if overall_score >= 0.8: feedback.append("Phát âm rất tốt! Bạn đã làm xuất sắc.") elif overall_score >= 0.6: feedback.append("Phát âm khá tốt, còn một vài điểm cần cải thiện.") elif overall_score >= 0.4: feedback.append( "Cần luyện tập thêm. Tập trung vào những từ được đánh dấu đỏ." ) else: feedback.append("Hãy luyện tập chậm và rõ ràng hơn.") # Wrong words feedback if wrong_words: if len(wrong_words) <= 3: word_names = [w["word"] for w in wrong_words] feedback.append(f"Các từ cần luyện tập: {', '.join(word_names)}") else: feedback.append( f"Có {len(wrong_words)} từ cần luyện tập. Tập trung vào từng từ một." ) # Most problematic phonemes problem_phonemes = defaultdict(int) for comparison in phoneme_comparisons: if comparison["status"] in ["wrong", "missing"]: phoneme = comparison["reference_phoneme"] problem_phonemes[phoneme] += 1 if problem_phonemes: most_difficult = sorted( problem_phonemes.items(), key=lambda x: x[1], reverse=True ) top_problem = most_difficult[0][0] phoneme_tips = { "θ": "Lưỡi giữa răng, thổi nhẹ", "ð": "Lưỡi giữa răng, rung dây thanh", "v": "Môi dưới chạm răng trên", "r": "Cuộn lưỡi, không chạm vòm miệng", "l": "Lưỡi chạm vòm miệng", "z": "Như 's' nhưng rung dây thanh", } if top_problem in phoneme_tips: feedback.append( f"Âm khó nhất '{top_problem}': {phoneme_tips[top_problem]}" ) return feedback class SimplePronunciationAssessor: """Main pronunciation assessor supporting both normal (Whisper) and advanced (Wav2Vec2) modes""" def __init__(self): print("Initializing Simple Pronunciation Assessor...") self.wav2vec2_asr = Wav2Vec2CharacterASR() # Advanced mode self.word_analyzer = WordAnalyzer() self.feedback_generator = SimpleFeedbackGenerator() print("Initialization completed") def assess_pronunciation( self, audio_path: str, reference_text: str, mode: str = "normal" ) -> Dict: """ Main assessment function with mode selection Args: audio_path: Path to audio file reference_text: Reference text to compare mode: 'normal' (Whisper) or 'advanced' (Wav2Vec2) Output: Word highlights + Phoneme differences + Wrong words """ print(f"Starting pronunciation assessment in {mode} mode...") # Step 1: Choose ASR model based on mode if mode == "advanced": print("Step 1: Using Wav2Vec2 character transcription...") asr_result = self.wav2vec2_asr.transcribe_to_characters(audio_path) model_info = f"Wav2Vec2-Character ({self.wav2vec2_asr.model})" character_transcript = asr_result["character_transcript"] phoneme_representation = asr_result["phoneme_representation"] print(f"Character transcript: {character_transcript}") print(f"Phoneme representation: {phoneme_representation}") # Step 2: Word analysis using phoneme representation print("Step 2: Analyzing words...") analysis_result = self.word_analyzer.analyze_words( reference_text, phoneme_representation ) # Step 3: Calculate overall score phoneme_comparisons = analysis_result["phoneme_differences"] overall_score = self._calculate_overall_score(phoneme_comparisons) # Step 4: Generate feedback print("Step 3: Generating feedback...") feedback = self.feedback_generator.generate_feedback( overall_score, analysis_result["wrong_words"], phoneme_comparisons ) result = { "transcript": character_transcript, # What user actually said "transcript_phonemes": phoneme_representation, "user_phonemes": phoneme_representation, # Alias for UI clarity "character_transcript": character_transcript, "overall_score": overall_score, "word_highlights": analysis_result["word_highlights"], "phoneme_differences": phoneme_comparisons, "wrong_words": analysis_result["wrong_words"], "feedback": feedback, "processing_info": { "model_used": model_info, "mode": mode, "character_based": mode == "advanced", "language_model_correction": mode == "normal", "raw_output": mode == "advanced", }, } print("Assessment completed successfully") return result def _calculate_overall_score(self, phoneme_comparisons: List[Dict]) -> float: """Calculate overall pronunciation score""" if not phoneme_comparisons: return 0.0 total_score = sum(comparison["score"] for comparison in phoneme_comparisons) return total_score / len(phoneme_comparisons)