ABAO77 commited on
Commit
df380ff
·
1 Parent(s): c6480d4

Add Wav2Vec2 model conversion and inference to ONNX format

Browse files

- Implemented Wav2Vec2ONNXConverter class for converting Wav2Vec2 models to ONNX format, including model loading, conversion, and verification.
- Added Wav2Vec2ONNXInference class for performing inference using the converted ONNX model.
- Included methods for softmax calculation and transcription of audio files.
- Added utility functions for creating compatible models and exporting with fallback options for different ONNX opset versions.
- Introduced optimization function for ONNX models.
- Created helper function to convert numpy types to native Python types for better compatibility.

.gitignore CHANGED
@@ -19,4 +19,5 @@ data_test
19
  **.tiff
20
  **.webp
21
  **.svg
22
- .serena
 
 
19
  **.tiff
20
  **.webp
21
  **.svg
22
+ .serena
23
+ **.onnx
EVALUATION_CAROUSEL_UPDATES.md DELETED
File without changes
HYDRATION_ERROR_FIXES.md DELETED
File without changes
LESSON_PRACTICE_2_UPDATES.md DELETED
@@ -1,95 +0,0 @@
1
- # Cập nhật Lesson Practice 2 Agent - Tóm tắt thay đổi
2
-
3
- ## Mục tiêu
4
- Điều chỉnh `lesson_practice_2` agent để:
5
- - **Teaching Agent** trở thành agent mặc định (thay vì Practice Agent)
6
- - Tạo trải nghiệm học tập tự nhiên và thu hút
7
- - **Responses ngắn gọn và tương tác** - không quá dài làm người dùng nản
8
- - Chuyển đổi mượt mà giữa teaching và practice mode
9
- - Người dùng cảm thấy thoải mái và muốn tương tác nhiều hơn
10
-
11
- ## Thay đổi chính
12
-
13
- ### 1. Agent mặc định (func.py)
14
- - **Trước**: `state["active_agent"] = "Practice Agent"`
15
- - **Sau**: `state["active_agent"] = "Teaching Agent"`
16
- - **Lý do**: Bắt đầu với việc dạy và hướng dẫn trước khi thực hành
17
-
18
- ### 2. Teaching Agent Prompt (prompt.py)
19
- #### Cải thiện chính:
20
- - **Triết lý dạy học tự nhiên**: Bắt đầu từ level hiện tại, xây dựng tự tin từ từ
21
- - **Linh hoạt ngôn ngữ**: Tiếng Việt khi cần, tiếng Anh khi có thể
22
- - **Phong cách thu hút**: Nhiệt tình, kiên nhẫn, khuyến khích với humor nhẹ nhàng
23
- - **Responses ngắn gọn**: 10-20 từ tối đa, một câu hỏi, tập trung vào tương tác
24
- - **Phương pháp dạy tương tác**: Một khái niệm/lần, hỏi input ngay, không giải thích quá nhiều
25
- - **Xử lý lỗi nhanh**: Sửa ngắn gọn, khuyến khích thử lại ngay
26
- - **Ví dụ cụ thể**: Có examples về responses tốt vs nên tránh
27
-
28
- ### 3. Practice Agent Prompt (prompt.py)
29
- #### Cải thiện chính:
30
- - **Đối tác hội thoại tự nhiên**: Tập trung vào giao tiếp thay vì hoàn hảo
31
- - **Responses cực ngắn**: 1-2 câu tối đa, một câu hỏi hay
32
- - **Phong cách partner**: Quan tâm thực sự, không điền đầy mọi khoảng trống
33
- - **Khuyến khích tham gia**: Tạo không gian cho họ chia sẻ thêm
34
- - **Ví dụ responses**: Examples về cách trả lời ngắn gọn nhưng hấp dẫn
35
-
36
- ### 4. Logic chuyển đổi (func.py)
37
- #### Teaching → Practice:
38
- - Người dùng thể hiện hiểu biết và tự tin
39
- - Yêu cầu thực hành hội thoại
40
- - Sẵn sàng cho giao tiếp tiếng Anh
41
-
42
- #### Practice → Teaching:
43
- - Cần giải thích ngữ pháp chi tiết
44
- - Lỗi cơ bản lặp lại nhiều lần
45
- - Yêu cầu hỗ trợ có cấu trúc hơn
46
-
47
- ### 5. Flow routing (flow.py)
48
- - Thêm fallback logic: mặc định về Teaching Agent nếu không có active agent
49
-
50
- ## Lợi ích của thay đổi
51
-
52
- ### Trải nghiệm người học:
53
- 1. **Bắt đầu thoải mái**: Teaching agent tạo môi trường an toàn để học
54
- 2. **Tương tác cao**: Responses ngắn gọn, luôn có câu hỏi khuyến khích tham gia
55
- 3. **Không bị overwhelm**: Không quá nhiều thông tin một lúc
56
- 4. **Linh hoạt ngôn ngữ**: Dùng tiếng Việt khi cần, tiếng Anh khi có thể
57
- 5. **Chuyển đổi tự nhiên**: Khi sẵn sàng, được khuyến khích thực hành
58
- 6. **Partner thực sự**: Practice mode như nói chuyện với bạn thật, câu trả lời ngắn gọn
59
-
60
- ### Hiệu quả giáo dục:
61
- 1. **Học có cấu trúc**: Dạy trước, luyện sau, từng bước nhỏ
62
- 2. **Động lực cao**: Môi trường vui vẻ, không áp lực, luôn được khuyến khích tham gia
63
- 3. **Duy trì sự chú ý**: Responses ngắn giúp người học không bị mệt mỏi
64
- 4. **Tương tác liên tục**: Luôn có cơ hội để người học phản hồi
65
- 5. **Ứng dụng thực tế**: Tập trung vào giao tiếp thực tế
66
- 6. **Tự tin giao tiếp**: Chuẩn bị kỹ trước khi thực hành
67
-
68
- ## Cách sử dụng
69
-
70
- 1. **Bắt đầu**: Teaching Agent sẽ chào và bắt đầu dạy
71
- 2. **Học tập**: Giải thích, luyện tập với hỗ trợ và khuyến khích
72
- 3. **Sẵn sàng**: Khi tự tin, Teaching Agent sẽ chuyển sang Practice Agent
73
- 4. **Thực hành**: Hội thoại tự nhiên với Practice Agent
74
- 5. **Hỗ trợ**: Nếu cần giúp, Practice Agent chuyển về Teaching Agent
75
-
76
- ## Kết quả mong đợi
77
- - Người học cảm thấy thoải mái và được hỗ trợ
78
- - **Luôn muốn tương tác thêm** vì responses ngắn gọn, dễ đọc
79
- - Quá trình học tự nhiên và không áp lực
80
- - **Không bị overwhelm** bởi thông tin quá nhiều
81
- - Chuyển đổi mượt mà giữa học và thực hành
82
- - Động lực cao và muốn tiếp tục học
83
- - Giao tiếp tiếng Anh tự tin và tự nhiên
84
-
85
- ## Ví dụ Response Style
86
-
87
- ### Teaching Agent:
88
- ❌ **Tránh**: "That's excellent! You're really making great progress with past tense. Let me explain how irregular verbs work in English. There are many irregular verbs like 'go-went', 'see-saw', 'have-had'..."
89
-
90
- ✅ **Tốt**: "Good try! Use **went** instead. Can you try again?"
91
-
92
- ### Practice Agent:
93
- ❌ **Tránh**: "That sounds like a really interesting experience! I'd love to hear more about what happened next and how you felt about the whole situation. It must have been quite exciting for you!"
94
-
95
- ✅ **Tốt**: "Wow, sounds exciting! What happened next?"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -17,4 +17,9 @@ deepgram-sdk
17
  whisper-openai
18
  nltk
19
  librosa
20
- eng-to-ipa
 
 
 
 
 
 
17
  whisper-openai
18
  nltk
19
  librosa
20
+ eng-to-ipa
21
+ onnxruntime
22
+ onnx
23
+ transformers
24
+ torch
25
+ optimum[onnxruntime]
src/apis/__pycache__/create_app.cpython-311.pyc CHANGED
Binary files a/src/apis/__pycache__/create_app.cpython-311.pyc and b/src/apis/__pycache__/create_app.cpython-311.pyc differ
 
