Spaces:
Running
Running
fix: update import statement for BaseModel in agent.py and add timing logs in speaking_controller.py
Browse files
src/agents/evaluation/agent.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from langchain_core.prompts import ChatPromptTemplate
|
2 |
-
from
|
3 |
from src.config.llm import model
|
4 |
from src.utils.logger import logger
|
5 |
from .prompt import evaluation_prompt
|
|
|
1 |
from langchain_core.prompts import ChatPromptTemplate
|
2 |
+
from pydantic import BaseModel, Field
|
3 |
from src.config.llm import model
|
4 |
from src.utils.logger import logger
|
5 |
from .prompt import evaluation_prompt
|
src/apis/controllers/speaking_controller.py
CHANGED
@@ -17,8 +17,8 @@ from transformers import WhisperProcessor, WhisperForConditionalGeneration
|
|
17 |
from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
|
18 |
from loguru import logger
|
19 |
import onnxruntime
|
|
|
20 |
|
21 |
-
warnings.filterwarnings("ignore")
|
22 |
|
23 |
# Download required NLTK data
|
24 |
try:
|
@@ -66,7 +66,8 @@ class WhisperASR:
|
|
66 |
Returns transcript and confidence score
|
67 |
"""
|
68 |
try:
|
69 |
-
|
|
|
70 |
audio, sr = librosa.load(audio_path, sr=self.sample_rate)
|
71 |
|
72 |
# Process audio
|
@@ -95,7 +96,7 @@ class WhisperASR:
|
|
95 |
# Convert to phoneme representation for comparison
|
96 |
g2p = SimpleG2P()
|
97 |
phoneme_representation = g2p.get_reference_phoneme_string(transcript)
|
98 |
-
|
99 |
return {
|
100 |
"character_transcript": transcript,
|
101 |
"phoneme_representation": phoneme_representation,
|
@@ -179,49 +180,7 @@ class Wav2Vec2CharacterASRONNX:
|
|
179 |
|
180 |
except ImportError as e:
|
181 |
print(f"Error importing Wav2Vec2ONNXConverter: {e}")
|
182 |
-
|
183 |
-
self._fallback_create_onnx_model(onnx_model_path, processor_name)
|
184 |
-
|
185 |
-
except Exception as e:
|
186 |
-
print(f"Error creating ONNX model: {e}")
|
187 |
-
# Try fallback method
|
188 |
-
self._fallback_create_onnx_model(onnx_model_path, processor_name)
|
189 |
-
|
190 |
-
def _fallback_create_onnx_model(self, onnx_model_path: str, processor_name: str):
|
191 |
-
"""Fallback method to create ONNX model using basic torch.onnx.export"""
|
192 |
-
try:
|
193 |
-
print("Using fallback method to create ONNX model...")
|
194 |
-
|
195 |
-
# Load PyTorch model
|
196 |
-
model = Wav2Vec2ForCTC.from_pretrained(processor_name)
|
197 |
-
model.eval()
|
198 |
-
|
199 |
-
# Create dummy input
|
200 |
-
dummy_input = torch.randn(1, 160000, dtype=torch.float32)
|
201 |
-
|
202 |
-
# Export to ONNX
|
203 |
-
with torch.no_grad():
|
204 |
-
torch.onnx.export(
|
205 |
-
model,
|
206 |
-
dummy_input,
|
207 |
-
onnx_model_path,
|
208 |
-
input_names=["input_values"],
|
209 |
-
output_names=["logits"],
|
210 |
-
dynamic_axes={
|
211 |
-
"input_values": {0: "batch_size", 1: "sequence_length"},
|
212 |
-
"logits": {0: "batch_size", 1: "sequence_length"},
|
213 |
-
},
|
214 |
-
opset_version=14,
|
215 |
-
do_constant_folding=True,
|
216 |
-
verbose=False,
|
217 |
-
export_params=True,
|
218 |
-
)
|
219 |
-
|
220 |
-
print(f"✓ Fallback ONNX model created at: {onnx_model_path}")
|
221 |
-
|
222 |
-
except Exception as e:
|
223 |
-
print(f"Fallback method also failed: {e}")
|
224 |
-
raise Exception(f"Could not create ONNX model: {e}")
|
225 |
|
226 |
def transcribe_to_characters(self, audio_path: str) -> Dict:
|
227 |
"""
|
@@ -230,6 +189,7 @@ class Wav2Vec2CharacterASRONNX:
|
|
230 |
"""
|
231 |
try:
|
232 |
# Load audio
|
|
|
233 |
speech, sr = librosa.load(audio_path, sr=self.sample_rate)
|
234 |
|
235 |
# Prepare input for ONNX
|
@@ -261,6 +221,9 @@ class Wav2Vec2CharacterASRONNX:
|
|
261 |
|
262 |
# Calculate confidence scores
|
263 |
confidence_scores = self._calculate_confidence_scores(logits)
|
|
|
|
|
|
|
264 |
|
265 |
return {
|
266 |
"character_transcript": character_transcript,
|
@@ -934,6 +897,7 @@ class SimplePronunciationAssessor:
|
|
934 |
print("Step 1: Using Whisper transcription...")
|
935 |
asr_result = self.whisper_asr.transcribe_to_text(audio_path)
|
936 |
model_info = f"Whisper ({self.whisper_asr.model_name})"
|
|
|
937 |
|
938 |
character_transcript = asr_result["character_transcript"]
|
939 |
phoneme_representation = asr_result["phoneme_representation"]
|
|
|
17 |
from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
|
18 |
from loguru import logger
|
19 |
import onnxruntime
|
20 |
+
import time
|
21 |
|
|
|
22 |
|
23 |
# Download required NLTK data
|
24 |
try:
|
|
|
66 |
Returns transcript and confidence score
|
67 |
"""
|
68 |
try:
|
69 |
+
|
70 |
+
start_time = time.time()
|
71 |
audio, sr = librosa.load(audio_path, sr=self.sample_rate)
|
72 |
|
73 |
# Process audio
|
|
|
96 |
# Convert to phoneme representation for comparison
|
97 |
g2p = SimpleG2P()
|
98 |
phoneme_representation = g2p.get_reference_phoneme_string(transcript)
|
99 |
+
logger.info(f"Whisper transcription time: {time.time() - start_time:.2f}s")
|
100 |
return {
|
101 |
"character_transcript": transcript,
|
102 |
"phoneme_representation": phoneme_representation,
|
|
|
180 |
|
181 |
except ImportError as e:
|
182 |
print(f"Error importing Wav2Vec2ONNXConverter: {e}")
|
183 |
+
raise e
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
|
185 |
def transcribe_to_characters(self, audio_path: str) -> Dict:
|
186 |
"""
|
|
|
189 |
"""
|
190 |
try:
|
191 |
# Load audio
|
192 |
+
start_time = time.time()
|
193 |
speech, sr = librosa.load(audio_path, sr=self.sample_rate)
|
194 |
|
195 |
# Prepare input for ONNX
|
|
|
221 |
|
222 |
# Calculate confidence scores
|
223 |
confidence_scores = self._calculate_confidence_scores(logits)
|
224 |
+
logger.info(
|
225 |
+
f"Wav2Vec2 ONNX transcription time: {time.time() - start_time:.2f}s"
|
226 |
+
)
|
227 |
|
228 |
return {
|
229 |
"character_transcript": character_transcript,
|
|
|
897 |
print("Step 1: Using Whisper transcription...")
|
898 |
asr_result = self.whisper_asr.transcribe_to_text(audio_path)
|
899 |
model_info = f"Whisper ({self.whisper_asr.model_name})"
|
900 |
+
print(f"Whisper ASR result: {asr_result}")
|
901 |
|
902 |
character_transcript = asr_result["character_transcript"]
|
903 |
phoneme_representation = asr_result["phoneme_representation"]
|