Run_code_api / raw.py
ABAO77's picture
Implement enhanced pronunciation assessment system with Wav2Vec2 support
aa2c910
raw
history blame
27.4 kB
from typing import List, Dict
import numpy as np
import librosa
import nltk
import eng_to_ipa as ipa
import re
from collections import defaultdict
from loguru import logger
import time
from src.AI_Models.wave2vec_inference import (
Wave2Vec2Inference,
Wave2Vec2ONNXInference,
export_to_onnx,
)
# Download required NLTK data
try:
nltk.download("cmudict", quiet=True)
from nltk.corpus import cmudict
except:
print("Warning: NLTK data not available")
class Wav2Vec2CharacterASR:
"""Wav2Vec2 character-level ASR with support for both ONNX and Transformers inference"""
def __init__(
self,
model_name: str = "facebook/wav2vec2-large-960h-lv60-self",
onnx: bool = False,
quantized: bool = False,
):
"""
Initialize Wav2Vec2 character-level model
Args:
model_name: HuggingFace model name
onnx: If True, use ONNX runtime for inference. If False, use Transformers
onnx_model_path: Path to the ONNX model file (only used if onnx=True)
"""
self.use_onnx = onnx
self.sample_rate = 16000
self.model_name = model_name
# Check thử path của onnx model có tồn tại hay không
if onnx:
import os
if not os.path.exists(
"wav2vec2-large-960h-lv60-self"
+ (".quant" if quantized else "")
+ ".onnx"
):
export_to_onnx(model_name, quantize=quantized)
self.model = (
Wave2Vec2Inference(model_name)
if not onnx
else Wave2Vec2ONNXInference(
model_name,
"wav2vec2-large-960h-lv60-self"
+ (".quant" if quantized else "")
+ ".onnx",
)
)
def transcribe_to_characters(self, audio_path: str) -> Dict:
try:
start_time = time.time()
character_transcript = self.model.file_to_text(audio_path)
character_transcript = self._clean_character_transcript(
character_transcript
)
phoneme_like_transcript = self._characters_to_phoneme_representation(
character_transcript
)
logger.info(f"Transcription time: {time.time() - start_time:.2f}s")
return {
"character_transcript": character_transcript,
"phoneme_representation": phoneme_like_transcript,
}
except Exception as e:
print(f"Transformers transcription error: {e}")
return self._empty_result()
def _calculate_confidence_scores(self, logits: np.ndarray) -> List[float]:
"""Calculate confidence scores from logits using numpy"""
# Apply softmax
exp_logits = np.exp(logits - np.max(logits, axis=-1, keepdims=True))
softmax_probs = exp_logits / np.sum(exp_logits, axis=-1, keepdims=True)
# Get max probabilities
max_probs = np.max(softmax_probs, axis=-1)[0]
return max_probs.tolist()
def _clean_character_transcript(self, transcript: str) -> str:
"""Clean and standardize character transcript"""
# Remove extra spaces and special tokens
logger.info(f"Raw transcript before cleaning: {transcript}")
cleaned = re.sub(r"\s+", " ", transcript)
cleaned = cleaned.strip().lower()
return cleaned
def _characters_to_phoneme_representation(self, text: str) -> str:
"""Convert character-based transcript to phoneme-like representation for comparison"""
if not text:
return ""
words = text.split()
phoneme_words = []
g2p = SimpleG2P()
for word in words:
try:
if g2p:
word_data = g2p.text_to_phonemes(word)[0]
phoneme_words.extend(word_data["phonemes"])
else:
phoneme_words.extend(self._simple_letter_to_phoneme(word))
except:
# Fallback: simple letter-to-sound mapping
phoneme_words.extend(self._simple_letter_to_phoneme(word))
return " ".join(phoneme_words)
def _simple_letter_to_phoneme(self, word: str) -> List[str]:
"""Simple fallback letter-to-phoneme conversion"""
letter_to_phoneme = {
"a": "æ",
"b": "b",
"c": "k",
"d": "d",
"e": "ɛ",
"f": "f",
"g": "ɡ",
"h": "h",
"i": "ɪ",
"j": "dʒ",
"k": "k",
"l": "l",
"m": "m",
"n": "n",
"o": "ʌ",
"p": "p",
"q": "k",
"r": "r",
"s": "s",
"t": "t",
"u": "ʌ",
"v": "v",
"w": "w",
"x": "ks",
"y": "j",
"z": "z",
}
phonemes = []
for letter in word.lower():
if letter in letter_to_phoneme:
phonemes.append(letter_to_phoneme[letter])
return phonemes
def _empty_result(self) -> Dict:
"""Return empty result structure"""
return {
"character_transcript": "",
"phoneme_representation": "",
"raw_predicted_ids": [],
"confidence_scores": [],
}
def get_model_info(self) -> Dict:
"""Get information about the loaded model"""
info = {
"model_name": self.model_name,
"sample_rate": self.sample_rate,
"inference_method": "ONNX" if self.use_onnx else "Transformers",
}
if self.use_onnx:
info.update(
{
"onnx_model_path": self.onnx_model_path,
"input_name": self.input_name,
"output_name": self.output_name,
"session_providers": self.session.get_providers(),
}
)
return info
class SimpleG2P:
"""Simple Grapheme-to-Phoneme converter for reference text"""
def __init__(self):
try:
self.cmu_dict = cmudict.dict()
except:
self.cmu_dict = {}
print("Warning: CMU dictionary not available")
def text_to_phonemes(self, text: str) -> List[Dict]:
"""Convert text to phoneme sequence"""
words = self._clean_text(text).split()
phoneme_sequence = []
for word in words:
word_phonemes = self._get_word_phonemes(word)
phoneme_sequence.append(
{
"word": word,
"phonemes": word_phonemes,
"ipa": self._get_ipa(word),
"phoneme_string": " ".join(word_phonemes),
}
)
return phoneme_sequence
def get_reference_phoneme_string(self, text: str) -> str:
"""Get reference phoneme string for comparison"""
phoneme_sequence = self.text_to_phonemes(text)
all_phonemes = []
for word_data in phoneme_sequence:
all_phonemes.extend(word_data["phonemes"])
return " ".join(all_phonemes)
def _clean_text(self, text: str) -> str:
"""Clean text for processing"""
text = re.sub(r"[^\w\s\']", " ", text)
text = re.sub(r"\s+", " ", text)
return text.lower().strip()
def _get_word_phonemes(self, word: str) -> List[str]:
"""Get phonemes for a word"""
word_lower = word.lower()
if word_lower in self.cmu_dict:
# Remove stress markers and convert to Wav2Vec2 phoneme format
phonemes = self.cmu_dict[word_lower][0]
clean_phonemes = [re.sub(r"[0-9]", "", p) for p in phonemes]
return self._convert_to_wav2vec_format(clean_phonemes)
else:
return self._estimate_phonemes(word)
def _convert_to_wav2vec_format(self, cmu_phonemes: List[str]) -> List[str]:
"""Convert CMU phonemes to Wav2Vec2 format"""
# Mapping from CMU to Wav2Vec2/eSpeak phonemes
cmu_to_espeak = {
"AA": "ɑ",
"AE": "æ",
"AH": "ʌ",
"AO": "ɔ",
"AW": "aʊ",
"AY": "aɪ",
"EH": "ɛ",
"ER": "ɝ",
"EY": "eɪ",
"IH": "ɪ",
"IY": "i",
"OW": "oʊ",
"OY": "ɔɪ",
"UH": "ʊ",
"UW": "u",
"B": "b",
"CH": "tʃ",
"D": "d",
"DH": "ð",
"F": "f",
"G": "ɡ",
"HH": "h",
"JH": "dʒ",
"K": "k",
"L": "l",
"M": "m",
"N": "n",
"NG": "ŋ",
"P": "p",
"R": "r",
"S": "s",
"SH": "ʃ",
"T": "t",
"TH": "θ",
"V": "v",
"W": "w",
"Y": "j",
"Z": "z",
"ZH": "ʒ",
}
converted = []
for phoneme in cmu_phonemes:
converted_phoneme = cmu_to_espeak.get(phoneme, phoneme.lower())
converted.append(converted_phoneme)
return converted
def _get_ipa(self, word: str) -> str:
"""Get IPA transcription"""
try:
return ipa.convert(word)
except:
return f"/{word}/"
def _estimate_phonemes(self, word: str) -> List[str]:
"""Estimate phonemes for unknown words"""
# Basic phoneme estimation with eSpeak-style output
phoneme_map = {
"ch": ["tʃ"],
"sh": ["ʃ"],
"th": ["θ"],
"ph": ["f"],
"ck": ["k"],
"ng": ["ŋ"],
"qu": ["k", "w"],
"a": ["æ"],
"e": ["ɛ"],
"i": ["ɪ"],
"o": ["ʌ"],
"u": ["ʌ"],
"b": ["b"],
"c": ["k"],
"d": ["d"],
"f": ["f"],
"g": ["ɡ"],
"h": ["h"],
"j": ["dʒ"],
"k": ["k"],
"l": ["l"],
"m": ["m"],
"n": ["n"],
"p": ["p"],
"r": ["r"],
"s": ["s"],
"t": ["t"],
"v": ["v"],
"w": ["w"],
"x": ["k", "s"],
"y": ["j"],
"z": ["z"],
}
word = word.lower()
phonemes = []
i = 0
while i < len(word):
# Check 2-letter combinations first
if i <= len(word) - 2:
two_char = word[i : i + 2]
if two_char in phoneme_map:
phonemes.extend(phoneme_map[two_char])
i += 2
continue
# Single character
char = word[i]
if char in phoneme_map:
phonemes.extend(phoneme_map[char])
i += 1
return phonemes
class PhonemeComparator:
"""Compare reference and learner phoneme sequences"""
def __init__(self):
# Vietnamese speakers' common phoneme substitutions
self.substitution_patterns = {
"θ": ["f", "s", "t"], # TH → F, S, T
"ð": ["d", "z", "v"], # DH → D, Z, V
"v": ["w", "f"], # V → W, F
"r": ["l"], # R → L
"l": ["r"], # L → R
"z": ["s"], # Z → S
"ʒ": ["ʃ", "z"], # ZH → SH, Z
"ŋ": ["n"], # NG → N
}
# Difficulty levels for Vietnamese speakers
self.difficulty_map = {
"θ": 0.9, # th (think)
"ð": 0.9, # th (this)
"v": 0.8, # v
"z": 0.8, # z
"ʒ": 0.9, # zh (measure)
"r": 0.7, # r
"l": 0.6, # l
"w": 0.5, # w
"f": 0.4, # f
"s": 0.3, # s
"ʃ": 0.5, # sh
"tʃ": 0.4, # ch
"dʒ": 0.5, # j
"ŋ": 0.3, # ng
}
def compare_phoneme_sequences(
self, reference_phonemes: str, learner_phonemes: str
) -> List[Dict]:
"""Compare reference and learner phoneme sequences"""
# Split phoneme strings
ref_phones = reference_phonemes.split()
learner_phones = learner_phonemes.split()
print(f"Reference phonemes: {ref_phones}")
print(f"Learner phonemes: {learner_phones}")
# Simple alignment comparison
comparisons = []
max_len = max(len(ref_phones), len(learner_phones))
for i in range(max_len):
ref_phoneme = ref_phones[i] if i < len(ref_phones) else ""
learner_phoneme = learner_phones[i] if i < len(learner_phones) else ""
if ref_phoneme and learner_phoneme:
# Both present - check accuracy
if ref_phoneme == learner_phoneme:
status = "correct"
score = 1.0
elif self._is_acceptable_substitution(ref_phoneme, learner_phoneme):
status = "acceptable"
score = 0.7
else:
status = "wrong"
score = 0.2
elif ref_phoneme and not learner_phoneme:
# Missing phoneme
status = "missing"
score = 0.0
elif learner_phoneme and not ref_phoneme:
# Extra phoneme
status = "extra"
score = 0.0
else:
continue
comparison = {
"position": i,
"reference_phoneme": ref_phoneme,
"learner_phoneme": learner_phoneme,
"status": status,
"score": score,
"difficulty": self.difficulty_map.get(ref_phoneme, 0.3),
}
comparisons.append(comparison)
return comparisons
def _is_acceptable_substitution(self, reference: str, learner: str) -> bool:
"""Check if learner phoneme is acceptable substitution for Vietnamese speakers"""
acceptable = self.substitution_patterns.get(reference, [])
return learner in acceptable
# =============================================================================
# WORD ANALYZER
# =============================================================================
class WordAnalyzer:
"""Analyze word-level pronunciation accuracy using character-based ASR"""
def __init__(self):
self.g2p = SimpleG2P()
self.comparator = PhonemeComparator()
def analyze_words(self, reference_text: str, learner_phonemes: str) -> Dict:
"""Analyze word-level pronunciation using phoneme representation from character ASR"""
# Get reference phonemes by word
reference_words = self.g2p.text_to_phonemes(reference_text)
# Get overall phoneme comparison
reference_phoneme_string = self.g2p.get_reference_phoneme_string(reference_text)
phoneme_comparisons = self.comparator.compare_phoneme_sequences(
reference_phoneme_string, learner_phonemes
)
# Map phonemes back to words
word_highlights = self._create_word_highlights(
reference_words, phoneme_comparisons
)
# Identify wrong words
wrong_words = self._identify_wrong_words(word_highlights, phoneme_comparisons)
return {
"word_highlights": word_highlights,
"phoneme_differences": phoneme_comparisons,
"wrong_words": wrong_words,
}
def _create_word_highlights(
self, reference_words: List[Dict], phoneme_comparisons: List[Dict]
) -> List[Dict]:
"""Create word highlighting data"""
word_highlights = []
phoneme_index = 0
for word_data in reference_words:
word = word_data["word"]
word_phonemes = word_data["phonemes"]
num_phonemes = len(word_phonemes)
# Get phoneme scores for this word
word_phoneme_scores = []
for j in range(num_phonemes):
if phoneme_index + j < len(phoneme_comparisons):
comparison = phoneme_comparisons[phoneme_index + j]
word_phoneme_scores.append(comparison["score"])
# Calculate word score
word_score = np.mean(word_phoneme_scores) if word_phoneme_scores else 0.0
# Create word highlight
highlight = {
"word": word,
"score": float(word_score),
"status": self._get_word_status(word_score),
"color": self._get_word_color(word_score),
"phonemes": word_phonemes,
"ipa": word_data["ipa"],
"phoneme_scores": word_phoneme_scores,
"phoneme_start_index": phoneme_index,
"phoneme_end_index": phoneme_index + num_phonemes - 1,
}
word_highlights.append(highlight)
phoneme_index += num_phonemes
return word_highlights
def _identify_wrong_words(
self, word_highlights: List[Dict], phoneme_comparisons: List[Dict]
) -> List[Dict]:
"""Identify words that were pronounced incorrectly"""
wrong_words = []
for word_highlight in word_highlights:
if word_highlight["score"] < 0.6: # Threshold for wrong pronunciation
# Find specific phoneme errors for this word
start_idx = word_highlight["phoneme_start_index"]
end_idx = word_highlight["phoneme_end_index"]
wrong_phonemes = []
missing_phonemes = []
for i in range(start_idx, min(end_idx + 1, len(phoneme_comparisons))):
comparison = phoneme_comparisons[i]
if comparison["status"] == "wrong":
wrong_phonemes.append(
{
"expected": comparison["reference_phoneme"],
"actual": comparison["learner_phoneme"],
"difficulty": comparison["difficulty"],
}
)
elif comparison["status"] == "missing":
missing_phonemes.append(
{
"phoneme": comparison["reference_phoneme"],
"difficulty": comparison["difficulty"],
}
)
wrong_word = {
"word": word_highlight["word"],
"score": word_highlight["score"],
"expected_phonemes": word_highlight["phonemes"],
"ipa": word_highlight["ipa"],
"wrong_phonemes": wrong_phonemes,
"missing_phonemes": missing_phonemes,
"tips": self._get_vietnamese_tips(wrong_phonemes, missing_phonemes),
}
wrong_words.append(wrong_word)
return wrong_words
def _get_word_status(self, score: float) -> str:
"""Get word status from score"""
if score >= 0.8:
return "excellent"
elif score >= 0.6:
return "good"
elif score >= 0.4:
return "needs_practice"
else:
return "poor"
def _get_word_color(self, score: float) -> str:
"""Get color for word highlighting"""
if score >= 0.8:
return "#22c55e" # Green
elif score >= 0.6:
return "#84cc16" # Light green
elif score >= 0.4:
return "#eab308" # Yellow
else:
return "#ef4444" # Red
def _get_vietnamese_tips(
self, wrong_phonemes: List[Dict], missing_phonemes: List[Dict]
) -> List[str]:
"""Get Vietnamese-specific pronunciation tips"""
tips = []
# Tips for specific Vietnamese pronunciation challenges
vietnamese_tips = {
"θ": "Đặt lưỡi giữa răng trên và dưới, thổi nhẹ (think, three)",
"ð": "Giống θ nhưng rung dây thanh âm (this, that)",
"v": "Chạm môi dưới vào răng trên, không dùng cả hai môi như tiếng Việt",
"r": "Cuộn lưỡi nhưng không chạm vào vòm miệng, không lăn lưỡi",
"l": "Đầu lưỡi chạm vào vòm miệng sau răng",
"z": "Giống âm 's' nhưng có rung dây thanh âm",
"ʒ": "Giống âm 'ʃ' (sh) nhưng có rung dây thanh âm",
"w": "Tròn môi như âm 'u', không dùng răng như âm 'v'",
}
# Add tips for wrong phonemes
for wrong in wrong_phonemes:
expected = wrong["expected"]
actual = wrong["actual"]
if expected in vietnamese_tips:
tips.append(f"Âm '{expected}': {vietnamese_tips[expected]}")
else:
tips.append(f"Luyện âm '{expected}' thay vì '{actual}'")
# Add tips for missing phonemes
for missing in missing_phonemes:
phoneme = missing["phoneme"]
if phoneme in vietnamese_tips:
tips.append(f"Thiếu âm '{phoneme}': {vietnamese_tips[phoneme]}")
return tips
class SimpleFeedbackGenerator:
"""Generate simple, actionable feedback in Vietnamese"""
def generate_feedback(
self,
overall_score: float,
wrong_words: List[Dict],
phoneme_comparisons: List[Dict],
) -> List[str]:
"""Generate Vietnamese feedback"""
feedback = []
# Overall feedback in Vietnamese
if overall_score >= 0.8:
feedback.append("Phát âm rất tốt! Bạn đã làm xuất sắc.")
elif overall_score >= 0.6:
feedback.append("Phát âm khá tốt, còn một vài điểm cần cải thiện.")
elif overall_score >= 0.4:
feedback.append(
"Cần luyện tập thêm. Tập trung vào những từ được đánh dấu đỏ."
)
else:
feedback.append("Hãy luyện tập chậm và rõ ràng hơn.")
# Wrong words feedback
if wrong_words:
if len(wrong_words) <= 3:
word_names = [w["word"] for w in wrong_words]
feedback.append(f"Các từ cần luyện tập: {', '.join(word_names)}")
else:
feedback.append(
f"Có {len(wrong_words)} từ cần luyện tập. Tập trung vào từng từ một."
)
# Most problematic phonemes
problem_phonemes = defaultdict(int)
for comparison in phoneme_comparisons:
if comparison["status"] in ["wrong", "missing"]:
phoneme = comparison["reference_phoneme"]
problem_phonemes[phoneme] += 1
if problem_phonemes:
most_difficult = sorted(
problem_phonemes.items(), key=lambda x: x[1], reverse=True
)
top_problem = most_difficult[0][0]
phoneme_tips = {
"θ": "Lưỡi giữa răng, thổi nhẹ",
"ð": "Lưỡi giữa răng, rung dây thanh",
"v": "Môi dưới chạm răng trên",
"r": "Cuộn lưỡi, không chạm vòm miệng",
"l": "Lưỡi chạm vòm miệng",
"z": "Như 's' nhưng rung dây thanh",
}
if top_problem in phoneme_tips:
feedback.append(
f"Âm khó nhất '{top_problem}': {phoneme_tips[top_problem]}"
)
return feedback
class SimplePronunciationAssessor:
"""Main pronunciation assessor supporting both normal (Whisper) and advanced (Wav2Vec2) modes"""
def __init__(self):
print("Initializing Simple Pronunciation Assessor...")
self.wav2vec2_asr = Wav2Vec2CharacterASR() # Advanced mode
self.word_analyzer = WordAnalyzer()
self.feedback_generator = SimpleFeedbackGenerator()
print("Initialization completed")
def assess_pronunciation(
self, audio_path: str, reference_text: str, mode: str = "normal"
) -> Dict:
"""
Main assessment function with mode selection
Args:
audio_path: Path to audio file
reference_text: Reference text to compare
mode: 'normal' (Whisper) or 'advanced' (Wav2Vec2)
Output: Word highlights + Phoneme differences + Wrong words
"""
print(f"Starting pronunciation assessment in {mode} mode...")
# Step 1: Choose ASR model based on mode
if mode == "advanced":
print("Step 1: Using Wav2Vec2 character transcription...")
asr_result = self.wav2vec2_asr.transcribe_to_characters(audio_path)
model_info = f"Wav2Vec2-Character ({self.wav2vec2_asr.model})"
character_transcript = asr_result["character_transcript"]
phoneme_representation = asr_result["phoneme_representation"]
print(f"Character transcript: {character_transcript}")
print(f"Phoneme representation: {phoneme_representation}")
# Step 2: Word analysis using phoneme representation
print("Step 2: Analyzing words...")
analysis_result = self.word_analyzer.analyze_words(
reference_text, phoneme_representation
)
# Step 3: Calculate overall score
phoneme_comparisons = analysis_result["phoneme_differences"]
overall_score = self._calculate_overall_score(phoneme_comparisons)
# Step 4: Generate feedback
print("Step 3: Generating feedback...")
feedback = self.feedback_generator.generate_feedback(
overall_score, analysis_result["wrong_words"], phoneme_comparisons
)
result = {
"transcript": character_transcript, # What user actually said
"transcript_phonemes": phoneme_representation,
"user_phonemes": phoneme_representation, # Alias for UI clarity
"character_transcript": character_transcript,
"overall_score": overall_score,
"word_highlights": analysis_result["word_highlights"],
"phoneme_differences": phoneme_comparisons,
"wrong_words": analysis_result["wrong_words"],
"feedback": feedback,
"processing_info": {
"model_used": model_info,
"mode": mode,
"character_based": mode == "advanced",
"language_model_correction": mode == "normal",
"raw_output": mode == "advanced",
},
}
print("Assessment completed successfully")
return result
def _calculate_overall_score(self, phoneme_comparisons: List[Dict]) -> float:
"""Calculate overall pronunciation score"""
if not phoneme_comparisons:
return 0.0
total_score = sum(comparison["score"] for comparison in phoneme_comparisons)
return total_score / len(phoneme_comparisons)