src/apis/controllers/speaking_controller.py ADDED
@@ -0,0 +1,1004 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException, APIRouter
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from pydantic import BaseModel
4
+ from typing import List, Dict, Optional
5
+ import tempfile
6
+ import os
7
+ import numpy as np
8
+ import librosa
9
+ import nltk
10
+ import eng_to_ipa as ipa
11
+ import torch
12
+ import re
13
+ from collections import defaultdict
14
+ import warnings
15
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
16
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
17
+ from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
18
+ from loguru import logger
19
+ import onnxruntime
20
+
21
+ warnings.filterwarnings("ignore")
22
+
23
+ # Download required NLTK data
24
+ try:
25
+ nltk.download("cmudict", quiet=True)
26
+ from nltk.corpus import cmudict
27
+ except:
28
+ print("Warning: NLTK data not available")
29
+
30
+
31
+ class WhisperASR:
32
+ """Whisper ASR for normal mode pronunciation assessment"""
33
+
34
+ def __init__(self, model_name: str = "openai/whisper-base.en"):
35
+ """
36
+ Initialize Whisper model for normal mode
37
+
38
+ Args:
39
+ model_name: HuggingFace model name for Whisper
40
+ """
41
+ print(f"Loading Whisper model: {model_name}")
42
+
43
+ try:
44
+ # Try ONNX first
45
+ self.processor = WhisperProcessor.from_pretrained(model_name)
46
+ self.model = ORTModelForSpeechSeq2Seq.from_pretrained(
47
+ model_name,
48
+ export=True,
49
+ provider="CPUExecutionProvider",
50
+ )
51
+ self.model_type = "ONNX"
52
+ print("Whisper ONNX model loaded successfully")
53
+ except:
54
+ # Fallback to PyTorch
55
+ self.processor = WhisperProcessor.from_pretrained(model_name)
56
+ self.model = WhisperForConditionalGeneration.from_pretrained(model_name)
57
+ self.model_type = "PyTorch"
58
+ print("Whisper PyTorch model loaded successfully")
59
+
60
+ self.model_name = model_name
61
+ self.sample_rate = 16000
62
+
63
+ def transcribe_to_text(self, audio_path: str) -> Dict:
64
+ """
65
+ Transcribe audio to text using Whisper
66
+ Returns transcript and confidence score
67
+ """
68
+ try:
69
+ # Load audio
70
+ audio, sr = librosa.load(audio_path, sr=self.sample_rate)
71
+
72
+ # Process audio
73
+ inputs = self.processor(audio, sampling_rate=16000, return_tensors="pt")
74
+
75
+ # Set language to English
76
+ forced_decoder_ids = self.processor.get_decoder_prompt_ids(
77
+ language="en", task="transcribe"
78
+ )
79
+
80
+ # Generate transcription
81
+ with torch.no_grad():
82
+ predicted_ids = self.model.generate(
83
+ inputs["input_features"],
84
+ forced_decoder_ids=forced_decoder_ids,
85
+ max_new_tokens=200,
86
+ do_sample=False,
87
+ )
88
+
89
+ # Decode to text
90
+ transcript = self.processor.batch_decode(
91
+ predicted_ids, skip_special_tokens=True
92
+ )[0]
93
+ transcript = transcript.strip().lower()
94
+
95
+ # Convert to phoneme representation for comparison
96
+ g2p = SimpleG2P()
97
+ phoneme_representation = g2p.get_reference_phoneme_string(transcript)
98
+
99
+ return {
100
+ "character_transcript": transcript,
101
+ "phoneme_representation": phoneme_representation,
102
+ "confidence_scores": [0.8]
103
+ * len(transcript.split()), # Simple confidence
104
+ }
105
+
106
+ except Exception as e:
107
+ logger.error(f"Whisper transcription error: {e}")
108
+ return {
109
+ "character_transcript": "",
110
+ "phoneme_representation": "",
111
+ "confidence_scores": [],
112
+ }
113
+
114
+ def get_model_info(self) -> Dict:
115
+ """Get information about the loaded Whisper model"""
116
+ return {
117
+ "model_name": self.model_name,
118
+ "model_type": self.model_type,
119
+ "sample_rate": self.sample_rate,
120
+ }
121
+
122
+
123
+ class Wav2Vec2CharacterASRONNX:
124
+ """Wav2Vec2 character-level ASR with ONNX runtime - no language model correction"""
125
+
126
+ def __init__(
127
+ self,
128
+ onnx_model_path: str = "./wav2vec2_asr.onnx",
129
+ processor_name: str = "facebook/wav2vec2-base-960h",
130
+ ):
131
+ """
132
+ Initialize Wav2Vec2 ONNX character-level model
133
+ Automatically creates ONNX model if it doesn't exist
134
+
135
+ Args:
136
+ onnx_model_path: Path to the ONNX model file
137
+ processor_name: HuggingFace model name for the processor
138
+ """
139
+ print(f"Loading Wav2Vec2 ONNX model from: {onnx_model_path}")
140
+ print(f"Loading processor: {processor_name}")
141
+
142
+ # Check if ONNX model exists, if not create it
143
+ if not os.path.exists(onnx_model_path):
144
+ print(f"ONNX model not found at {onnx_model_path}. Creating it...")
145
+ self._create_onnx_model(onnx_model_path, processor_name)
146
+
147
+ try:
148
+ # Load ONNX model
149
+ self.session = onnxruntime.InferenceSession(onnx_model_path)
150
+ self.input_name = self.session.get_inputs()[0].name
151
+ self.output_name = self.session.get_outputs()[0].name
152
+
153
+ # Load processor
154
+ self.processor = Wav2Vec2Processor.from_pretrained(processor_name)
155
+
156
+ print("ONNX Wav2Vec2 character model loaded successfully")
157
+ self.model_name = processor_name
158
+ self.onnx_path = onnx_model_path
159
+ self.sample_rate = 16000
160
+
161
+ except Exception as e:
162
+ print(f"Error loading ONNX model: {e}")
163
+ raise
164
+
165
+ def _create_onnx_model(self, onnx_model_path: str, processor_name: str):
166
+ """Create ONNX model if it doesn't exist"""
167
+ try:
168
+ # Import the converter from model_convert
169
+ from src.model_convert.wav2vec2onnx import Wav2Vec2ONNXConverter
170
+
171
+ print("Creating new ONNX model...")
172
+ converter = Wav2Vec2ONNXConverter(processor_name)
173
+ created_path = converter.convert_to_onnx(
174
+ onnx_path=onnx_model_path,
175
+ input_length=160000, # 10 seconds
176
+ opset_version=14,
177
+ )
178
+ print(f"✓ ONNX model created successfully at: {created_path}")
179
+
180
+ except ImportError as e:
181
+ print(f"Error importing Wav2Vec2ONNXConverter: {e}")
182
+ # Fallback: use the convert_to_onnx.py directly if wav2vec2onnx.py doesn't work
183
+ self._fallback_create_onnx_model(onnx_model_path, processor_name)
184
+
185
+ except Exception as e:
186
+ print(f"Error creating ONNX model: {e}")
187
+ # Try fallback method
188
+ self._fallback_create_onnx_model(onnx_model_path, processor_name)
189
+
190
+ def _fallback_create_onnx_model(self, onnx_model_path: str, processor_name: str):
191
+ """Fallback method to create ONNX model using basic torch.onnx.export"""
192
+ try:
193
+ print("Using fallback method to create ONNX model...")
194
+
195
+ # Load PyTorch model
196
+ model = Wav2Vec2ForCTC.from_pretrained(processor_name)
197
+ model.eval()
198
+
199
+ # Create dummy input
200
+ dummy_input = torch.randn(1, 160000, dtype=torch.float32)
201
+
202
+ # Export to ONNX
203
+ with torch.no_grad():
204
+ torch.onnx.export(
205
+ model,
206
+ dummy_input,
207
+ onnx_model_path,
208
+ input_names=["input_values"],
209
+ output_names=["logits"],
210
+ dynamic_axes={
211
+ "input_values": {0: "batch_size", 1: "sequence_length"},
212
+ "logits": {0: "batch_size", 1: "sequence_length"},
213
+ },
214
+ opset_version=14,
215
+ do_constant_folding=True,
216
+ verbose=False,
217
+ export_params=True,
218
+ )
219
+
220
+ print(f"✓ Fallback ONNX model created at: {onnx_model_path}")
221
+
222
+ except Exception as e:
223
+ print(f"Fallback method also failed: {e}")
224
+ raise Exception(f"Could not create ONNX model: {e}")
225
+
226
+ def transcribe_to_characters(self, audio_path: str) -> Dict:
227
+ """
228
+ Transcribe audio directly to characters using ONNX model (no language model correction)
229
+ Returns raw character sequence as produced by the model
230
+ """
231
+ try:
232
+ # Load audio
233
+ speech, sr = librosa.load(audio_path, sr=self.sample_rate)
234
+
235
+ # Prepare input for ONNX
236
+ input_values = self.processor(
237
+ speech, sampling_rate=self.sample_rate, return_tensors="np"
238
+ ).input_values
239
+
240
+ # Run ONNX inference
241
+ ort_inputs = {self.input_name: input_values}
242
+ ort_outputs = self.session.run([self.output_name], ort_inputs)
243
+ logits = ort_outputs[0]
244
+
245
+ # Get predictions
246
+ predicted_ids = np.argmax(logits, axis=-1)
247
+
248
+ # Decode to characters directly
249
+ character_transcript = self.processor.batch_decode(predicted_ids)[0]
250
+ logger.info(f"character_transcript {character_transcript}")
251
+
252
+ # Clean up character transcript
253
+ character_transcript = self._clean_character_transcript(
254
+ character_transcript
255
+ )
256
+
257
+ # Convert characters to phoneme-like representation
258
+ phoneme_like_transcript = self._characters_to_phoneme_representation(
259
+ character_transcript
260
+ )
261
+
262
+ # Calculate confidence scores
263
+ confidence_scores = self._calculate_confidence_scores(logits)
264
+
265
+ return {
266
+ "character_transcript": character_transcript,
267
+ "phoneme_representation": phoneme_like_transcript,
268
+ "raw_predicted_ids": predicted_ids[0].tolist(),
269
+ "confidence_scores": confidence_scores[:100], # Limit for JSON
270
+ }
271
+
272
+ except Exception as e:
273
+ print(f"Transcription error: {e}")
274
+ return {
275
+ "character_transcript": "",
276
+ "phoneme_representation": "",
277
+ "raw_predicted_ids": [],
278
+ "confidence_scores": [],
279
+ }
280
+
281
+ def _calculate_confidence_scores(self, logits: np.ndarray) -> List[float]:
282
+ """Calculate confidence scores from logits using numpy"""
283
+ # Apply softmax
284
+ exp_logits = np.exp(logits - np.max(logits, axis=-1, keepdims=True))
285
+ softmax_probs = exp_logits / np.sum(exp_logits, axis=-1, keepdims=True)
286
+
287
+ # Get max probabilities
288
+ max_probs = np.max(softmax_probs, axis=-1)[0]
289
+ return max_probs.tolist()
290
+
291
+ def _clean_character_transcript(self, transcript: str) -> str:
292
+ """Clean and standardize character transcript"""
293
+ # Remove extra spaces and special tokens
294
+ logger.info(f"Raw transcript before cleaning: {transcript}")
295
+ cleaned = re.sub(r"\s+", " ", transcript)
296
+ cleaned = cleaned.strip().lower()
297
+
298
+ return cleaned
299
+
300
+ def _characters_to_phoneme_representation(self, text: str) -> str:
301
+ """Convert character-based transcript to phoneme-like representation for comparison"""
302
+ # This is a simple character-to-phoneme mapping for pronunciation comparison
303
+ # The idea is to convert the raw character output to something comparable with reference phonemes
304
+
305
+ if not text:
306
+ return ""
307
+
308
+ words = text.split()
309
+ phoneme_words = []
310
+
311
+ # Use our G2P to convert transcript words to phonemes
312
+ g2p = SimpleG2P()
313
+
314
+ for word in words:
315
+ try:
316
+ word_data = g2p.text_to_phonemes(word)[0]
317
+ phoneme_words.extend(word_data["phonemes"])
318
+ except:
319
+ # Fallback: simple letter-to-sound mapping
320
+ phoneme_words.extend(self._simple_letter_to_phoneme(word))
321
+
322
+ return " ".join(phoneme_words)
323
+
324
+ def _simple_letter_to_phoneme(self, word: str) -> List[str]:
325
+ """Simple fallback letter-to-phoneme conversion"""
326
+ letter_to_phoneme = {
327
+ "a": "æ",
328
+ "b": "b",
329
+ "c": "k",
330
+ "d": "d",
331
+ "e": "ɛ",
332
+ "f": "f",
333
+ "g": "ɡ",
334
+ "h": "h",
335
+ "i": "ɪ",
336
+ "j": "dʒ",
337
+ "k": "k",
338
+ "l": "l",
339
+ "m": "m",
340
+ "n": "n",
341
+ "o": "ʌ",
342
+ "p": "p",
343
+ "q": "k",
344
+ "r": "r",
345
+ "s": "s",
346
+ "t": "t",
347
+ "u": "ʌ",
348
+ "v": "v",
349
+ "w": "w",
350
+ "x": "ks",
351
+ "y": "j",
352
+ "z": "z",
353
+ }
354
+
355
+ phonemes = []
356
+ for letter in word.lower():
357
+ if letter in letter_to_phoneme:
358
+ phonemes.append(letter_to_phoneme[letter])
359
+
360
+ return phonemes
361
+
362
+ def get_model_info(self) -> Dict:
363
+ """Get information about the loaded ONNX model"""
364
+ return {
365
+ "onnx_model_path": self.onnx_path,
366
+ "processor_name": self.model_name,
367
+ "input_name": self.input_name,
368
+ "output_name": self.output_name,
369
+ "sample_rate": self.sample_rate,
370
+ "session_providers": self.session.get_providers(),
371
+ }
372
+
373
+
374
+ class SimpleG2P:
375
+ """Simple Grapheme-to-Phoneme converter for reference text"""
376
+
377
+ def __init__(self):
378
+ try:
379
+ self.cmu_dict = cmudict.dict()
380
+ except:
381
+ self.cmu_dict = {}
382
+ print("Warning: CMU dictionary not available")
383
+
384
+ def text_to_phonemes(self, text: str) -> List[Dict]:
385
+ """Convert text to phoneme sequence"""
386
+ words = self._clean_text(text).split()
387
+ phoneme_sequence = []
388
+
389
+ for word in words:
390
+ word_phonemes = self._get_word_phonemes(word)
391
+ phoneme_sequence.append(
392
+ {
393
+ "word": word,
394
+ "phonemes": word_phonemes,
395
+ "ipa": self._get_ipa(word),
396
+ "phoneme_string": " ".join(word_phonemes),
397
+ }
398
+ )
399
+
400
+ return phoneme_sequence
401
+
402
+ def get_reference_phoneme_string(self, text: str) -> str:
403
+ """Get reference phoneme string for comparison"""
404
+ phoneme_sequence = self.text_to_phonemes(text)
405
+ all_phonemes = []
406
+
407
+ for word_data in phoneme_sequence:
408
+ all_phonemes.extend(word_data["phonemes"])
409
+
410
+ return " ".join(all_phonemes)
411
+
412
+ def _clean_text(self, text: str) -> str:
413
+ """Clean text for processing"""
414
+ text = re.sub(r"[^\w\s\']", " ", text)
415
+ text = re.sub(r"\s+", " ", text)
416
+ return text.lower().strip()
417
+
418
+ def _get_word_phonemes(self, word: str) -> List[str]:
419
+ """Get phonemes for a word"""
420
+ word_lower = word.lower()
421
+
422
+ if word_lower in self.cmu_dict:
423
+ # Remove stress markers and convert to Wav2Vec2 phoneme format
424
+ phonemes = self.cmu_dict[word_lower][0]
425
+ clean_phonemes = [re.sub(r"[0-9]", "", p) for p in phonemes]
426
+ return self._convert_to_wav2vec_format(clean_phonemes)
427
+ else:
428
+ return self._estimate_phonemes(word)
429
+
430
+ def _convert_to_wav2vec_format(self, cmu_phonemes: List[str]) -> List[str]:
431
+ """Convert CMU phonemes to Wav2Vec2 format"""
432
+ # Mapping from CMU to Wav2Vec2/eSpeak phonemes
433
+ cmu_to_espeak = {
434
+ "AA": "ɑ",
435
+ "AE": "æ",
436
+ "AH": "ʌ",
437
+ "AO": "ɔ",
438
+ "AW": "aʊ",
439
+ "AY": "aɪ",
440
+ "EH": "ɛ",
441
+ "ER": "ɝ",
442
+ "EY": "eɪ",
443
+ "IH": "ɪ",
444
+ "IY": "i",
445
+ "OW": "oʊ",
446
+ "OY": "ɔɪ",
447
+ "UH": "ʊ",
448
+ "UW": "u",
449
+ "B": "b",
450
+ "CH": "tʃ",
451
+ "D": "d",
452
+ "DH": "ð",
453
+ "F": "f",
454
+ "G": "ɡ",
455
+ "HH": "h",
456
+ "JH": "dʒ",
457
+ "K": "k",
458
+ "L": "l",
459
+ "M": "m",
460
+ "N": "n",
461
+ "NG": "ŋ",
462
+ "P": "p",
463
+ "R": "r",
464
+ "S": "s",
465
+ "SH": "ʃ",
466
+ "T": "t",
467
+ "TH": "θ",
468
+ "V": "v",
469
+ "W": "w",
470
+ "Y": "j",
471
+ "Z": "z",
472
+ "ZH": "ʒ",
473
+ }
474
+
475
+ converted = []
476
+ for phoneme in cmu_phonemes:
477
+ converted_phoneme = cmu_to_espeak.get(phoneme, phoneme.lower())
478
+ converted.append(converted_phoneme)
479
+
480
+ return converted
481
+
482
+ def _get_ipa(self, word: str) -> str:
483
+ """Get IPA transcription"""
484
+ try:
485
+ return ipa.convert(word)
486
+ except:
487
+ return f"/{word}/"
488
+
489
+ def _estimate_phonemes(self, word: str) -> List[str]:
490
+ """Estimate phonemes for unknown words"""
491
+ # Basic phoneme estimation with eSpeak-style output
492
+ phoneme_map = {
493
+ "ch": ["tʃ"],
494
+ "sh": ["ʃ"],
495
+ "th": ["θ"],
496
+ "ph": ["f"],
497
+ "ck": ["k"],
498
+ "ng": ["ŋ"],
499
+ "qu": ["k", "w"],
500
+ "a": ["æ"],
501
+ "e": ["ɛ"],
502
+ "i": ["ɪ"],
503
+ "o": ["ʌ"],
504
+ "u": ["ʌ"],
505
+ "b": ["b"],
506
+ "c": ["k"],
507
+ "d": ["d"],
508
+ "f": ["f"],
509
+ "g": ["ɡ"],
510
+ "h": ["h"],
511
+ "j": ["dʒ"],
512
+ "k": ["k"],
513
+ "l": ["l"],
514
+ "m": ["m"],
515
+ "n": ["n"],
516
+ "p": ["p"],
517
+ "r": ["r"],
518
+ "s": ["s"],
519
+ "t": ["t"],
520
+ "v": ["v"],
521
+ "w": ["w"],
522
+ "x": ["k", "s"],
523
+ "y": ["j"],
524
+ "z": ["z"],
525
+ }
526
+
527
+ word = word.lower()
528
+ phonemes = []
529
+ i = 0
530
+
531
+ while i < len(word):
532
+ # Check 2-letter combinations first
533
+ if i <= len(word) - 2:
534
+ two_char = word[i : i + 2]
535
+ if two_char in phoneme_map:
536
+ phonemes.extend(phoneme_map[two_char])
537
+ i += 2
538
+ continue
539
+
540
+ # Single character
541
+ char = word[i]
542
+ if char in phoneme_map:
543
+ phonemes.extend(phoneme_map[char])
544
+
545
+ i += 1
546
+
547
+ return phonemes
548
+
549
+
550
+ class PhonemeComparator:
551
+ """Compare reference and learner phoneme sequences"""
552
+
553
+ def __init__(self):
554
+ # Vietnamese speakers' common phoneme substitutions
555
+ self.substitution_patterns = {
556
+ "θ": ["f", "s", "t"], # TH → F, S, T
557
+ "ð": ["d", "z", "v"], # DH → D, Z, V
558
+ "v": ["w", "f"], # V → W, F
559
+ "r": ["l"], # R → L
560
+ "l": ["r"], # L → R
561
+ "z": ["s"], # Z → S
562
+ "ʒ": ["ʃ", "z"], # ZH → SH, Z
563
+ "ŋ": ["n"], # NG → N
564
+ }
565
+
566
+ # Difficulty levels for Vietnamese speakers
567
+ self.difficulty_map = {
568
+ "θ": 0.9, # th (think)
569
+ "ð": 0.9, # th (this)
570
+ "v": 0.8, # v
571
+ "z": 0.8, # z
572
+ "ʒ": 0.9, # zh (measure)
573
+ "r": 0.7, # r
574
+ "l": 0.6, # l
575
+ "w": 0.5, # w
576
+ "f": 0.4, # f
577
+ "s": 0.3, # s
578
+ "ʃ": 0.5, # sh
579
+ "tʃ": 0.4, # ch
580
+ "dʒ": 0.5, # j
581
+ "ŋ": 0.3, # ng
582
+ }
583
+
584
+ def compare_phoneme_sequences(
585
+ self, reference_phonemes: str, learner_phonemes: str
586
+ ) -> List[Dict]:
587
+ """Compare reference and learner phoneme sequences"""
588
+
589
+ # Split phoneme strings
590
+ ref_phones = reference_phonemes.split()
591
+ learner_phones = learner_phonemes.split()
592
+
593
+ print(f"Reference phonemes: {ref_phones}")
594
+ print(f"Learner phonemes: {learner_phones}")
595
+
596
+ # Simple alignment comparison
597
+ comparisons = []
598
+ max_len = max(len(ref_phones), len(learner_phones))
599
+
600
+ for i in range(max_len):
601
+ ref_phoneme = ref_phones[i] if i < len(ref_phones) else ""
602
+ learner_phoneme = learner_phones[i] if i < len(learner_phones) else ""
603
+
604
+ if ref_phoneme and learner_phoneme:
605
+ # Both present - check accuracy
606
+ if ref_phoneme == learner_phoneme:
607
+ status = "correct"
608
+ score = 1.0
609
+ elif self._is_acceptable_substitution(ref_phoneme, learner_phoneme):
610
+ status = "acceptable"
611
+ score = 0.7
612
+ else:
613
+ status = "wrong"
614
+ score = 0.2
615
+
616
+ elif ref_phoneme and not learner_phoneme:
617
+ # Missing phoneme
618
+ status = "missing"
619
+ score = 0.0
620
+
621
+ elif learner_phoneme and not ref_phoneme:
622
+ # Extra phoneme
623
+ status = "extra"
624
+ score = 0.0
625
+ else:
626
+ continue
627
+
628
+ comparison = {
629
+ "position": i,
630
+ "reference_phoneme": ref_phoneme,
631
+ "learner_phoneme": learner_phoneme,
632
+ "status": status,
633
+ "score": score,
634
+ "difficulty": self.difficulty_map.get(ref_phoneme, 0.3),
635
+ }
636
+
637
+ comparisons.append(comparison)
638
+
639
+ return comparisons
640
+
641
+ def _is_acceptable_substitution(self, reference: str, learner: str) -> bool:
642
+ """Check if learner phoneme is acceptable substitution for Vietnamese speakers"""
643
+ acceptable = self.substitution_patterns.get(reference, [])
644
+ return learner in acceptable
645
+
646
+
647
+ # =============================================================================
648
+ # WORD ANALYZER
649
+ # =============================================================================
650
+
651
+
652
+ class WordAnalyzer:
653
+ """Analyze word-level pronunciation accuracy using character-based ASR"""
654
+
655
+ def __init__(self):
656
+ self.g2p = SimpleG2P()
657
+ self.comparator = PhonemeComparator()
658
+
659
+ def analyze_words(self, reference_text: str, learner_phonemes: str) -> Dict:
660
+ """Analyze word-level pronunciation using phoneme representation from character ASR"""
661
+
662
+ # Get reference phonemes by word
663
+ reference_words = self.g2p.text_to_phonemes(reference_text)
664
+
665
+ # Get overall phoneme comparison
666
+ reference_phoneme_string = self.g2p.get_reference_phoneme_string(reference_text)
667
+ phoneme_comparisons = self.comparator.compare_phoneme_sequences(
668
+ reference_phoneme_string, learner_phonemes
669
+ )
670
+
671
+ # Map phonemes back to words
672
+ word_highlights = self._create_word_highlights(
673
+ reference_words, phoneme_comparisons
674
+ )
675
+
676
+ # Identify wrong words
677
+ wrong_words = self._identify_wrong_words(word_highlights, phoneme_comparisons)
678
+
679
+ return {
680
+ "word_highlights": word_highlights,
681
+ "phoneme_differences": phoneme_comparisons,
682
+ "wrong_words": wrong_words,
683
+ }
684
+
685
+ def _create_word_highlights(
686
+ self, reference_words: List[Dict], phoneme_comparisons: List[Dict]
687
+ ) -> List[Dict]:
688
+ """Create word highlighting data"""
689
+
690
+ word_highlights = []
691
+ phoneme_index = 0
692
+
693
+ for word_data in reference_words:
694
+ word = word_data["word"]
695
+ word_phonemes = word_data["phonemes"]
696
+ num_phonemes = len(word_phonemes)
697
+
698
+ # Get phoneme scores for this word
699
+ word_phoneme_scores = []
700
+ for j in range(num_phonemes):
701
+ if phoneme_index + j < len(phoneme_comparisons):
702
+ comparison = phoneme_comparisons[phoneme_index + j]
703
+ word_phoneme_scores.append(comparison["score"])
704
+
705
+ # Calculate word score
706
+ word_score = np.mean(word_phoneme_scores) if word_phoneme_scores else 0.0
707
+
708
+ # Create word highlight
709
+ highlight = {
710
+ "word": word,
711
+ "score": float(word_score),
712
+ "status": self._get_word_status(word_score),
713
+ "color": self._get_word_color(word_score),
714
+ "phonemes": word_phonemes,
715
+ "ipa": word_data["ipa"],
716
+ "phoneme_scores": word_phoneme_scores,
717
+ "phoneme_start_index": phoneme_index,
718
+ "phoneme_end_index": phoneme_index + num_phonemes - 1,
719
+ }
720
+
721
+ word_highlights.append(highlight)
722
+ phoneme_index += num_phonemes
723
+
724
+ return word_highlights
725
+
726
+ def _identify_wrong_words(
727
+ self, word_highlights: List[Dict], phoneme_comparisons: List[Dict]
728
+ ) -> List[Dict]:
729
+ """Identify words that were pronounced incorrectly"""
730
+
731
+ wrong_words = []
732
+
733
+ for word_highlight in word_highlights:
734
+ if word_highlight["score"] < 0.6: # Threshold for wrong pronunciation
735
+
736
+ # Find specific phoneme errors for this word
737
+ start_idx = word_highlight["phoneme_start_index"]
738
+ end_idx = word_highlight["phoneme_end_index"]
739
+
740
+ wrong_phonemes = []
741
+ missing_phonemes = []
742
+
743
+ for i in range(start_idx, min(end_idx + 1, len(phoneme_comparisons))):
744
+ comparison = phoneme_comparisons[i]
745
+
746
+ if comparison["status"] == "wrong":
747
+ wrong_phonemes.append(
748
+ {
749
+ "expected": comparison["reference_phoneme"],
750
+ "actual": comparison["learner_phoneme"],
751
+ "difficulty": comparison["difficulty"],
752
+ }
753
+ )
754
+ elif comparison["status"] == "missing":
755
+ missing_phonemes.append(
756
+ {
757
+ "phoneme": comparison["reference_phoneme"],
758
+ "difficulty": comparison["difficulty"],
759
+ }
760
+ )
761
+
762
+ wrong_word = {
763
+ "word": word_highlight["word"],
764
+ "score": word_highlight["score"],
765
+ "expected_phonemes": word_highlight["phonemes"],
766
+ "ipa": word_highlight["ipa"],
767
+ "wrong_phonemes": wrong_phonemes,
768
+ "missing_phonemes": missing_phonemes,
769
+ "tips": self._get_vietnamese_tips(wrong_phonemes, missing_phonemes),
770
+ }
771
+
772
+ wrong_words.append(wrong_word)
773
+
774
+ return wrong_words
775
+
776
+ def _get_word_status(self, score: float) -> str:
777
+ """Get word status from score"""
778
+ if score >= 0.8:
779
+ return "excellent"
780
+ elif score >= 0.6:
781
+ return "good"
782
+ elif score >= 0.4:
783
+ return "needs_practice"
784
+ else:
785
+ return "poor"
786
+
787
+ def _get_word_color(self, score: float) -> str:
788
+ """Get color for word highlighting"""
789
+ if score >= 0.8:
790
+ return "#22c55e" # Green
791
+ elif score >= 0.6:
792
+ return "#84cc16" # Light green
793
+ elif score >= 0.4:
794
+ return "#eab308" # Yellow
795
+ else:
796
+ return "#ef4444" # Red
797
+
798
+ def _get_vietnamese_tips(
799
+ self, wrong_phonemes: List[Dict], missing_phonemes: List[Dict]
800
+ ) -> List[str]:
801
+ """Get Vietnamese-specific pronunciation tips"""
802
+
803
+ tips = []
804
+
805
+ # Tips for specific Vietnamese pronunciation challenges
806
+ vietnamese_tips = {
807
+ "θ": "Đặt lưỡi giữa răng trên và dưới, thổi nhẹ (think, three)",
808
+ "ð": "Giống θ nhưng rung dây thanh âm (this, that)",
809
+ "v": "Chạm môi dưới vào răng trên, không dùng cả hai môi như tiếng Việt",
810
+ "r": "Cuộn lưỡi nhưng không chạm vào vòm miệng, không lăn lưỡi",
811
+ "l": "Đầu lưỡi chạm vào vòm miệng sau răng",
812
+ "z": "Giống âm 's' nhưng có rung dây thanh âm",
813
+ "ʒ": "Giống âm 'ʃ' (sh) nhưng có rung dây thanh âm",
814
+ "w": "Tròn môi như âm 'u', không dùng răng như âm 'v'",
815
+ }
816
+
817
+ # Add tips for wrong phonemes
818
+ for wrong in wrong_phonemes:
819
+ expected = wrong["expected"]
820
+ actual = wrong["actual"]
821
+
822
+ if expected in vietnamese_tips:
823
+ tips.append(f"Âm '{expected}': {vietnamese_tips[expected]}")
824
+ else:
825
+ tips.append(f"Luyện âm '{expected}' thay vì '{actual}'")
826
+
827
+ # Add tips for missing phonemes
828
+ for missing in missing_phonemes:
829
+ phoneme = missing["phoneme"]
830
+ if phoneme in vietnamese_tips:
831
+ tips.append(f"Thiếu âm '{phoneme}': {vietnamese_tips[phoneme]}")
832
+
833
+ return tips
834
+
835
+
836
+ class SimpleFeedbackGenerator:
837
+ """Generate simple, actionable feedback in Vietnamese"""
838
+
839
+ def generate_feedback(
840
+ self,
841
+ overall_score: float,
842
+ wrong_words: List[Dict],
843
+ phoneme_comparisons: List[Dict],
844
+ ) -> List[str]:
845
+ """Generate Vietnamese feedback"""
846
+
847
+ feedback = []
848
+
849
+ # Overall feedback in Vietnamese
850
+ if overall_score >= 0.8:
851
+ feedback.append("Phát âm rất tốt! Bạn đã làm xuất sắc.")
852
+ elif overall_score >= 0.6:
853
+ feedback.append("Phát âm khá tốt, còn một vài điểm cần cải thiện.")
854
+ elif overall_score >= 0.4:
855
+ feedback.append(
856
+ "Cần luyện tập thêm. Tập trung vào những từ được đánh dấu đỏ."
857
+ )
858
+ else:
859
+ feedback.append("Hãy luyện tập chậm và rõ ràng hơn.")
860
+
861
+ # Wrong words feedback
862
+ if wrong_words:
863
+ if len(wrong_words) <= 3:
864
+ word_names = [w["word"] for w in wrong_words]
865
+ feedback.append(f"Các từ cần luyện tập: {', '.join(word_names)}")
866
+ else:
867
+ feedback.append(
868
+ f"Có {len(wrong_words)} từ cần luyện tập. Tập trung vào từng từ một."
869
+ )
870
+
871
+ # Most problematic phonemes
872
+ problem_phonemes = defaultdict(int)
873
+ for comparison in phoneme_comparisons:
874
+ if comparison["status"] in ["wrong", "missing"]:
875
+ phoneme = comparison["reference_phoneme"]
876
+ problem_phonemes[phoneme] += 1
877
+
878
+ if problem_phonemes:
879
+ most_difficult = sorted(
880
+ problem_phonemes.items(), key=lambda x: x[1], reverse=True
881
+ )
882
+ top_problem = most_difficult[0][0]
883
+
884
+ phoneme_tips = {
885
+ "θ": "Lưỡi giữa răng, thổi nhẹ",
886
+ "ð": "Lưỡi giữa răng, rung dây thanh",
887
+ "v": "Môi dưới chạm răng trên",
888
+ "r": "Cuộn lưỡi, không chạm vòm miệng",
889
+ "l": "Lưỡi chạm vòm miệng",
890
+ "z": "Như 's' nhưng rung dây thanh",
891
+ }
892
+
893
+ if top_problem in phoneme_tips:
894
+ feedback.append(
895
+ f"Âm khó nhất '{top_problem}': {phoneme_tips[top_problem]}"
896
+ )
897
+
898
+ return feedback
899
+
900
+
901
+ class SimplePronunciationAssessor:
902
+ """Main pronunciation assessor supporting both normal (Whisper) and advanced (Wav2Vec2) modes"""
903
+
904
+ def __init__(self):
905
+ print("Initializing Simple Pronunciation Assessor...")
906
+ self.wav2vec2_asr = Wav2Vec2CharacterASRONNX() # Advanced mode
907
+ self.whisper_asr = WhisperASR() # Normal mode
908
+ self.word_analyzer = WordAnalyzer()
909
+ self.feedback_generator = SimpleFeedbackGenerator()
910
+ print("Initialization completed")
911
+
912
+ def assess_pronunciation(
913
+ self, audio_path: str, reference_text: str, mode: str = "normal"
914
+ ) -> Dict:
915
+ """
916
+ Main assessment function with mode selection
917
+
918
+ Args:
919
+ audio_path: Path to audio file
920
+ reference_text: Reference text to compare
921
+ mode: 'normal' (Whisper) or 'advanced' (Wav2Vec2)
922
+
923
+ Output: Word highlights + Phoneme differences + Wrong words
924
+ """
925
+
926
+ print(f"Starting pronunciation assessment in {mode} mode...")
927
+
928
+ # Step 1: Choose ASR model based on mode
929
+ if mode == "advanced":
930
+ print("Step 1: Using Wav2Vec2 character transcription...")
931
+ asr_result = self.wav2vec2_asr.transcribe_to_characters(audio_path)
932
+ model_info = f"Wav2Vec2-Character ({self.wav2vec2_asr.model_name})"
933
+ else: # normal mode
934
+ print("Step 1: Using Whisper transcription...")
935
+ asr_result = self.whisper_asr.transcribe_to_text(audio_path)
936
+ model_info = f"Whisper ({self.whisper_asr.model_name})"
937
+
938
+ character_transcript = asr_result["character_transcript"]
939
+ phoneme_representation = asr_result["phoneme_representation"]
940
+
941
+ print(f"Character transcript: {character_transcript}")
942
+ print(f"Phoneme representation: {phoneme_representation}")
943
+
944
+ # Step 2: Word analysis using phoneme representation
945
+ print("Step 2: Analyzing words...")
946
+ analysis_result = self.word_analyzer.analyze_words(
947
+ reference_text, phoneme_representation
948
+ )
949
+
950
+ # Step 3: Calculate overall score
951
+ phoneme_comparisons = analysis_result["phoneme_differences"]
952
+ overall_score = self._calculate_overall_score(phoneme_comparisons)
953
+
954
+ # Step 4: Generate feedback
955
+ print("Step 3: Generating feedback...")
956
+ feedback = self.feedback_generator.generate_feedback(
957
+ overall_score, analysis_result["wrong_words"], phoneme_comparisons
958
+ )
959
+
960
+ result = {
961
+ "transcript": character_transcript, # What user actually said
962
+ "transcript_phonemes": phoneme_representation,
963
+ "user_phonemes": phoneme_representation, # Alias for UI clarity
964
+ "character_transcript": character_transcript,
965
+ "overall_score": overall_score,
966
+ "word_highlights": analysis_result["word_highlights"],
967
+ "phoneme_differences": phoneme_comparisons,
968
+ "wrong_words": analysis_result["wrong_words"],
969
+ "feedback": feedback,
970
+ "processing_info": {
971
+ "model_used": model_info,
972
+ "mode": mode,
973
+ "character_based": mode == "advanced",
974
+ "language_model_correction": mode == "normal",
975
+ "raw_output": mode == "advanced",
976
+ },
977
+ }
978
+
979
+ print("Assessment completed successfully")
980
+ return result
981
+
982
+ def _calculate_overall_score(self, phoneme_comparisons: List[Dict]) -> float:
983
+ """Calculate overall pronunciation score"""
984
+ if not phoneme_comparisons:
985
+ return 0.0
986
+
987
+ total_score = sum(comparison["score"] for comparison in phoneme_comparisons)
988
+ return total_score / len(phoneme_comparisons)
989
+
990
+
991
+ def convert_numpy_types(obj):
992
+ """Convert numpy types to Python native types"""
993
+ if isinstance(obj, np.integer):
994
+ return int(obj)
995
+ elif isinstance(obj, np.floating):
996
+ return float(obj)
997
+ elif isinstance(obj, np.ndarray):
998
+ return obj.tolist()
999
+ elif isinstance(obj, dict):
1000
+ return {key: convert_numpy_types(value) for key, value in obj.items()}
1001
+ elif isinstance(obj, list):
1002
+ return [convert_numpy_types(item) for item in obj]
1003
+ else:
1004
+ return obj
src/apis/routes/speaking_route.py CHANGED
@@ -1,38 +1,23 @@
1
- # PRONUNCIATION ASSESSMENT USING WAV2VEC2PHONEME
2
- # Input: Audio + Reference Text → Output: Word highlights + Phoneme diff + Wrong words
3
- # Uses Wav2Vec2Phoneme for accurate phoneme-level transcription without language model correction
4
-
5
- from fastapi import FastAPI, UploadFile, File, Form, HTTPException, APIRouter
6
- from fastapi.middleware.cors import CORSMiddleware
7
  from pydantic import BaseModel
