ABAO77 commited on
Commit
1a5420f
·
1 Parent(s): df380ff

fix: update import statement for BaseModel in agent.py and add timing logs in speaking_controller.py

Browse files
src/agents/evaluation/agent.py CHANGED
@@ -1,5 +1,5 @@
1
  from langchain_core.prompts import ChatPromptTemplate
2
- from langchain_core.pydantic_v1 import BaseModel, Field
3
  from src.config.llm import model
4
  from src.utils.logger import logger
5
  from .prompt import evaluation_prompt
 
1
  from langchain_core.prompts import ChatPromptTemplate
2
+ from pydantic import BaseModel, Field
3
  from src.config.llm import model
4
  from src.utils.logger import logger
5
  from .prompt import evaluation_prompt
src/apis/controllers/speaking_controller.py CHANGED
@@ -17,8 +17,8 @@ from transformers import WhisperProcessor, WhisperForConditionalGeneration
17
  from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
18
  from loguru import logger
19
  import onnxruntime
 
20
 
21
- warnings.filterwarnings("ignore")
22
 
23
  # Download required NLTK data
24
  try:
@@ -66,7 +66,8 @@ class WhisperASR:
66
  Returns transcript and confidence score
67
  """
68
  try:
69
- # Load audio
 
70
  audio, sr = librosa.load(audio_path, sr=self.sample_rate)
71
 
72
  # Process audio
@@ -95,7 +96,7 @@ class WhisperASR:
95
  # Convert to phoneme representation for comparison
96
  g2p = SimpleG2P()
97
  phoneme_representation = g2p.get_reference_phoneme_string(transcript)
98
-
99
  return {
100
  "character_transcript": transcript,
101
  "phoneme_representation": phoneme_representation,
@@ -179,49 +180,7 @@ class Wav2Vec2CharacterASRONNX:
179
 
180
  except ImportError as e:
181
  print(f"Error importing Wav2Vec2ONNXConverter: {e}")
182
- # Fallback: use the convert_to_onnx.py directly if wav2vec2onnx.py doesn't work
183
- self._fallback_create_onnx_model(onnx_model_path, processor_name)
184
-
185
- except Exception as e:
186
- print(f"Error creating ONNX model: {e}")
187
- # Try fallback method
188
- self._fallback_create_onnx_model(onnx_model_path, processor_name)
189
-
190
- def _fallback_create_onnx_model(self, onnx_model_path: str, processor_name: str):
191
- """Fallback method to create ONNX model using basic torch.onnx.export"""
192
- try:
193
- print("Using fallback method to create ONNX model...")
194
-
195
- # Load PyTorch model
196
- model = Wav2Vec2ForCTC.from_pretrained(processor_name)
197
- model.eval()
198
-
199
- # Create dummy input
200
- dummy_input = torch.randn(1, 160000, dtype=torch.float32)
201
-
202
- # Export to ONNX
203
- with torch.no_grad():
204
- torch.onnx.export(
205
- model,
206
- dummy_input,
207
- onnx_model_path,
208
- input_names=["input_values"],
209
- output_names=["logits"],
210
- dynamic_axes={
211
- "input_values": {0: "batch_size", 1: "sequence_length"},
212
- "logits": {0: "batch_size", 1: "sequence_length"},
213
- },
214
- opset_version=14,
215
- do_constant_folding=True,
216
- verbose=False,
217
- export_params=True,
218
- )
219
-
220
- print(f"✓ Fallback ONNX model created at: {onnx_model_path}")
221
-
222
- except Exception as e:
223
- print(f"Fallback method also failed: {e}")
224
- raise Exception(f"Could not create ONNX model: {e}")
225
 
226
  def transcribe_to_characters(self, audio_path: str) -> Dict:
227
  """
@@ -230,6 +189,7 @@ class Wav2Vec2CharacterASRONNX:
230
  """
231
  try:
232
  # Load audio
 
233
  speech, sr = librosa.load(audio_path, sr=self.sample_rate)
234
 
235
  # Prepare input for ONNX
@@ -261,6 +221,9 @@ class Wav2Vec2CharacterASRONNX:
261
 
262
  # Calculate confidence scores
263
  confidence_scores = self._calculate_confidence_scores(logits)
 
 
 
264
 
265
  return {
266
  "character_transcript": character_transcript,
@@ -934,6 +897,7 @@ class SimplePronunciationAssessor:
934
  print("Step 1: Using Whisper transcription...")
935
  asr_result = self.whisper_asr.transcribe_to_text(audio_path)
936
  model_info = f"Whisper ({self.whisper_asr.model_name})"
 
937
 
938
  character_transcript = asr_result["character_transcript"]
939
  phoneme_representation = asr_result["phoneme_representation"]
 
17
  from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
18
  from loguru import logger
19
  import onnxruntime
20
+ import time
21
 
 
22
 
23
  # Download required NLTK data
24
  try:
 
66
  Returns transcript and confidence score
67
  """
68
  try:
69
+
70
+ start_time = time.time()
71
  audio, sr = librosa.load(audio_path, sr=self.sample_rate)
72
 
73
  # Process audio
 
96
  # Convert to phoneme representation for comparison
97
  g2p = SimpleG2P()
98
  phoneme_representation = g2p.get_reference_phoneme_string(transcript)
99
+ logger.info(f"Whisper transcription time: {time.time() - start_time:.2f}s")
100
  return {
101
  "character_transcript": transcript,
102
  "phoneme_representation": phoneme_representation,
 
180
 
181
  except ImportError as e:
182
  print(f"Error importing Wav2Vec2ONNXConverter: {e}")
183
+ raise e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
  def transcribe_to_characters(self, audio_path: str) -> Dict:
186
  """
 
189
  """
190
  try:
191
  # Load audio
192
+ start_time = time.time()
193
  speech, sr = librosa.load(audio_path, sr=self.sample_rate)
194
 
195
  # Prepare input for ONNX
 
221
 
222
  # Calculate confidence scores
223
  confidence_scores = self._calculate_confidence_scores(logits)
224
+ logger.info(
225
+ f"Wav2Vec2 ONNX transcription time: {time.time() - start_time:.2f}s"
226
+ )
227
 
228
  return {
229
  "character_transcript": character_transcript,
 
897
  print("Step 1: Using Whisper transcription...")
898
  asr_result = self.whisper_asr.transcribe_to_text(audio_path)
899
  model_info = f"Whisper ({self.whisper_asr.model_name})"
900
+ print(f"Whisper ASR result: {asr_result}")
901
 
902
  character_transcript = asr_result["character_transcript"]
903
  phoneme_representation = asr_result["phoneme_representation"]