8
- from typing import List, Dict, Optional
9
  import tempfile
10
- import os
11
  import numpy as np
12
- import librosa
13
- import nltk
14
- import eng_to_ipa as ipa
15
- import torch
16
  import re
17
- from collections import defaultdict
18
  import warnings
19
- from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2PhonemeCTCTokenizer
 
 
 
 
 
 
20
 
21
  warnings.filterwarnings("ignore")
22
 
23
- # Download required NLTK data
24
- try:
25
- nltk.download("cmudict", quiet=True)
26
- from nltk.corpus import cmudict
27
- except:
28
- print("Warning: NLTK data not available")
29
-
30
- # =============================================================================
31
- # MODELS
32
- # =============================================================================
33
-
34
  router = APIRouter(prefix="/pronunciation", tags=["Pronunciation"])
35
 
 
36
  class PronunciationAssessmentResult(BaseModel):
37
  transcript: str # What the user actually said (character transcript)
38
  transcript_phonemes: str # User's phonemes
@@ -45,843 +30,145 @@ class PronunciationAssessmentResult(BaseModel):
45
  feedback: List[str]
46
  processing_info: Dict
47
 
48
- # =============================================================================
49
- # WAV2VEC2 PHONEME ASR
50
- # =============================================================================
51
-
52
- class Wav2Vec2CharacterASR:
53
- """Wav2Vec2 character-level ASR without language model correction"""
54
-
55
- def __init__(self, model_name: str = "facebook/wav2vec2-base-960h"):
56
- """
57
- Initialize Wav2Vec2 character-level model
58
- Available models:
59
- - facebook/wav2vec2-large-960h-lv60-self (character-level, no LM)
60
- - facebook/wav2vec2-base-960h (character-level, no LM)
61
- - facebook/wav2vec2-large-960h (character-level, no LM)
62
- """
63
- print(f"Loading Wav2Vec2 character model: {model_name}")
64
-
65
- try:
66
- self.processor = Wav2Vec2Processor.from_pretrained(model_name)
67
- self.model = Wav2Vec2ForCTC.from_pretrained(model_name)
68
- self.model.eval()
69
- print("Wav2Vec2 character model loaded successfully")
70
- self.model_name = model_name
71
- except Exception as e:
72
- print(f"Error loading model {model_name}: {e}")
73
- # Fallback to base model
74
- fallback_model = "facebook/wav2vec2-base-960h"
75
- print(f"Trying fallback model: {fallback_model}")
76
- try:
77
- self.processor = Wav2Vec2Processor.from_pretrained(fallback_model)
78
- self.model = Wav2Vec2ForCTC.from_pretrained(fallback_model)
79
- self.model.eval()
80
- self.model_name = fallback_model
81
- print("Fallback model loaded successfully")
82
- except Exception as e2:
83
- raise Exception(f"Failed to load both models. Original error: {e}, Fallback error: {e2}")
84
-
85
- self.sample_rate = 16000
86
-
87
- def transcribe_to_characters(self, audio_path: str) -> Dict:
88
- """
89
- Transcribe audio directly to characters (no language model correction)
90
- Returns raw character sequence as produced by the model
91
- """
92
- try:
93
- # Load audio
94
- speech, sr = librosa.load(audio_path, sr=self.sample_rate)
95
-
96
- # Prepare input
97
- input_values = self.processor(
98
- speech,
99
- sampling_rate=self.sample_rate,
100
- return_tensors="pt"
101
- ).input_values
102
-
103
- # Get model predictions (no language model involved)
104
- with torch.no_grad():
105
- logits = self.model(input_values).logits
106
- predicted_ids = torch.argmax(logits, dim=-1)
107
-
108
- # Decode to characters directly
109
- character_transcript = self.processor.batch_decode(predicted_ids)[0]
110
-
111
- # Clean up character transcript
112
- character_transcript = self._clean_character_transcript(character_transcript)
113
-
114
- # Convert characters to phoneme-like representation
115
- phoneme_like_transcript = self._characters_to_phoneme_representation(character_transcript)
116
-
117
- return {
118
- "character_transcript": character_transcript,
119
- "phoneme_representation": phoneme_like_transcript,
120
- "raw_predicted_ids": predicted_ids[0].tolist(),
121
- "confidence_scores": torch.softmax(logits, dim=-1).max(dim=-1)[0][0].tolist()[:100] # Limit for JSON
122
- }
123
-
124
- except Exception as e:
125
- print(f"Transcription error: {e}")
126
- return {
127
- "character_transcript": "",
128
- "phoneme_representation": "",
129
- "raw_predicted_ids": [],
130
- "confidence_scores": []
131
- }
132
-
133
- def _clean_character_transcript(self, transcript: str) -> str:
134
- """Clean and standardize character transcript"""
135
- # Remove extra spaces and special tokens
136
- cleaned = re.sub(r'\s+', ' ', transcript)
137
- cleaned = cleaned.strip().lower()
138
-
139
- return cleaned
140
-
141
- def _characters_to_phoneme_representation(self, text: str) -> str:
142
- """Convert character-based transcript to phoneme-like representation for comparison"""
143
- # This is a simple character-to-phoneme mapping for pronunciation comparison
144
- # The idea is to convert the raw character output to something comparable with reference phonemes
145
-
146
- if not text:
147
- return ""
148
-
149
- words = text.split()
150
- phoneme_words = []
151
-
152
- # Use our G2P to convert transcript words to phonemes
153
- g2p = SimpleG2P()
154
-
155
- for word in words:
156
- try:
157
- word_data = g2p.text_to_phonemes(word)[0]
158
- phoneme_words.extend(word_data["phonemes"])
159
- except:
160
- # Fallback: simple letter-to-sound mapping
161
- phoneme_words.extend(self._simple_letter_to_phoneme(word))
162
-
163
- return " ".join(phoneme_words)
164
-
165
- def _simple_letter_to_phoneme(self, word: str) -> List[str]:
166
- """Simple fallback letter-to-phoneme conversion"""
167
- letter_to_phoneme = {
168
- 'a': 'æ', 'b': 'b', 'c': 'k', 'd': 'd', 'e': 'ɛ',
169
- 'f': 'f', 'g': 'ɡ', 'h': 'h', 'i': 'ɪ', 'j': 'dʒ',
170
- 'k': 'k', 'l': 'l', 'm': 'm', 'n': 'n', 'o': 'ʌ',
171
- 'p': 'p', 'q': 'k', 'r': 'r', 's': 's', 't': 't',
172
- 'u': 'ʌ', 'v': 'v', 'w': 'w', 'x': 'ks', 'y': 'j', 'z': 'z'
173
- }
174
-
175
- phonemes = []
176
- for letter in word.lower():
177
- if letter in letter_to_phoneme:
178
- phonemes.append(letter_to_phoneme[letter])
179
-
180
- return phonemes
181
-
182
- # =============================================================================
183
- # SIMPLE G2P FOR REFERENCE
184
- # =============================================================================
185
-
186
- class SimpleG2P:
187
- """Simple Grapheme-to-Phoneme converter for reference text"""
188
-
189
- def __init__(self):
190
- try:
191
- self.cmu_dict = cmudict.dict()
192
- except:
193
- self.cmu_dict = {}
194
- print("Warning: CMU dictionary not available")
195
-
196
- def text_to_phonemes(self, text: str) -> List[Dict]:
197
- """Convert text to phoneme sequence"""
198
- words = self._clean_text(text).split()
199
- phoneme_sequence = []
200
-
201
- for word in words:
202
- word_phonemes = self._get_word_phonemes(word)
203
- phoneme_sequence.append({
204
- "word": word,
205
- "phonemes": word_phonemes,
206
- "ipa": self._get_ipa(word),
207
- "phoneme_string": " ".join(word_phonemes)
208
- })
209
-
210
- return phoneme_sequence
211
-
212
- def get_reference_phoneme_string(self, text: str) -> str:
213
- """Get reference phoneme string for comparison"""
214
- phoneme_sequence = self.text_to_phonemes(text)
215
- all_phonemes = []
216
-
217
- for word_data in phoneme_sequence:
218
- all_phonemes.extend(word_data["phonemes"])
219
-
220
- return " ".join(all_phonemes)
221
-
222
- def _clean_text(self, text: str) -> str:
223
- """Clean text for processing"""
224
- text = re.sub(r"[^\w\s\']", " ", text)
225
- text = re.sub(r"\s+", " ", text)
226
- return text.lower().strip()
227
-
228
- def _get_word_phonemes(self, word: str) -> List[str]:
229
- """Get phonemes for a word"""
230
- word_lower = word.lower()
231
-
232
- if word_lower in self.cmu_dict:
233
- # Remove stress markers and convert to Wav2Vec2 phoneme format
234
- phonemes = self.cmu_dict[word_lower][0]
235
- clean_phonemes = [re.sub(r"[0-9]", "", p) for p in phonemes]
236
- return self._convert_to_wav2vec_format(clean_phonemes)
237
- else:
238
- return self._estimate_phonemes(word)
239
-
240
- def _convert_to_wav2vec_format(self, cmu_phonemes: List[str]) -> List[str]:
241
- """Convert CMU phonemes to Wav2Vec2 format"""
242
- # Mapping from CMU to Wav2Vec2/eSpeak phonemes
243
- cmu_to_espeak = {
244
- "AA": "ɑ", "AE": "æ", "AH": "ʌ", "AO": "ɔ", "AW": "aʊ",
245
- "AY": "aɪ", "EH": "ɛ", "ER": "ɝ", "EY": "eɪ", "IH": "ɪ",
246
- "IY": "i", "OW": "oʊ", "OY": "ɔɪ", "UH": "ʊ", "UW": "u",
247
- "B": "b", "CH": "tʃ", "D": "d", "DH": "ð", "F": "f",
248
- "G": "ɡ", "HH": "h", "JH": "dʒ", "K": "k", "L": "l",
249
- "M": "m", "N": "n", "NG": "ŋ", "P": "p", "R": "r",
250
- "S": "s", "SH": "ʃ", "T": "t", "TH": "θ", "V": "v",
251
- "W": "w", "Y": "j", "Z": "z", "ZH": "ʒ"
252
- }
253
-
254
- converted = []
255
- for phoneme in cmu_phonemes:
256
- converted_phoneme = cmu_to_espeak.get(phoneme, phoneme.lower())
257
- converted.append(converted_phoneme)
258
-
259
- return converted
260
-
261
- def _get_ipa(self, word: str) -> str:
262
- """Get IPA transcription"""
263
- try:
264
- return ipa.convert(word)
265
- except:
266
- return f"/{word}/"
267
-
268
- def _estimate_phonemes(self, word: str) -> List[str]:
269
- """Estimate phonemes for unknown words"""
270
- # Basic phoneme estimation with eSpeak-style output
271
- phoneme_map = {
272
- "ch": ["tʃ"], "sh": ["ʃ"], "th": ["θ"], "ph": ["f"],
273
- "ck": ["k"], "ng": ["ŋ"], "qu": ["k", "w"],
274
- "a": ["æ"], "e": ["ɛ"], "i": ["ɪ"], "o": ["ʌ"], "u": ["ʌ"],
275
- "b": ["b"], "c": ["k"], "d": ["d"], "f": ["f"], "g": ["ɡ"],
276
- "h": ["h"], "j": ["dʒ"], "k": ["k"], "l": ["l"], "m": ["m"],
277
- "n": ["n"], "p": ["p"], "r": ["r"], "s": ["s"], "t": ["t"],
278
- "v": ["v"], "w": ["w"], "x": ["k", "s"], "y": ["j"], "z": ["z"]
279
- }
280
-
281
- word = word.lower()
282
- phonemes = []
283
- i = 0
284
-
285
- while i < len(word):
286
- # Check 2-letter combinations first
287
- if i <= len(word) - 2:
288
- two_char = word[i:i+2]
289
- if two_char in phoneme_map:
290
- phonemes.extend(phoneme_map[two_char])
291
- i += 2
292
- continue
293
-
294
- # Single character
295
- char = word[i]
296
- if char in phoneme_map:
297
- phonemes.extend(phoneme_map[char])
298
-
299
- i += 1
300
-
301
- return phonemes
302
-
303
- # =============================================================================
304
- # PHONEME COMPARATOR
305
- # =============================================================================
306
-
307
- class PhonemeComparator:
308
- """Compare reference and learner phoneme sequences"""
309
-
310
- def __init__(self):
311
- # Vietnamese speakers' common phoneme substitutions
312
- self.substitution_patterns = {
313
- "θ": ["f", "s", "t"], # TH → F, S, T
314
- "ð": ["d", "z", "v"], # DH → D, Z, V
315
- "v": ["w", "f"], # V → W, F
316
- "r": ["l"], # R → L
317
- "l": ["r"], # L → R
318
- "z": ["s"], # Z → S
319
- "ʒ": ["ʃ", "z"], # ZH → SH, Z
320
- "ŋ": ["n"], # NG → N
321
- }
322
-
323
- # Difficulty levels for Vietnamese speakers
324
- self.difficulty_map = {
325
- "θ": 0.9, # th (think)
326
- "ð": 0.9, # th (this)
327
- "v": 0.8, # v
328
- "z": 0.8, # z
329
- "ʒ": 0.9, # zh (measure)
330
- "r": 0.7, # r
331
- "l": 0.6, # l
332
- "w": 0.5, # w
333
- "f": 0.4, # f
334
- "s": 0.3, # s
335
- "ʃ": 0.5, # sh
336
- "tʃ": 0.4, # ch
337
- "dʒ": 0.5, # j
338
- "ŋ": 0.3, # ng
339
- }
340
-
341
- def compare_phoneme_sequences(self, reference_phonemes: str,
342
- learner_phonemes: str) -> List[Dict]:
343
- """Compare reference and learner phoneme sequences"""
344
-
345
- # Split phoneme strings
346
- ref_phones = reference_phonemes.split()
347
- learner_phones = learner_phonemes.split()
348
-
349
- print(f"Reference phonemes: {ref_phones}")
350
- print(f"Learner phonemes: {learner_phones}")
351
-
352
- # Simple alignment comparison
353
- comparisons = []
354
- max_len = max(len(ref_phones), len(learner_phones))
355
-
356
- for i in range(max_len):
357
- ref_phoneme = ref_phones[i] if i < len(ref_phones) else ""
358
- learner_phoneme = learner_phones[i] if i < len(learner_phones) else ""
359
-
360
- if ref_phoneme and learner_phoneme:
361
- # Both present - check accuracy
362
- if ref_phoneme == learner_phoneme:
363
- status = "correct"
364
- score = 1.0
365
- elif self._is_acceptable_substitution(ref_phoneme, learner_phoneme):
366
- status = "acceptable"
367
- score = 0.7
368
- else:
369
- status = "wrong"
370
- score = 0.2
371
-
372
- elif ref_phoneme and not learner_phoneme:
373
- # Missing phoneme
374
- status = "missing"
375
- score = 0.0
376
-
377
- elif learner_phoneme and not ref_phoneme:
378
- # Extra phoneme
379
- status = "extra"
380
- score = 0.0
381
- else:
382
- continue
383
-
384
- comparison = {
385
- "position": i,
386
- "reference_phoneme": ref_phoneme,
387
- "learner_phoneme": learner_phoneme,
388
- "status": status,
389
- "score": score,
390
- "difficulty": self.difficulty_map.get(ref_phoneme, 0.3)
391
- }
392
-
393
- comparisons.append(comparison)
394
-
395
- return comparisons
396
-
397
- def _is_acceptable_substitution(self, reference: str, learner: str) -> bool:
398
- """Check if learner phoneme is acceptable substitution for Vietnamese speakers"""
399
- acceptable = self.substitution_patterns.get(reference, [])
400
- return learner in acceptable
401
-
402
- # =============================================================================
403
- # WORD ANALYZER
404
- # =============================================================================
405
-
406
- class WordAnalyzer:
407
- """Analyze word-level pronunciation accuracy using character-based ASR"""
408
-
409
- def __init__(self):
410
- self.g2p = SimpleG2P()
411
- self.comparator = PhonemeComparator()
412
-
413
- def analyze_words(self, reference_text: str, learner_phonemes: str) -> Dict:
414
- """Analyze word-level pronunciation using phoneme representation from character ASR"""
415
-
416
- # Get reference phonemes by word
417
- reference_words = self.g2p.text_to_phonemes(reference_text)
418
-
419
- # Get overall phoneme comparison
420
- reference_phoneme_string = self.g2p.get_reference_phoneme_string(reference_text)
421
- phoneme_comparisons = self.comparator.compare_phoneme_sequences(
422
- reference_phoneme_string, learner_phonemes
423
- )
424
-
425
- # Map phonemes back to words
426
- word_highlights = self._create_word_highlights(reference_words, phoneme_comparisons)
427
-
428
- # Identify wrong words
429
- wrong_words = self._identify_wrong_words(word_highlights, phoneme_comparisons)
430
-
431
- return {
432
- "word_highlights": word_highlights,
433
- "phoneme_differences": phoneme_comparisons,
434
- "wrong_words": wrong_words
435
- }
436
-
437
- def _create_word_highlights(self, reference_words: List[Dict],
438
- phoneme_comparisons: List[Dict]) -> List[Dict]:
439
- """Create word highlighting data"""
440
-
441
- word_highlights = []
442
- phoneme_index = 0
443
-
444
- for word_data in reference_words:
445
- word = word_data["word"]
446
- word_phonemes = word_data["phonemes"]
447
- num_phonemes = len(word_phonemes)
448
-
449
- # Get phoneme scores for this word
450
- word_phoneme_scores = []
451
- for j in range(num_phonemes):
452
- if phoneme_index + j < len(phoneme_comparisons):
453
- comparison = phoneme_comparisons[phoneme_index + j]
454
- word_phoneme_scores.append(comparison["score"])
455
-
456
- # Calculate word score
457
- word_score = np.mean(word_phoneme_scores) if word_phoneme_scores else 0.0
458
-
459
- # Create word highlight
460
- highlight = {
461
- "word": word,
462
- "score": float(word_score),
463
- "status": self._get_word_status(word_score),
464
- "color": self._get_word_color(word_score),
465
- "phonemes": word_phonemes,
466
- "ipa": word_data["ipa"],
467
- "phoneme_scores": word_phoneme_scores,
468
- "phoneme_start_index": phoneme_index,
469
- "phoneme_end_index": phoneme_index + num_phonemes - 1
470
- }
471
-
472
- word_highlights.append(highlight)
473
- phoneme_index += num_phonemes
474
-
475
- return word_highlights
476
-
477
- def _identify_wrong_words(self, word_highlights: List[Dict],
478
- phoneme_comparisons: List[Dict]) -> List[Dict]:
479
- """Identify words that were pronounced incorrectly"""
480
-
481
- wrong_words = []
482
-
483
- for word_highlight in word_highlights:
484
- if word_highlight["score"] < 0.6: # Threshold for wrong pronunciation
485
-
486
- # Find specific phoneme errors for this word
487
- start_idx = word_highlight["phoneme_start_index"]
488
- end_idx = word_highlight["phoneme_end_index"]
489
-
490
- wrong_phonemes = []
491
- missing_phonemes = []
492
-
493
- for i in range(start_idx, min(end_idx + 1, len(phoneme_comparisons))):
494
- comparison = phoneme_comparisons[i]
495
-
496
- if comparison["status"] == "wrong":
497
- wrong_phonemes.append({
498
- "expected": comparison["reference_phoneme"],
499
- "actual": comparison["learner_phoneme"],
500
- "difficulty": comparison["difficulty"]
501
- })
502
- elif comparison["status"] == "missing":
503
- missing_phonemes.append({
504
- "phoneme": comparison["reference_phoneme"],
505
- "difficulty": comparison["difficulty"]
506
- })
507
-
508
- wrong_word = {
509
- "word": word_highlight["word"],
510
- "score": word_highlight["score"],
511
- "expected_phonemes": word_highlight["phonemes"],
512
- "ipa": word_highlight["ipa"],
513
- "wrong_phonemes": wrong_phonemes,
514
- "missing_phonemes": missing_phonemes,
515
- "tips": self._get_vietnamese_tips(wrong_phonemes, missing_phonemes)
516
- }
517
-
518
- wrong_words.append(wrong_word)
519
-
520
- return wrong_words
521
-
522
- def _get_word_status(self, score: float) -> str:
523
- """Get word status from score"""
524
- if score >= 0.8:
525
- return "excellent"
526
- elif score >= 0.6:
527
- return "good"
528
- elif score >= 0.4:
529
- return "needs_practice"
530
- else:
531
- return "poor"
532
-
533
- def _get_word_color(self, score: float) -> str:
534
- """Get color for word highlighting"""
535
- if score >= 0.8:
536
- return "#22c55e" # Green
537
- elif score >= 0.6:
538
- return "#84cc16" # Light green
539
- elif score >= 0.4:
540
- return "#eab308" # Yellow
541
- else:
542
- return "#ef4444" # Red
543
-
544
- def _get_vietnamese_tips(self, wrong_phonemes: List[Dict],
545
- missing_phonemes: List[Dict]) -> List[str]:
546
- """Get Vietnamese-specific pronunciation tips"""
547
-
548
- tips = []
549
-
550
- # Tips for specific Vietnamese pronunciation challenges
551
- vietnamese_tips = {
552
- "θ": "Đặt lưỡi giữa răng trên và dưới, thổi nhẹ (think, three)",
553
- "ð": "Giống θ nhưng rung dây thanh âm (this, that)",
554
- "v": "Chạm môi dưới vào răng trên, không dùng cả hai môi như tiếng Việt",
555
- "r": "Cuộn lưỡi nhưng không chạm vào vòm miệng, không lăn lưỡi",
556
- "l": "Đầu lưỡi chạm vào vòm miệng sau răng",
557
- "z": "Giống âm 's' nhưng có rung dây thanh âm",
558
- "ʒ": "Giống âm 'ʃ' (sh) nhưng có rung dây thanh âm",
559
- "w": "Tròn môi như âm 'u', không dùng răng như âm 'v'"
560
- }
561
-
562
- # Add tips for wrong phonemes
563
- for wrong in wrong_phonemes:
564
- expected = wrong["expected"]
565
- actual = wrong["actual"]
566
-
567
- if expected in vietnamese_tips:
568
- tips.append(f"Âm '{expected}': {vietnamese_tips[expected]}")
569
- else:
570
- tips.append(f"Luyện âm '{expected}' thay vì '{actual}'")
571
-
572
- # Add tips for missing phonemes
573
- for missing in missing_phonemes:
574
- phoneme = missing["phoneme"]
575
- if phoneme in vietnamese_tips:
576
- tips.append(f"Thiếu âm '{phoneme}': {vietnamese_tips[phoneme]}")
577
-
578
- return tips
579
-
580
- # =============================================================================
581
- # FEEDBACK GENERATOR
582
- # =============================================================================
583
-
584
- class SimpleFeedbackGenerator:
585
- """Generate simple, actionable feedback in Vietnamese"""
586
-
587
- def generate_feedback(self, overall_score: float, wrong_words: List[Dict],
588
- phoneme_comparisons: List[Dict]) -> List[str]:
589
- """Generate Vietnamese feedback"""
590
-
591
- feedback = []
592
-
593
- # Overall feedback in Vietnamese
594
- if overall_score >= 0.8:
595
- feedback.append("Phát âm rất tốt! Bạn đã làm xuất sắc.")
596
- elif overall_score >= 0.6:
597
- feedback.append("Phát âm khá tốt, còn một vài điểm cần cải thiện.")
598
- elif overall_score >= 0.4:
599
- feedback.append("Cần luyện tập thêm. Tập trung vào những từ được đánh dấu đỏ.")
600
- else:
601
- feedback.append("Hãy luyện tập chậm và rõ ràng hơn.")
602
-
603
- # Wrong words feedback
604
- if wrong_words:
605
- if len(wrong_words) <= 3:
606
- word_names = [w["word"] for w in wrong_words]
607
- feedback.append(f"Các từ cần luyện tập: {', '.join(word_names)}")
608
- else:
609
- feedback.append(f"Có {len(wrong_words)} từ cần luyện tập. Tập trung vào từng từ một.")
610
-
611
- # Most problematic phonemes
612
- problem_phonemes = defaultdict(int)
613
- for comparison in phoneme_comparisons:
614
- if comparison["status"] in ["wrong", "missing"]:
615
- phoneme = comparison["reference_phoneme"]
616
- problem_phonemes[phoneme] += 1
617
-
618
- if problem_phonemes:
619
- most_difficult = sorted(problem_phonemes.items(), key=lambda x: x[1], reverse=True)
620
- top_problem = most_difficult[0][0]
621
-
622
- phoneme_tips = {
623
- "θ": "Lưỡi giữa răng, thổi nhẹ",
624
- "ð": "Lưỡi giữa răng, rung dây thanh",
625
- "v": "Môi dưới chạm răng trên",
626
- "r": "Cuộn lưỡi, không chạm vòm miệng",
627
- "l": "Lưỡi chạm vòm miệng",
628
- "z": "Như 's' nhưng rung dây thanh"
629
- }
630
-
631
- if top_problem in phoneme_tips:
632
- feedback.append(f"Âm khó nhất '{top_problem}': {phoneme_tips[top_problem]}")
633
-
634
- return feedback
635
-
636
- # =============================================================================
637
- # MAIN PRONUNCIATION ASSESSOR
638
- # =============================================================================
639
-
640
- class SimplePronunciationAssessor:
641
- """Main pronunciation assessor using Wav2Vec2 character-level model"""
642
-
643
- def __init__(self):
644
- print("Initializing Simple Pronunciation Assessor...")
645
- self.asr = Wav2Vec2CharacterASR() # Updated to use character-based ASR
646
- self.word_analyzer = WordAnalyzer()
647
- self.feedback_generator = SimpleFeedbackGenerator()
648
- print("Initialization completed")
649
-
650
- def assess_pronunciation(self, audio_path: str, reference_text: str) -> Dict:
651
- """
652
- Main assessment function
653
- Input: Audio path + Reference text
654
- Output: Word highlights + Phoneme differences + Wrong words
655
- """
656
-
657
- print("Starting pronunciation assessment...")
658
-
659
- # Step 1: Wav2Vec2 character transcription (no language model)
660
- print("Step 1: Transcribing to characters...")
661
- asr_result = self.asr.transcribe_to_characters(audio_path)
662
- character_transcript = asr_result["character_transcript"]
663
- phoneme_representation = asr_result["phoneme_representation"]
664
-
665
- print(f"Character transcript: {character_transcript}")
666
- print(f"Phoneme representation: {phoneme_representation}")
667
-
668
- # Step 2: Word analysis using phoneme representation
669
- print("Step 2: Analyzing words...")
670
- analysis_result = self.word_analyzer.analyze_words(reference_text, phoneme_representation)
671
-
672
- # Step 3: Calculate overall score
673
- phoneme_comparisons = analysis_result["phoneme_differences"]
674
- overall_score = self._calculate_overall_score(phoneme_comparisons)
675
-
676
- # Step 4: Generate feedback
677
- print("Step 3: Generating feedback...")
678
- feedback = self.feedback_generator.generate_feedback(
679
- overall_score, analysis_result["wrong_words"], phoneme_comparisons
680
- )
681
-
682
- result = {
683
- "transcript": character_transcript, # What user actually said
684
- "transcript_phonemes": phoneme_representation,
685
- "user_phonemes": phoneme_representation, # Alias for UI clarity
686
- "character_transcript": character_transcript,
687
- "overall_score": overall_score,
688
- "word_highlights": analysis_result["word_highlights"],
689
- "phoneme_differences": phoneme_comparisons,
690
- "wrong_words": analysis_result["wrong_words"],
691
- "feedback": feedback,
692
- "processing_info": {
693
- "model_used": f"Wav2Vec2-Character ({self.asr.model_name})",
694
- "character_based": True,
695
- "language_model_correction": False,
696
- "raw_output": True
697
- }
698
- }
699
-
700
- print("Assessment completed successfully")
701
- return result
702
-
703
- def _calculate_overall_score(self, phoneme_comparisons: List[Dict]) -> float:
704
- """Calculate overall pronunciation score"""
705
- if not phoneme_comparisons:
706
- return 0.0
707
-
708
- total_score = sum(comparison["score"] for comparison in phoneme_comparisons)
709
- return total_score / len(phoneme_comparisons)
710
-
711
- # =============================================================================
712
- # API ENDPOINT
713
- # =============================================================================
714
 
715
- # Initialize assessor
716
  assessor = SimplePronunciationAssessor()
717
 
718
- def convert_numpy_types(obj):
719
- """Convert numpy types to Python native types"""
720
- if isinstance(obj, np.integer):
721
- return int(obj)
722
- elif isinstance(obj, np.floating):
723
- return float(obj)
724
- elif isinstance(obj, np.ndarray):
725
- return obj.tolist()
726
- elif isinstance(obj, dict):
727
- return {key: convert_numpy_types(value) for key, value in obj.items()}
728
- elif isinstance(obj, list):
729
- return [convert_numpy_types(item) for item in obj]
730
- else:
731
- return obj
732
 
733
  @router.post("/assess", response_model=PronunciationAssessmentResult)
734
  async def assess_pronunciation(
735
  audio: UploadFile = File(..., description="Audio file (.wav, .mp3, .m4a)"),
736
- reference_text: str = Form(..., description="Reference text to pronounce")
 
 
 
 
737
  ):
738
  """
739
- Pronunciation Assessment API using Wav2Vec2 Character-level Model
740
-
741
  Key Features:
742
- - Uses facebook/wav2vec2-large-960h-lv60-self for character transcription
743
- - NO language model correction (shows actual pronunciation errors)
 
744
  - Character-level accuracy converted to phoneme representation
745
  - Vietnamese-optimized feedback and tips
746
-
747
- Input: Audio file + Reference text
748
  Output: Word highlights + Phoneme differences + Wrong words
749
  """
750
-
751
  import time
 
752
  start_time = time.time()
753
-
 
 
 
 
 
 
754
  # Validate inputs
755
  if not reference_text.strip():
756
  raise HTTPException(status_code=400, detail="Reference text cannot be empty")
757
-
758
  if len(reference_text) > 500:
759
- raise HTTPException(status_code=400, detail="Reference text too long (max 500 characters)")
760
-
 
 
761
  # Check for valid English characters
762
  if not re.match(r"^[a-zA-Z\s\'\-\.!?,;:]+$", reference_text):
763
  raise HTTPException(
764
  status_code=400,
765
- detail="Text must contain only English letters, spaces, and basic punctuation"
766
  )
767
-
768
  try:
769
  # Save uploaded file temporarily
770
  file_extension = ".wav"
771
  if audio.filename and "." in audio.filename:
772
  file_extension = f".{audio.filename.split('.')[-1]}"
773
-
774
- with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp_file:
 
 
775
  content = await audio.read()
776
  tmp_file.write(content)
777
  tmp_file.flush()
778
-
779
- print(f"Processing audio file: {tmp_file.name}")
780
-
781
- # Run assessment using Wav2Vec2 Character model
782
- result = assessor.assess_pronunciation(tmp_file.name, reference_text)
783
-
784
-
785
  # Add processing time
786
  processing_time = time.time() - start_time
787
  result["processing_info"]["processing_time"] = processing_time
788
-
789
  # Convert numpy types for JSON serialization
790
  final_result = convert_numpy_types(result)
791
-
792
- print(f"Assessment completed in {processing_time:.2f} seconds")
793
-
 
 
794
  return PronunciationAssessmentResult(**final_result)
795
-
796
  except Exception as e:
797
- print(f"Assessment error: {str(e)}")
798
  import traceback
 
799
  traceback.print_exc()
800
  raise HTTPException(status_code=500, detail=f"Assessment failed: {str(e)}")
801
 
 
802
  # =============================================================================
803
  # UTILITY ENDPOINTS
804
  # =============================================================================
805
 
 
806
  @router.get("/phonemes/{word}")
807
  async def get_word_phonemes(word: str):
808
  """Get phoneme breakdown for a specific word"""
809
  try:
810
  g2p = SimpleG2P()
811
  phoneme_data = g2p.text_to_phonemes(word)[0]
812
-
813
  # Add difficulty analysis for Vietnamese speakers
814
  difficulty_scores = []
815
  comparator = PhonemeComparator()
816
-
817
  for phoneme in phoneme_data["phonemes"]:
818
  difficulty = comparator.difficulty_map.get(phoneme, 0.3)
819
  difficulty_scores.append(difficulty)
820
-
821
  avg_difficulty = float(np.mean(difficulty_scores)) if difficulty_scores else 0.3
822
-
823
  return {
824
  "word": word,
825
  "phonemes": phoneme_data["phonemes"],
826
  "phoneme_string": phoneme_data["phoneme_string"],
827
  "ipa": phoneme_data["ipa"],
828
  "difficulty_score": avg_difficulty,
829
- "difficulty_level": "hard" if avg_difficulty > 0.6 else "medium" if avg_difficulty > 0.4 else "easy",
 
 
 
 
830
  "challenging_phonemes": [
831
  {
832
  "phoneme": p,
833
  "difficulty": comparator.difficulty_map.get(p, 0.3),
834
- "vietnamese_tip": get_vietnamese_tip(p)
835
  }
836
  for p in phoneme_data["phonemes"]
837
  if comparator.difficulty_map.get(p, 0.3) > 0.6
838
- ]
839
- }
840
-
841
- except Exception as e:
842
- raise HTTPException(status_code=500, detail=f"Word analysis error: {str(e)}")
843
-
844
- @router.get("/health")
845
- async def health_check():
846
- """Health check endpoint"""
847
- try:
848
- model_info = {
849
- "status": "healthy",
850
- "model": assessor.asr.model_name,
851
- "character_based": True,
852
- "language_model_correction": False,
853
- "vietnamese_optimized": True
854
- }
855
- return model_info
856
- except Exception as e:
857
- return {
858
- "status": "error",
859
- "error": str(e)
860
  }
861
 
862
- @router.get("/test-model")
863
- async def test_model():
864
- """Test if Wav2Vec2 model is working"""
865
- try:
866
- # Test model info
867
- test_result = {
868
- "model_loaded": True,
869
- "model_name": assessor.asr.model_name,
870
- "processor_ready": True,
871
- "sample_rate": assessor.asr.sample_rate,
872
- "sample_characters": "this is a test",
873
- "sample_phonemes": "ðɪs ɪz ə tɛst"
874
- }
875
- return test_result
876
  except Exception as e:
877
- return {
878
- "model_loaded": False,
879
- "error": str(e)
880
- }
881
 
882
- # =============================================================================
883
- # HELPER FUNCTIONS
884
- # =============================================================================
885
 
886
  def get_vietnamese_tip(phoneme: str) -> str:
887
  """Get Vietnamese pronunciation tip for a phoneme"""
@@ -889,10 +176,10 @@ def get_vietnamese_tip(phoneme: str) -> str:
889
  "θ": "Đặt lưỡi giữa răng, thổi nhẹ",
890
  "ð": "Giống θ nhưng rung dây thanh âm",
891
  "v": "Môi dưới chạm răng trên",
892
- "r": "Cuộn lưỡi, không chạm vòm miệng",
893
  "l": "Lưỡi chạm vòm miệng sau răng",
894
  "z": "Như 's' nhưng rung dây thanh",
895
  "ʒ": "Như 'ʃ' nhưng rung dây thanh",
896
- "w": "Tròn môi như 'u'"
897
  }
898
  return tips.get(phoneme, f"Luyện âm {phoneme}")
 
1
+ from fastapi import UploadFile, File, Form, HTTPException, APIRouter
 
 
 
 
 
2
  from pydantic import BaseModel
3
+ from typing import List, Dict
4
  import tempfile
 
5
  import numpy as np
 
 
 
 
6
  import re
 
7
  import warnings
8
+ from loguru import logger
9
+ from src.apis.controllers.speaking_controller import (
10
+ SimpleG2P,
11
+ PhonemeComparator,
12
+ SimplePronunciationAssessor,
13
+ convert_numpy_types,
14
+ )
15
 
16
  warnings.filterwarnings("ignore")
17
 
 
 
 
 
 
 
 
 
 
 
 
18
  router = APIRouter(prefix="/pronunciation", tags=["Pronunciation"])
19
 
20
+
21
  class PronunciationAssessmentResult(BaseModel):
22
  transcript: str # What the user actually said (character transcript)
23
  transcript_phonemes: str # User's phonemes
 
30
  feedback: List[str]
31
  processing_info: Dict
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
 
34
  assessor = SimplePronunciationAssessor()
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  @router.post("/assess", response_model=PronunciationAssessmentResult)
38
  async def assess_pronunciation(
39
  audio: UploadFile = File(..., description="Audio file (.wav, .mp3, .m4a)"),
40
+ reference_text: str = Form(..., description="Reference text to pronounce"),
41
+ mode: str = Form(
42
+ "normal",
43
+ description="Assessment mode: 'normal' (Whisper) or 'advanced' (Wav2Vec2)",
44
+ ),
45
  ):
46
  """
47
+ Pronunciation Assessment API with mode selection
48
+
49
  Key Features:
50
+ - Normal mode: Uses Whisper for more accurate transcription with language model
51
+ - Advanced mode: Uses facebook/wav2vec2-large-960h-lv60-self for character transcription
52
+ - NO language model correction in advanced mode (shows actual pronunciation errors)
53
  - Character-level accuracy converted to phoneme representation
54
  - Vietnamese-optimized feedback and tips
55
+
56
+ Input: Audio file + Reference text + Mode
57
  Output: Word highlights + Phoneme differences + Wrong words
58
  """
59
+
60
  import time
61
+
62
  start_time = time.time()
63
+
64
+ # Validate mode
65
+ if mode not in ["normal", "advanced"]:
66
+ raise HTTPException(
67
+ status_code=400, detail="Mode must be 'normal' or 'advanced'"
68
+ )
69
+
70
  # Validate inputs
71
  if not reference_text.strip():
72
  raise HTTPException(status_code=400, detail="Reference text cannot be empty")
73
+
74
  if len(reference_text) > 500:
75
+ raise HTTPException(
76
+ status_code=400, detail="Reference text too long (max 500 characters)"
77
+ )
78
+
79
  # Check for valid English characters
80
  if not re.match(r"^[a-zA-Z\s\'\-\.!?,;:]+$", reference_text):
81
  raise HTTPException(
82
  status_code=400,
83
+ detail="Text must contain only English letters, spaces, and basic punctuation",
84
  )
85
+
86
  try:
87
  # Save uploaded file temporarily
88
  file_extension = ".wav"
89
  if audio.filename and "." in audio.filename:
90
  file_extension = f".{audio.filename.split('.')[-1]}"
91
+
92
+ with tempfile.NamedTemporaryFile(
93
+ delete=False, suffix=file_extension
94
+ ) as tmp_file:
95
  content = await audio.read()
96
  tmp_file.write(content)
97
  tmp_file.flush()
98
+
99
+ logger.info(f"Processing audio file: {tmp_file.name} with mode: {mode}")
100
+
101
+ # Run assessment using selected mode
102
+ result = assessor.assess_pronunciation(tmp_file.name, reference_text, mode)
103
+
 
104
  # Add processing time
105
  processing_time = time.time() - start_time
106
  result["processing_info"]["processing_time"] = processing_time
107
+
108
  # Convert numpy types for JSON serialization
109
  final_result = convert_numpy_types(result)
110
+
111
+ logger.info(
112
+ f"Assessment completed in {processing_time:.2f} seconds using {mode} mode"
113
+ )
114
+
115
  return PronunciationAssessmentResult(**final_result)
116
+
117
  except Exception as e:
118
+ logger.error(f"Assessment error: {str(e)}")
119
  import traceback
120
+
121
  traceback.print_exc()
122
  raise HTTPException(status_code=500, detail=f"Assessment failed: {str(e)}")
123
 
124
+
125
  # =============================================================================
126
  # UTILITY ENDPOINTS
127
  # =============================================================================
128
 
129
+
130
  @router.get("/phonemes/{word}")
131
  async def get_word_phonemes(word: str):
132
  """Get phoneme breakdown for a specific word"""
133
  try:
134
  g2p = SimpleG2P()
135
  phoneme_data = g2p.text_to_phonemes(word)[0]
136
+
137
  # Add difficulty analysis for Vietnamese speakers
138
  difficulty_scores = []
139
  comparator = PhonemeComparator()
140
+
141
  for phoneme in phoneme_data["phonemes"]:
142
  difficulty = comparator.difficulty_map.get(phoneme, 0.3)
143
  difficulty_scores.append(difficulty)
144
+
145
  avg_difficulty = float(np.mean(difficulty_scores)) if difficulty_scores else 0.3
146
+
147
  return {
148
  "word": word,
149
  "phonemes": phoneme_data["phonemes"],
150
  "phoneme_string": phoneme_data["phoneme_string"],
151
  "ipa": phoneme_data["ipa"],
152
  "difficulty_score": avg_difficulty,
153
+ "difficulty_level": (
154
+ "hard"
155
+ if avg_difficulty > 0.6
156
+ else "medium" if avg_difficulty > 0.4 else "easy"
157
+ ),
158
  "challenging_phonemes": [
159
  {
160
  "phoneme": p,
161
  "difficulty": comparator.difficulty_map.get(p, 0.3),
162
+ "vietnamese_tip": get_vietnamese_tip(p),
163
  }
164
  for p in phoneme_data["phonemes"]
165
  if comparator.difficulty_map.get(p, 0.3) > 0.6
166
+ ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  }
168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  except Exception as e:
170
+ raise HTTPException(status_code=500, detail=f"Word analysis error: {str(e)}")
 
 
 
171
 
 
 
 
172
 
173
  def get_vietnamese_tip(phoneme: str) -> str:
174
  """Get Vietnamese pronunciation tip for a phoneme"""
 
176
  "θ": "Đặt lưỡi giữa răng, thổi nhẹ",
177
  "ð": "Giống θ nhưng rung dây thanh âm",
178
  "v": "Môi dưới chạm răng trên",
179
+ "r": "Cuộn lưỡi, không chạm vòm miệng",
180
  "l": "Lưỡi chạm vòm miệng sau răng",
181
  "z": "Như 's' nhưng rung dây thanh",
182
  "ʒ": "Như 'ʃ' nhưng rung dây thanh",
183
+ "w": "Tròn môi như 'u'",
184
  }
185
  return tips.get(phoneme, f"Luyện âm {phoneme}")
src/model_convert/wav2vec2onnx.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import onnx
3
+ import onnxruntime
4
+ import numpy as np
5
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
6
+ from typing import Dict, Tuple
7
+ import librosa
8
+ import os
9
+
10
+ class Wav2Vec2ONNXConverter:
11
+ """Convert Wav2Vec2 model to ONNX format"""
12
+
13
+ def __init__(self, model_name: str = "facebook/wav2vec2-base-960h"):
14
+ """Initialize the converter with the specified model"""
15
+ print(f"Loading Wav2Vec2 model: {model_name}")
16
+ self.model_name = model_name
17
+ self.processor = Wav2Vec2Processor.from_pretrained(model_name)
18
+ self.model = Wav2Vec2ForCTC.from_pretrained(model_name)
19
+
20
+ # Disable flash attention and scaled_dot_product_attention for ONNX compatibility
21
+ if hasattr(self.model.config, 'use_flash_attention_2'):
22
+ self.model.config.use_flash_attention_2 = False
23
+
24
+ # Force model to use standard attention
25
+ if hasattr(self.model, 'wav2vec2') and hasattr(self.model.wav2vec2, 'encoder'):
26
+ for layer in self.model.wav2vec2.encoder.layers:
27
+ if hasattr(layer.attention, 'attention_dropout'):
28
+ # Ensure standard attention is used
29
+ layer.attention.attention_dropout = torch.nn.Dropout(layer.attention.attention_dropout.p)
30
+
31
+ self.model.eval()
32
+ self.sample_rate = 16000
33
+ print("Model loaded successfully")
34
+
35
+ def convert_to_onnx(self,
36
+ onnx_path: str = "wav2vec2_model.onnx",
37
+ input_length: int = 160000, # 10 seconds at 16kHz
38
+ opset_version: int = 14) -> str:
39
+ """
40
+ Convert the Wav2Vec2 model to ONNX format
41
+
42
+ Args:
43
+ onnx_path: Path to save the ONNX model
44
+ input_length: Length of input audio (samples)
45
+ opset_version: ONNX opset version
46
+
47
+ Returns:
48
+ Path to the saved ONNX model
49
+ """
50
+ print(f"Converting model to ONNX format...")
51
+
52
+ # Create dummy input
53
+ dummy_input = torch.randn(1, input_length, dtype=torch.float32)
54
+
55
+ # Input names and dynamic axes
56
+ input_names = ["input_values"]
57
+ output_names = ["logits"]
58
+
59
+ # Dynamic axes for variable length input
60
+ dynamic_axes = {
61
+ "input_values": {0: "batch_size", 1: "sequence_length"},
62
+ "logits": {0: "batch_size", 1: "sequence_length"}
63
+ }
64
+
65
+ try:
66
+ # Disable torch optimizations that may cause ONNX issues
67
+ with torch.no_grad():
68
+ # Set model to evaluation mode and disable dropout
69
+ self.model.eval()
70
+ for module in self.model.modules():
71
+ if isinstance(module, torch.nn.Dropout):
72
+ module.p = 0.0
73
+
74
+ # Export to ONNX
75
+ torch.onnx.export(
76
+ self.model,
77
+ dummy_input,
78
+ onnx_path,
79
+ input_names=input_names,
80
+ output_names=output_names,
81
+ dynamic_axes=dynamic_axes,
82
+ opset_version=opset_version,
83
+ do_constant_folding=True,
84
+ verbose=False,
85
+ export_params=True,
86
+ training=torch.onnx.TrainingMode.EVAL,
87
+ operator_export_type=torch.onnx.OperatorExportTypes.ONNX
88
+ )
89
+
90
+ print(f"Model successfully exported to: {onnx_path}")
91
+
92
+ # Verify the exported model
93
+ self._verify_onnx_model(onnx_path, dummy_input)
94
+
95
+ return onnx_path
96
+
97
+ except Exception as e:
98
+ print(f"Error during ONNX conversion: {e}")
99
+ raise
100
+
101
+ def _verify_onnx_model(self, onnx_path: str, test_input: torch.Tensor):
102
+ """Verify the exported ONNX model"""
103
+ print("Verifying ONNX model...")
104
+
105
+ try:
106
+ # Load and check ONNX model
107
+ onnx_model = onnx.load(onnx_path)
108
+ onnx.checker.check_model(onnx_model)
109
+ print("✓ ONNX model structure is valid")
110
+
111
+ # Test inference with ONNX Runtime
112
+ ort_session = onnxruntime.InferenceSession(onnx_path)
113
+
114
+ # Get model input/output info
115
+ input_name = ort_session.get_inputs()[0].name
116
+ output_name = ort_session.get_outputs()[0].name
117
+
118
+ print(f"✓ Input name: {input_name}")
119
+ print(f"✓ Output name: {output_name}")
120
+
121
+ # Run inference
122
+ ort_inputs = {input_name: test_input.numpy()}
123
+ ort_outputs = ort_session.run([output_name], ort_inputs)
124
+
125
+ # Compare with original PyTorch model
126
+ with torch.no_grad():
127
+ torch_output = self.model(test_input)
128
+ torch_logits = torch_output.logits
129
+
130
+ # Check output similarity
131
+ onnx_logits = ort_outputs[0]
132
+ max_diff = np.max(np.abs(torch_logits.numpy() - onnx_logits))
133
+
134
+ print(f"✓ Maximum difference between PyTorch and ONNX: {max_diff:.6f}")
135
+
136
+ if max_diff < 1e-4:
137
+ print("✓ ONNX model verification successful!")
138
+ else:
139
+ print("⚠ Warning: Large difference detected between models")
140
+
141
+ except Exception as e:
142
+ print(f"Error during verification: {e}")
143
+ raise
144
+
145
+ class Wav2Vec2ONNXInference:
146
+ """ONNX inference class for Wav2Vec2"""
147
+
148
+ def __init__(self, onnx_path: str, processor_name: str = "facebook/wav2vec2-base-960h"):
149
+ """Initialize ONNX inference"""
150
+ print(f"Loading ONNX model from: {onnx_path}")
151
+
152
+ # Load processor for tokenization
153
+ self.processor = Wav2Vec2Processor.from_pretrained(processor_name)
154
+
155
+ # Create ONNX Runtime session
156
+ self.session = onnxruntime.InferenceSession(onnx_path)
157
+ self.input_name = self.session.get_inputs()[0].name
158
+ self.output_name = self.session.get_outputs()[0].name
159
+ self.sample_rate = 16000
160
+
161
+ print("ONNX model loaded successfully")
162
+
163
+ def transcribe(self, audio_path: str) -> Dict:
164
+ """Transcribe audio using ONNX model"""
165
+ try:
166
+ # Load audio
167
+ speech, sr = librosa.load(audio_path, sr=self.sample_rate)
168
+
169
+ # Prepare input
170
+ input_values = self.processor(
171
+ speech,
172
+ sampling_rate=self.sample_rate,
173
+ return_tensors="np"
174
+ ).input_values
175
+
176
+ # Run ONNX inference
177
+ ort_inputs = {self.input_name: input_values}
178
+ ort_outputs = self.session.run([self.output_name], ort_inputs)
179
+ logits = ort_outputs[0]
180
+
181
+ # Decode predictions
182
+ predicted_ids = np.argmax(logits, axis=-1)
183
+ transcription = self.processor.batch_decode(predicted_ids)[0]
184
+
185
+ # Calculate confidence scores
186
+ confidence_scores = np.max(self._softmax(logits), axis=-1)[0]
187
+
188
+ return {
189
+ "transcription": transcription,
190
+ "confidence_scores": confidence_scores[:100].tolist(), # Limit for JSON
191
+ "predicted_ids": predicted_ids[0].tolist()
192
+ }
193
+
194
+ except Exception as e:
195
+ print(f"Transcription error: {e}")
196
+ return {
197
+ "transcription": "",
198
+ "confidence_scores": [],
199
+ "predicted_ids": []
200
+ }
201
+
202
+ def _softmax(self, x):
203
+ """Apply softmax to logits"""
204
+ exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
205
+ return exp_x / np.sum(exp_x, axis=-1, keepdims=True)
206
+
207
+ # Example usage and testing
208
+ def main():
209
+ """Example usage of the converter"""
210
+
211
+ # Method 1: Try standard conversion
212
+ try:
213
+ print("Method 1: Standard conversion...")
214
+ converter = Wav2Vec2ONNXConverter("facebook/wav2vec2-base-960h")
215
+ onnx_path = converter.convert_to_onnx(
216
+ onnx_path="wav2vec2_asr.onnx",
217
+ input_length=160000, # 10 seconds
218
+ opset_version=14 # Updated to version 14 for compatibility
219
+ )
220
+ print("✓ Standard conversion successful!")
221
+
222
+ except Exception as e:
223
+ print(f"✗ Standard conversion failed: {e}")
224
+ print("\nMethod 2: Trying fallback approach...")
225
+
226
+ try:
227
+ # Method 2: Use compatible model creation
228
+ model, processor = create_compatible_model("facebook/wav2vec2-base-960h")
229
+ onnx_path = export_with_fallback(
230
+ model,
231
+ processor,
232
+ "wav2vec2_asr_fallback.onnx",
233
+ input_length=160000
234
+ )
235
+ print("✓ Fallback conversion successful!")
236
+
237
+ except Exception as e2:
238
+ print(f"✗ All conversion methods failed: {e2}")
239
+ return
240
+
241
+ # Test ONNX inference
242
+ print("\nTesting ONNX inference...")
243
+ try:
244
+ onnx_inference = Wav2Vec2ONNXInference(onnx_path)
245
+ print("✓ ONNX model loaded successfully for inference")
246
+
247
+ # Create a test audio file (or use your own)
248
+ # result = onnx_inference.transcribe("test_audio.wav")
249
+ # print("Transcription:", result["transcription"])
250
+
251
+ except Exception as e:
252
+ print(f"✗ ONNX inference test failed: {e}")
253
+
254
+ print("Conversion process completed!")
255
+
256
+ # Additional utility functions
257
+ def create_compatible_model(model_name: str = "facebook/wav2vec2-base-960h"):
258
+ """Create a Wav2Vec2 model compatible with ONNX export"""
259
+ from transformers import Wav2Vec2Config
260
+
261
+ # Load config and modify for ONNX compatibility
262
+ config = Wav2Vec2Config.from_pretrained(model_name)
263
+
264
+ # Disable features that may cause ONNX issues
265
+ if hasattr(config, 'use_flash_attention_2'):
266
+ config.use_flash_attention_2 = False
267
+ if hasattr(config, 'torch_dtype'):
268
+ config.torch_dtype = torch.float32
269
+
270
+ # Load model with modified config
271
+ model = Wav2Vec2ForCTC.from_pretrained(model_name, config=config, torch_dtype=torch.float32)
272
+ processor = Wav2Vec2Processor.from_pretrained(model_name)
273
+
274
+ return model, processor
275
+
276
+ def export_with_fallback(model, processor, onnx_path: str, input_length: int = 160000):
277
+ """Export model with fallback options for different opset versions"""
278
+
279
+ dummy_input = torch.randn(1, input_length, dtype=torch.float32)
280
+ input_names = ["input_values"]
281
+ output_names = ["logits"]
282
+
283
+ dynamic_axes = {
284
+ "input_values": {0: "batch_size", 1: "sequence_length"},
285
+ "logits": {0: "batch_size", 1: "sequence_length"}
286
+ }
287
+
288
+ # Try different opset versions
289
+ opset_versions = [14, 13, 12, 11]
290
+
291
+ for opset_version in opset_versions:
292
+ try:
293
+ print(f"Trying ONNX export with opset version {opset_version}...")
294
+
295
+ with torch.no_grad():
296
+ model.eval()
297
+
298
+ # Disable all dropouts
299
+ for module in model.modules():
300
+ if isinstance(module, torch.nn.Dropout):
301
+ module.p = 0.0
302
+
303
+ torch.onnx.export(
304
+ model,
305
+ dummy_input,
306
+ onnx_path,
307
+ input_names=input_names,
308
+ output_names=output_names,
309
+ dynamic_axes=dynamic_axes,
310
+ opset_version=opset_version,
311
+ do_constant_folding=True,
312
+ verbose=False,
313
+ export_params=True,
314
+ training=torch.onnx.TrainingMode.EVAL
315
+ )
316
+
317
+ print(f"✓ Successfully exported with opset version {opset_version}")
318
+ return onnx_path
319
+
320
+ except Exception as e:
321
+ print(f"✗ Failed with opset {opset_version}: {str(e)[:100]}...")
322
+ continue
323
+
324
+ raise Exception("Failed to export with all attempted opset versions")
325
+ def optimize_onnx_model(onnx_path: str, optimized_path: str = None):
326
+ """Optimize ONNX model for inference"""
327
+ try:
328
+ from onnxruntime.tools import optimizer
329
+
330
+ if optimized_path is None:
331
+ optimized_path = onnx_path.replace(".onnx", "_optimized.onnx")
332
+
333
+ # Optimize model
334
+ opt_model = optimizer.optimize_model(
335
+ onnx_path,
336
+ model_type="bert", # Similar architecture
337
+ num_heads=12,
338
+ hidden_size=768
339
+ )
340
+
341
+ opt_model.save_model_to_file(optimized_path)
342
+ print(f"Optimized model saved to: {optimized_path}")
343
+
344
+ return optimized_path
345
+
346
+ except ImportError:
347
+ print("ONNX Runtime tools not available for optimization")
348
+ return onnx_path
349
+ except Exception as e:
350
+ print(f"Optimization error: {e}")
351
+ return onnx_path
352
+
353
+ def compare_models(original_converter, onnx_inference, test_audio_path: str):
354
+ """Compare PyTorch and ONNX model outputs"""
355
+ print("Comparing PyTorch vs ONNX outputs...")
356
+
357
+ # PyTorch inference
358
+ torch_result = original_converter.transcribe_to_characters(test_audio_path)
359
+
360
+ # ONNX inference
361
+ onnx_result = onnx_inference.transcribe(test_audio_path)
362
+
363
+ print(f"PyTorch transcription: {torch_result['character_transcript']}")
364
+ print(f"ONNX transcription: {onnx_result['transcription']}")
365
+
366
+ # Compare similarity
367
+ if torch_result['character_transcript'] == onnx_result['transcription']:
368
+ print("✓ Transcriptions match exactly!")
369
+ else:
370
+ print("⚠ Transcriptions differ")
371
+
372
+ if __name__ == "__main__":
373
+ main()
src/utils/helper.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def convert_numpy_types(obj):
5
+ """Convert numpy types to Python native types"""
6
+ if isinstance(obj, np.integer):
7
+ return int(obj)
8
+ elif isinstance(obj, np.floating):
9
+ return float(obj)
10
+ elif isinstance(obj, np.ndarray):
11
+ return obj.tolist()
12
+ elif isinstance(obj, dict):
13
+ return {key: convert_numpy_types(value) for key, value in obj.items()}
14
+ elif isinstance(obj, list):
15
+ return [convert_numpy_types(item) for item in obj]
16
+ else:
17
+ return obj