ABAO77 commited on
Commit
5d88ac1
·
1 Parent(s): 9537fdb

add deepspeed

Browse files
inference.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import (
3
+ Wav2Vec2ForCTC,
4
+ Wav2Vec2Processor,
5
+ AutoProcessor,
6
+ AutoModelForCTC,
7
+ )
8
+
9
+ # import deepspeed
10
+ import librosa
11
+ import numpy as np
12
+ from typing import Optional, List, Union
13
+
14
+
15
+ def get_model_name(model_name: Optional[str] = None) -> str:
16
+ """Helper function to get model name with default fallback"""
17
+ if model_name is None:
18
+ return "facebook/wav2vec2-large-robust-ft-libri-960h"
19
+ return model_name
20
+
21
+
22
+ class Wave2Vec2Inference:
23
+ def __init__(
24
+ self,
25
+ model_name: Optional[str] = None,
26
+ use_gpu: bool = True,
27
+ use_deepspeed: bool = True,
28
+ ):
29
+ """
30
+ Initialize Wav2Vec2 model for inference with optional DeepSpeed optimization.
31
+
32
+ Args:
33
+ model_name: HuggingFace model name or None for default
34
+ use_gpu: Whether to use GPU acceleration
35
+ use_deepspeed: Whether to use DeepSpeed optimization
36
+ """
37
+ # Get the actual model name using helper function
38
+ self.model_name = get_model_name(model_name)
39
+ self.use_deepspeed = use_deepspeed
40
+
41
+ # Auto-detect device
42
+ if use_gpu:
43
+ if torch.backends.mps.is_available():
44
+ self.device = "mps"
45
+ elif torch.cuda.is_available():
46
+ self.device = "cuda"
47
+ else:
48
+ self.device = "cpu"
49
+ else:
50
+ self.device = "cpu"
51
+
52
+ print(f"Using device: {self.device}")
53
+ print(f"Loading model: {self.model_name}")
54
+ print(f"DeepSpeed enabled: {self.use_deepspeed}")
55
+
56
+ # Check if model is XLSR and use appropriate processor/model
57
+ is_xlsr = "xlsr" in self.model_name.lower()
58
+
59
+ if is_xlsr:
60
+ print("Using Wav2Vec2Processor and Wav2Vec2ForCTC for XLSR model")
61
+ self.processor = Wav2Vec2Processor.from_pretrained(self.model_name)
62
+ self.model = Wav2Vec2ForCTC.from_pretrained(self.model_name)
63
+ else:
64
+ print("Using AutoProcessor and AutoModelForCTC")
65
+ self.processor = AutoProcessor.from_pretrained(self.model_name)
66
+ self.model = AutoModelForCTC.from_pretrained(self.model_name)
67
+
68
+ # Initialize DeepSpeed if enabled
69
+ if self.use_deepspeed:
70
+ self._init_deepspeed()
71
+ else:
72
+ self.model.to(self.device)
73
+ self.model.eval()
74
+ self.ds_engine = None
75
+
76
+ # Disable gradients for inference
77
+ torch.set_grad_enabled(False)
78
+
79
+ def _init_deepspeed(self):
80
+ """Initialize DeepSpeed inference engine"""
81
+ try:
82
+ # DeepSpeed configuration based on device
83
+ if self.device == "cuda":
84
+ ds_config = {
85
+ "tensor_parallel": {"tp_size": 1},
86
+ "dtype": torch.float32,
87
+ "replace_with_kernel_inject": True,
88
+ "enable_cuda_graph": False,
89
+ }
90
+ else:
91
+ ds_config = {
92
+ "tensor_parallel": {"tp_size": 1},
93
+ "dtype": torch.float32,
94
+ "replace_with_kernel_inject": False,
95
+ "enable_cuda_graph": False,
96
+ }
97
+
98
+ print("Initializing DeepSpeed inference engine...")
99
+ self.ds_engine = deepspeed.init_inference(self.model, **ds_config)
100
+ self.ds_engine.module.to(self.device)
101
+
102
+ except Exception as e:
103
+ print(f"DeepSpeed initialization failed: {e}")
104
+ print("Falling back to standard PyTorch inference...")
105
+ self.use_deepspeed = False
106
+ self.ds_engine = None
107
+ self.model.to(self.device)
108
+ self.model.eval()
109
+
110
+ def _get_model(self):
111
+ """Get the appropriate model for inference"""
112
+ if self.use_deepspeed and self.ds_engine is not None:
113
+ return self.ds_engine.module
114
+ return self.model
115
+
116
+ def buffer_to_text(
117
+ self, audio_buffer: Union[np.ndarray, torch.Tensor, List]
118
+ ) -> str:
119
+ """
120
+ Convert audio buffer to text transcription.
121
+
122
+ Args:
123
+ audio_buffer: Audio data as numpy array, tensor, or list
124
+
125
+ Returns:
126
+ str: Transcribed text
127
+ """
128
+ if len(audio_buffer) == 0:
129
+ return ""
130
+
131
+ # Convert to tensor
132
+ if isinstance(audio_buffer, np.ndarray):
133
+ audio_tensor = torch.from_numpy(audio_buffer).float()
134
+ elif isinstance(audio_buffer, list):
135
+ audio_tensor = torch.tensor(audio_buffer, dtype=torch.float32)
136
+ else:
137
+ audio_tensor = audio_buffer.float()
138
+
139
+ # Process audio
140
+ inputs = self.processor(
141
+ audio_tensor,
142
+ sampling_rate=16_000,
143
+ return_tensors="pt",
144
+ padding=True,
145
+ )
146
+
147
+ # Move to device
148
+ input_values = inputs.input_values.to(self.device)
149
+ attention_mask = (
150
+ inputs.attention_mask.to(self.device)
151
+ if "attention_mask" in inputs
152
+ else None
153
+ )
154
+
155
+ # Get the appropriate model
156
+ model = self._get_model()
157
+
158
+ # Inference
159
+ with torch.no_grad():
160
+ if attention_mask is not None:
161
+ outputs = model(input_values, attention_mask=attention_mask)
162
+ else:
163
+ outputs = model(input_values)
164
+
165
+ # Handle different output formats
166
+ if hasattr(outputs, "logits"):
167
+ logits = outputs.logits
168
+ else:
169
+ logits = outputs
170
+
171
+ # Decode
172
+ predicted_ids = torch.argmax(logits, dim=-1)
173
+ if self.device != "cpu":
174
+ predicted_ids = predicted_ids.cpu()
175
+
176
+ transcription = self.processor.batch_decode(predicted_ids)[0]
177
+ return transcription.lower().strip()
178
+
179
+ def file_to_text(self, filename: str) -> str:
180
+ """
181
+ Transcribe audio file to text.
182
+
183
+ Args:
184
+ filename: Path to audio file
185
+
186
+ Returns:
187
+ str: Transcribed text
188
+ """
189
+ try:
190
+ audio_input, _ = librosa.load(filename, sr=16000, dtype=np.float32)
191
+ return self.buffer_to_text(audio_input)
192
+ except Exception as e:
193
+ print(f"Error loading audio file {filename}: {e}")
194
+ return ""
195
+
196
+ def batch_file_to_text(self, filenames: List[str]) -> List[str]:
197
+ """
198
+ Transcribe multiple audio files to text.
199
+
200
+ Args:
201
+ filenames: List of audio file paths
202
+
203
+ Returns:
204
+ List[str]: List of transcribed texts
205
+ """
206
+ results = []
207
+ for i, filename in enumerate(filenames):
208
+ print(f"Processing file {i+1}/{len(filenames)}: {filename}")
209
+ transcription = self.file_to_text(filename)
210
+ results.append(transcription)
211
+ if transcription:
212
+ print(f"Transcription: {transcription}")
213
+ else:
214
+ print("Failed to transcribe")
215
+ return results
216
+
217
+ def transcribe_with_confidence(
218
+ self, audio_buffer: Union[np.ndarray, torch.Tensor]
219
+ ) -> tuple:
220
+ """
221
+ Transcribe audio and return confidence scores.
222
+
223
+ Args:
224
+ audio_buffer: Audio data
225
+
226
+ Returns:
227
+ tuple: (transcription, confidence_scores)
228
+ """
229
+ if len(audio_buffer) == 0:
230
+ return "", []
231
+
232
+ # Convert to tensor
233
+ if isinstance(audio_buffer, np.ndarray):
234
+ audio_tensor = torch.from_numpy(audio_buffer).float()
235
+ else:
236
+ audio_tensor = audio_buffer.float()
237
+
238
+ # Process audio
239
+ inputs = self.processor(
240
+ audio_tensor,
241
+ sampling_rate=16_000,
242
+ return_tensors="pt",
243
+ padding=True,
244
+ )
245
+
246
+ input_values = inputs.input_values.to(self.device)
247
+ attention_mask = (
248
+ inputs.attention_mask.to(self.device)
249
+ if "attention_mask" in inputs
250
+ else None
251
+ )
252
+
253
+ model = self._get_model()
254
+
255
+ # Inference
256
+ with torch.no_grad():
257
+ if attention_mask is not None:
258
+ outputs = model(input_values, attention_mask=attention_mask)
259
+ else:
260
+ outputs = model(input_values)
261
+
262
+ if hasattr(outputs, "logits"):
263
+ logits = outputs.logits
264
+ else:
265
+ logits = outputs
266
+
267
+ # Get probabilities and confidence scores
268
+ probs = torch.nn.functional.softmax(logits, dim=-1)
269
+ predicted_ids = torch.argmax(logits, dim=-1)
270
+
271
+ # Calculate confidence as max probability for each prediction
272
+ max_probs = torch.max(probs, dim=-1)[0]
273
+ confidence_scores = max_probs.cpu().numpy().tolist()
274
+
275
+ if self.device != "cpu":
276
+ predicted_ids = predicted_ids.cpu()
277
+
278
+ transcription = self.processor.batch_decode(predicted_ids)[0]
279
+ return transcription.lower().strip(), confidence_scores
280
+
281
+ def cleanup(self):
282
+ """Clean up resources"""
283
+ if hasattr(self, "ds_engine") and self.ds_engine is not None:
284
+ del self.ds_engine
285
+ if hasattr(self, "model"):
286
+ del self.model
287
+ if hasattr(self, "processor"):
288
+ del self.processor
289
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
290
+
291
+ def __del__(self):
292
+ """Destructor to clean up resources"""
293
+ self.cleanup()
294
+
295
+
296
+ # Example usage
297
+ if __name__ == "__main__":
298
+ # Initialize with DeepSpeed
299
+ asr = Wave2Vec2Inference(
300
+ model_name="facebook/wav2vec2-large-robust-ft-libri-960h",
301
+ use_gpu=False,
302
+ use_deepspeed=False,
303
+ )
304
+
305
+ # Single file transcription
306
+ result = asr.file_to_text("./test_audio/hello_how_are_you_today.wav")
307
+ print(f"Transcription: {result}")
308
+
309
+ # # Batch processing
310
+ # files = ["audio1.wav", "audio2.wav", "audio3.wav"]
311
+ # batch_results = asr.batch_file_to_text(files)
312
+
313
+ # # Transcription with confidence scores
314
+ # audio_data, _ = librosa.load("path/to/audio.wav", sr=16000)
315
+ # transcription, confidence = asr.transcribe_with_confidence(audio_data)
316
+ # print(f"Transcription: {transcription}")
317
+ # print(f"Average confidence: {np.mean(confidence):.3f}")
318
+
319
+ # Cleanup
requirements.txt CHANGED
@@ -23,4 +23,5 @@ onnx
23
  transformers
24
  torch
25
  optimum[onnxruntime]
26
- Levenshtein
 
 
23
  transformers
24
  torch
25
  optimum[onnxruntime]
26
+ Levenshtein
27
+ deepspeed
src/AI_Models/wave2vec_inference.py CHANGED
@@ -1,63 +1,416 @@
1
- import torch
2
- from transformers import (
3
- AutoModelForCTC,
4
- AutoProcessor,
5
- Wav2Vec2Processor,
6
- Wav2Vec2ForCTC,
7
- )
8
- import onnxruntime as rt
9
- import numpy as np
10
- import librosa
11
- import warnings
12
- import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- warnings.filterwarnings("ignore")
15
 
16
- # Available Wave2Vec2 models
17
- WAVE2VEC2_MODELS = {
18
- "english_large": "jonatasgrosman/wav2vec2-large-xlsr-53-english",
19
- "multilingual": "facebook/wav2vec2-large-xlsr-53",
20
- "english_960h": "facebook/wav2vec2-large-960h-lv60-self",
21
- "base_english": "facebook/wav2vec2-base-960h",
22
- "large_english": "facebook/wav2vec2-large-960h",
23
- "xlsr_english": "jonatasgrosman/wav2vec2-large-xlsr-53-english",
24
- "xlsr_multilingual": "facebook/wav2vec2-large-xlsr-53"
25
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- # Default model
28
- DEFAULT_MODEL = "jonatasgrosman/wav2vec2-large-xlsr-53-english"
 
 
 
 
 
 
 
29
 
 
 
30
 
31
- def get_available_models():
32
- """Return dictionary of available Wave2Vec2 models"""
33
- return WAVE2VEC2_MODELS.copy()
34
 
 
 
35
 
36
- def get_model_name(model_key=None):
37
- """
38
- Get model name from key or return default
39
-
40
- Args:
41
- model_key: Key from WAVE2VEC2_MODELS or full model name
42
-
43
- Returns:
44
- str: Full model name
45
- """
46
- if model_key is None:
47
- return DEFAULT_MODEL
48
-
49
- if model_key in WAVE2VEC2_MODELS:
50
- return WAVE2VEC2_MODELS[model_key]
51
-
52
- # If it's already a full model name, return as is
53
- return model_key
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
 
56
  class Wave2Vec2Inference:
57
- def __init__(self, model_name=None, use_gpu=True):
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  # Get the actual model name using helper function
59
  self.model_name = get_model_name(model_name)
60
-
 
61
  # Auto-detect device
62
  if use_gpu:
63
  if torch.backends.mps.is_available():
@@ -71,10 +424,11 @@ class Wave2Vec2Inference:
71
 
72
  print(f"Using device: {self.device}")
73
  print(f"Loading model: {self.model_name}")
 
74
 
75
  # Check if model is XLSR and use appropriate processor/model
76
  is_xlsr = "xlsr" in self.model_name.lower()
77
-
78
  if is_xlsr:
79
  print("Using Wav2Vec2Processor and Wav2Vec2ForCTC for XLSR model")
80
  self.processor = Wav2Vec2Processor.from_pretrained(self.model_name)
@@ -83,22 +437,77 @@ class Wave2Vec2Inference:
83
  print("Using AutoProcessor and AutoModelForCTC")
84
  self.processor = AutoProcessor.from_pretrained(self.model_name)
85
  self.model = AutoModelForCTC.from_pretrained(self.model_name)
86
-
87
- self.model.to(self.device)
88
- self.model.eval()
 
 
 
 
 
89
 
90
  # Disable gradients for inference
91
  torch.set_grad_enabled(False)
92
 
93
- def buffer_to_text(self, audio_buffer):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  if len(audio_buffer) == 0:
95
  return ""
96
 
97
  # Convert to tensor
98
  if isinstance(audio_buffer, np.ndarray):
99
  audio_tensor = torch.from_numpy(audio_buffer).float()
100
- else:
101
  audio_tensor = torch.tensor(audio_buffer, dtype=torch.float32)
 
 
102
 
103
  # Process audio
104
  inputs = self.processor(
@@ -116,12 +525,21 @@ class Wave2Vec2Inference:
116
  else None
117
  )
118
 
 
 
 
119
  # Inference
120
  with torch.no_grad():
121
  if attention_mask is not None:
122
- logits = self.model(input_values, attention_mask=attention_mask).logits
123
  else:
124
- logits = self.model(input_values).logits
 
 
 
 
 
 
125
 
126
  # Decode
127
  predicted_ids = torch.argmax(logits, dim=-1)
@@ -131,7 +549,16 @@ class Wave2Vec2Inference:
131
  transcription = self.processor.batch_decode(predicted_ids)[0]
132
  return transcription.lower().strip()
133
 
134
- def file_to_text(self, filename):
 
 
 
 
 
 
 
 
 
135
  try:
136
  audio_input, _ = librosa.load(filename, sr=16000, dtype=np.float32)
137
  return self.buffer_to_text(audio_input)
@@ -139,232 +566,101 @@ class Wave2Vec2Inference:
139
  print(f"Error loading audio file {filename}: {e}")
140
  return ""
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
- class Wave2Vec2ONNXInference:
144
- def __init__(self, model_name=None, onnx_path=None, use_gpu=True):
145
- # Get the actual model name using helper function
146
- self.model_name = get_model_name(model_name)
147
- print(f"Loading ONNX model: {self.model_name}")
148
-
149
- # Always use Wav2Vec2Processor for ONNX (works for all models)
150
- self.processor = Wav2Vec2Processor.from_pretrained(self.model_name)
151
-
152
- # Setup ONNX Runtime
153
- options = rt.SessionOptions()
154
- options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL
155
-
156
- # Choose providers based on GPU availability
157
- providers = []
158
- if use_gpu and rt.get_available_providers():
159
- if "CUDAExecutionProvider" in rt.get_available_providers():
160
- providers.append("CUDAExecutionProvider")
161
- providers.append("CPUExecutionProvider")
162
 
163
- self.model = rt.InferenceSession(onnx_path, options, providers=providers)
164
- self.input_name = self.model.get_inputs()[0].name
165
- print(f"ONNX model loaded with providers: {self.model.get_providers()}")
166
 
167
- def buffer_to_text(self, audio_buffer):
 
 
168
  if len(audio_buffer) == 0:
169
- return ""
170
 
171
  # Convert to tensor
172
  if isinstance(audio_buffer, np.ndarray):
173
  audio_tensor = torch.from_numpy(audio_buffer).float()
174
  else:
175
- audio_tensor = torch.tensor(audio_buffer, dtype=torch.float32)
176
 
177
  # Process audio
178
  inputs = self.processor(
179
  audio_tensor,
180
  sampling_rate=16_000,
181
- return_tensors="np",
182
  padding=True,
183
  )
184
 
185
- # ONNX inference
186
- input_values = inputs.input_values.astype(np.float32)
187
- onnx_outputs = self.model.run(None, {self.input_name: input_values})[0]
 
 
 
 
 
188
 
189
- # Decode
190
- prediction = np.argmax(onnx_outputs, axis=-1)
191
- transcription = self.processor.decode(prediction.squeeze().tolist())
192
- return transcription.lower().strip()
 
 
193
 
194
- def file_to_text(self, filename):
195
- try:
196
- audio_input, _ = librosa.load(filename, sr=16000, dtype=np.float32)
197
- return self.buffer_to_text(audio_input)
198
- except Exception as e:
199
- print(f"Error loading audio file {filename}: {e}")
200
- return ""
201
 
 
 
 
202
 
203
- def convert_to_onnx(model_id_or_path, onnx_model_name):
204
- """Convert PyTorch model to ONNX format"""
205
- print(f"Converting {model_id_or_path} to ONNX...")
206
- model = Wav2Vec2ForCTC.from_pretrained(model_id_or_path)
207
- model.eval()
208
-
209
- # Create dummy input
210
- audio_len = 250000
211
- dummy_input = torch.randn(1, audio_len, requires_grad=True)
212
-
213
- torch.onnx.export(
214
- model,
215
- dummy_input,
216
- onnx_model_name,
217
- export_params=True,
218
- opset_version=14,
219
- do_constant_folding=True,
220
- input_names=["input"],
221
- output_names=["output"],
222
- dynamic_axes={
223
- "input": {1: "audio_len"},
224
- "output": {1: "audio_len"},
225
- },
226
- )
227
- print(f"ONNX model saved to: {onnx_model_name}")
228
-
229
-
230
- def quantize_onnx_model(onnx_model_path, quantized_model_path):
231
- """Quantize ONNX model for faster inference"""
232
- print("Starting quantization...")
233
- from onnxruntime.quantization import quantize_dynamic, QuantType
234
-
235
- quantize_dynamic(
236
- onnx_model_path, quantized_model_path, weight_type=QuantType.QUInt8
237
- )
238
- print(f"Quantized model saved to: {quantized_model_path}")
239
-
240
-
241
- def export_to_onnx(model_name, quantize=False):
242
- """
243
- Export model to ONNX format with optional quantization
244
-
245
- Args:
246
- model_name: HuggingFace model name
247
- quantize: Whether to also create quantized version
248
-
249
- Returns:
250
- tuple: (onnx_path, quantized_path or None)
251
- """
252
- onnx_filename = f"{model_name.split('/')[-1]}.onnx"
253
- convert_to_onnx(model_name, onnx_filename)
254
-
255
- quantized_path = None
256
- if quantize:
257
- quantized_path = onnx_filename.replace(".onnx", ".quantized.onnx")
258
- quantize_onnx_model(onnx_filename, quantized_path)
259
-
260
- return onnx_filename, quantized_path
261
-
262
-
263
- def create_inference(
264
- model_name=None, use_onnx=False, onnx_path=None, use_gpu=True, use_onnx_quantize=False
265
- ):
266
- """
267
- Create optimized inference instance
268
-
269
- Args:
270
- model_name: Model key from WAVE2VEC2_MODELS or full HuggingFace model name (default: uses DEFAULT_MODEL)
271
- use_onnx: Whether to use ONNX runtime
272
- onnx_path: Path to ONNX model file
273
- use_gpu: Whether to use GPU if available
274
- use_onnx_quantize: Whether to use quantized ONNX model
275
-
276
- Returns:
277
- Inference instance
278
- """
279
- # Get the actual model name
280
- actual_model_name = get_model_name(model_name)
281
-
282
- if use_onnx:
283
- if not onnx_path or not os.path.exists(onnx_path):
284
- # Convert to ONNX if path not provided or doesn't exist
285
- onnx_filename = f"{actual_model_name.split('/')[-1]}.onnx"
286
- convert_to_onnx(actual_model_name, onnx_filename)
287
- onnx_path = onnx_filename
288
-
289
- if use_onnx_quantize:
290
- quantized_path = onnx_path.replace(".onnx", ".quantized.onnx")
291
- if not os.path.exists(quantized_path):
292
- quantize_onnx_model(onnx_path, quantized_path)
293
- onnx_path = quantized_path
294
-
295
- print(f"Using ONNX model: {onnx_path}")
296
- return Wave2Vec2ONNXInference(model_name, onnx_path, use_gpu)
297
- else:
298
- print("Using PyTorch model")
299
- return Wave2Vec2Inference(model_name, use_gpu)
300
-
301
-
302
- if __name__ == "__main__":
303
- import time
304
-
305
- # Display available models
306
- print("Available Wave2Vec2 models:")
307
- for key, model_name in get_available_models().items():
308
- print(f" {key}: {model_name}")
309
- print(f"\nDefault model: {DEFAULT_MODEL}")
310
- print()
311
-
312
- # Test with different models
313
- test_models = ["english_large", "multilingual", "english_960h"]
314
- test_file = "test.wav"
315
-
316
- if not os.path.exists(test_file):
317
- print(f"Test file {test_file} not found. Please provide a valid audio file.")
318
- print("Creating example usage without actual file...")
319
-
320
- # Example usage without file
321
- print("\n=== Example Usage ===")
322
-
323
- # Using default model
324
- print("1. Using default model:")
325
- asr_default = create_inference()
326
- print(f" Model loaded: {asr_default.model_name}")
327
-
328
- # Using model key
329
- print("\n2. Using model key 'english_large':")
330
- asr_key = create_inference("english_large")
331
- print(f" Model loaded: {asr_key.model_name}")
332
-
333
- # Using full model name
334
- print("\n3. Using full model name:")
335
- asr_full = create_inference("facebook/wav2vec2-base-960h")
336
- print(f" Model loaded: {asr_full.model_name}")
337
-
338
- exit(0)
339
 
340
- # Test different model configurations
341
- for model_key in test_models:
342
- print(f"\n=== Testing model: {model_key} ===")
343
-
344
- # Test different configurations
345
- configs = [
346
- {"use_onnx": False, "use_gpu": True},
347
- {"use_onnx": True, "use_gpu": True, "use_onnx_quantize": False},
348
- ]
349
-
350
- for config in configs:
351
- print(f"\nConfig: {config}")
352
-
353
- # Create inference instance with model selection
354
- asr = create_inference(model_key, **config)
355
-
356
- # Warm up
357
- asr.file_to_text(test_file)
358
-
359
- # Test performance
360
- times = []
361
- for i in range(3):
362
- start_time = time.time()
363
- text = asr.file_to_text(test_file)
364
- end_time = time.time()
365
- execution_time = end_time - start_time
366
- times.append(execution_time)
367
- print(f"Run {i+1}: {execution_time:.3f}s - {text[:50]}...")
368
-
369
- avg_time = sum(times) / len(times)
370
- print(f"Average time: {avg_time:.3f}s")
 
1
+ # import torch
2
+ # from transformers import (
3
+ # AutoModelForCTC,
4
+ # AutoProcessor,
5
+ # Wav2Vec2Processor,
6
+ # Wav2Vec2ForCTC,
7
+ # )
8
+ # import onnxruntime as rt
9
+ # import numpy as np
10
+ # import librosa
11
+ # import warnings
12
+ # import os
13
+
14
+ # warnings.filterwarnings("ignore")
15
+
16
+ # # Available Wave2Vec2 models
17
+ # WAVE2VEC2_MODELS = {
18
+ # "english_large": "jonatasgrosman/wav2vec2-large-xlsr-53-english",
19
+ # "multilingual": "facebook/wav2vec2-large-xlsr-53",
20
+ # "english_960h": "facebook/wav2vec2-large-960h-lv60-self",
21
+ # "base_english": "facebook/wav2vec2-base-960h",
22
+ # "large_english": "facebook/wav2vec2-large-960h",
23
+ # "xlsr_english": "jonatasgrosman/wav2vec2-large-xlsr-53-english",
24
+ # "xlsr_multilingual": "facebook/wav2vec2-large-xlsr-53"
25
+ # }
26
+
27
+ # # Default model
28
+ # DEFAULT_MODEL = "jonatasgrosman/wav2vec2-large-xlsr-53-english"
29
+
30
+
31
+ # def get_available_models():
32
+ # """Return dictionary of available Wave2Vec2 models"""
33
+ # return WAVE2VEC2_MODELS.copy()
34
+
35
+
36
+ # def get_model_name(model_key=None):
37
+ # """
38
+ # Get model name from key or return default
39
+
40
+ # Args:
41
+ # model_key: Key from WAVE2VEC2_MODELS or full model name
42
+
43
+ # Returns:
44
+ # str: Full model name
45
+ # """
46
+ # if model_key is None:
47
+ # return DEFAULT_MODEL
48
+
49
+ # if model_key in WAVE2VEC2_MODELS:
50
+ # return WAVE2VEC2_MODELS[model_key]
51
+
52
+ # # If it's already a full model name, return as is
53
+ # return model_key
54
 
 
55
 
56
+ # class Wave2Vec2Inference:
57
+ # def __init__(self, model_name=None, use_gpu=True):
58
+ # # Get the actual model name using helper function
59
+ # self.model_name = get_model_name(model_name)
60
+
61
+ # # Auto-detect device
62
+ # if use_gpu:
63
+ # if torch.backends.mps.is_available():
64
+ # self.device = "mps"
65
+ # elif torch.cuda.is_available():
66
+ # self.device = "cuda"
67
+ # else:
68
+ # self.device = "cpu"
69
+ # else:
70
+ # self.device = "cpu"
71
+
72
+ # print(f"Using device: {self.device}")
73
+ # print(f"Loading model: {self.model_name}")
74
+
75
+ # # Check if model is XLSR and use appropriate processor/model
76
+ # is_xlsr = "xlsr" in self.model_name.lower()
77
+
78
+ # if is_xlsr:
79
+ # print("Using Wav2Vec2Processor and Wav2Vec2ForCTC for XLSR model")
80
+ # self.processor = Wav2Vec2Processor.from_pretrained(self.model_name)
81
+ # self.model = Wav2Vec2ForCTC.from_pretrained(self.model_name)
82
+ # else:
83
+ # print("Using AutoProcessor and AutoModelForCTC")
84
+ # self.processor = AutoProcessor.from_pretrained(self.model_name)
85
+ # self.model = AutoModelForCTC.from_pretrained(self.model_name)
86
+
87
+ # self.model.to(self.device)
88
+ # self.model.eval()
89
+
90
+ # # Disable gradients for inference
91
+ # torch.set_grad_enabled(False)
92
+
93
+ # def buffer_to_text(self, audio_buffer):
94
+ # if len(audio_buffer) == 0:
95
+ # return ""
96
+
97
+ # # Convert to tensor
98
+ # if isinstance(audio_buffer, np.ndarray):
99
+ # audio_tensor = torch.from_numpy(audio_buffer).float()
100
+ # else:
101
+ # audio_tensor = torch.tensor(audio_buffer, dtype=torch.float32)
102
+
103
+ # # Process audio
104
+ # inputs = self.processor(
105
+ # audio_tensor,
106
+ # sampling_rate=16_000,
107
+ # return_tensors="pt",
108
+ # padding=True,
109
+ # )
110
+
111
+ # # Move to device
112
+ # input_values = inputs.input_values.to(self.device)
113
+ # attention_mask = (
114
+ # inputs.attention_mask.to(self.device)
115
+ # if "attention_mask" in inputs
116
+ # else None
117
+ # )
118
+
119
+ # # Inference
120
+ # with torch.no_grad():
121
+ # if attention_mask is not None:
122
+ # logits = self.model(input_values, attention_mask=attention_mask).logits
123
+ # else:
124
+ # logits = self.model(input_values).logits
125
+
126
+ # # Decode
127
+ # predicted_ids = torch.argmax(logits, dim=-1)
128
+ # if self.device != "cpu":
129
+ # predicted_ids = predicted_ids.cpu()
130
+
131
+ # transcription = self.processor.batch_decode(predicted_ids)[0]
132
+ # return transcription.lower().strip()
133
+
134
+ # def file_to_text(self, filename):
135
+ # try:
136
+ # audio_input, _ = librosa.load(filename, sr=16000, dtype=np.float32)
137
+ # return self.buffer_to_text(audio_input)
138
+ # except Exception as e:
139
+ # print(f"Error loading audio file {filename}: {e}")
140
+ # return ""
141
+
142
+
143
+ # class Wave2Vec2ONNXInference:
144
+ # def __init__(self, model_name=None, onnx_path=None, use_gpu=True):
145
+ # # Get the actual model name using helper function
146
+ # self.model_name = get_model_name(model_name)
147
+ # print(f"Loading ONNX model: {self.model_name}")
148
+
149
+ # # Always use Wav2Vec2Processor for ONNX (works for all models)
150
+ # self.processor = Wav2Vec2Processor.from_pretrained(self.model_name)
151
+
152
+ # # Setup ONNX Runtime
153
+ # options = rt.SessionOptions()
154
+ # options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL
155
+
156
+ # # Choose providers based on GPU availability
157
+ # providers = []
158
+ # if use_gpu and rt.get_available_providers():
159
+ # if "CUDAExecutionProvider" in rt.get_available_providers():
160
+ # providers.append("CUDAExecutionProvider")
161
+ # providers.append("CPUExecutionProvider")
162
+
163
+ # self.model = rt.InferenceSession(onnx_path, options, providers=providers)
164
+ # self.input_name = self.model.get_inputs()[0].name
165
+ # print(f"ONNX model loaded with providers: {self.model.get_providers()}")
166
+
167
+ # def buffer_to_text(self, audio_buffer):
168
+ # if len(audio_buffer) == 0:
169
+ # return ""
170
+
171
+ # # Convert to tensor
172
+ # if isinstance(audio_buffer, np.ndarray):
173
+ # audio_tensor = torch.from_numpy(audio_buffer).float()
174
+ # else:
175
+ # audio_tensor = torch.tensor(audio_buffer, dtype=torch.float32)
176
+
177
+ # # Process audio
178
+ # inputs = self.processor(
179
+ # audio_tensor,
180
+ # sampling_rate=16_000,
181
+ # return_tensors="np",
182
+ # padding=True,
183
+ # )
184
+
185
+ # # ONNX inference
186
+ # input_values = inputs.input_values.astype(np.float32)
187
+ # onnx_outputs = self.model.run(None, {self.input_name: input_values})[0]
188
+
189
+ # # Decode
190
+ # prediction = np.argmax(onnx_outputs, axis=-1)
191
+ # transcription = self.processor.decode(prediction.squeeze().tolist())
192
+ # return transcription.lower().strip()
193
+
194
+ # def file_to_text(self, filename):
195
+ # try:
196
+ # audio_input, _ = librosa.load(filename, sr=16000, dtype=np.float32)
197
+ # return self.buffer_to_text(audio_input)
198
+ # except Exception as e:
199
+ # print(f"Error loading audio file {filename}: {e}")
200
+ # return ""
201
+
202
+
203
+ # def convert_to_onnx(model_id_or_path, onnx_model_name):
204
+ # """Convert PyTorch model to ONNX format"""
205
+ # print(f"Converting {model_id_or_path} to ONNX...")
206
+ # model = Wav2Vec2ForCTC.from_pretrained(model_id_or_path)
207
+ # model.eval()
208
+
209
+ # # Create dummy input
210
+ # audio_len = 250000
211
+ # dummy_input = torch.randn(1, audio_len, requires_grad=True)
212
+
213
+ # torch.onnx.export(
214
+ # model,
215
+ # dummy_input,
216
+ # onnx_model_name,
217
+ # export_params=True,
218
+ # opset_version=14,
219
+ # do_constant_folding=True,
220
+ # input_names=["input"],
221
+ # output_names=["output"],
222
+ # dynamic_axes={
223
+ # "input": {1: "audio_len"},
224
+ # "output": {1: "audio_len"},
225
+ # },
226
+ # )
227
+ # print(f"ONNX model saved to: {onnx_model_name}")
228
+
229
+
230
+ # def quantize_onnx_model(onnx_model_path, quantized_model_path):
231
+ # """Quantize ONNX model for faster inference"""
232
+ # print("Starting quantization...")
233
+ # from onnxruntime.quantization import quantize_dynamic, QuantType
234
+
235
+ # quantize_dynamic(
236
+ # onnx_model_path, quantized_model_path, weight_type=QuantType.QUInt8
237
+ # )
238
+ # print(f"Quantized model saved to: {quantized_model_path}")
239
+
240
+
241
+ # def export_to_onnx(model_name, quantize=False):
242
+ # """
243
+ # Export model to ONNX format with optional quantization
244
+
245
+ # Args:
246
+ # model_name: HuggingFace model name
247
+ # quantize: Whether to also create quantized version
248
+
249
+ # Returns:
250
+ # tuple: (onnx_path, quantized_path or None)
251
+ # """
252
+ # onnx_filename = f"{model_name.split('/')[-1]}.onnx"
253
+ # convert_to_onnx(model_name, onnx_filename)
254
+
255
+ # quantized_path = None
256
+ # if quantize:
257
+ # quantized_path = onnx_filename.replace(".onnx", ".quantized.onnx")
258
+ # quantize_onnx_model(onnx_filename, quantized_path)
259
+
260
+ # return onnx_filename, quantized_path
261
+
262
+
263
+ # def create_inference(
264
+ # model_name=None, use_onnx=False, onnx_path=None, use_gpu=True, use_onnx_quantize=False
265
+ # ):
266
+ # """
267
+ # Create optimized inference instance
268
+
269
+ # Args:
270
+ # model_name: Model key from WAVE2VEC2_MODELS or full HuggingFace model name (default: uses DEFAULT_MODEL)
271
+ # use_onnx: Whether to use ONNX runtime
272
+ # onnx_path: Path to ONNX model file
273
+ # use_gpu: Whether to use GPU if available
274
+ # use_onnx_quantize: Whether to use quantized ONNX model
275
+
276
+ # Returns:
277
+ # Inference instance
278
+ # """
279
+ # # Get the actual model name
280
+ # actual_model_name = get_model_name(model_name)
281
+
282
+ # if use_onnx:
283
+ # if not onnx_path or not os.path.exists(onnx_path):
284
+ # # Convert to ONNX if path not provided or doesn't exist
285
+ # onnx_filename = f"{actual_model_name.split('/')[-1]}.onnx"
286
+ # convert_to_onnx(actual_model_name, onnx_filename)
287
+ # onnx_path = onnx_filename
288
+
289
+ # if use_onnx_quantize:
290
+ # quantized_path = onnx_path.replace(".onnx", ".quantized.onnx")
291
+ # if not os.path.exists(quantized_path):
292
+ # quantize_onnx_model(onnx_path, quantized_path)
293
+ # onnx_path = quantized_path
294
+
295
+ # print(f"Using ONNX model: {onnx_path}")
296
+ # return Wave2Vec2ONNXInference(model_name, onnx_path, use_gpu)
297
+ # else:
298
+ # print("Using PyTorch model")
299
+ # return Wave2Vec2Inference(model_name, use_gpu)
300
+
301
+
302
+ # if __name__ == "__main__":
303
+ # import time
304
+
305
+ # # Display available models
306
+ # print("Available Wave2Vec2 models:")
307
+ # for key, model_name in get_available_models().items():
308
+ # print(f" {key}: {model_name}")
309
+ # print(f"\nDefault model: {DEFAULT_MODEL}")
310
+ # print()
311
+
312
+ # # Test with different models
313
+ # test_models = ["english_large", "multilingual", "english_960h"]
314
+ # test_file = "test.wav"
315
+
316
+ # if not os.path.exists(test_file):
317
+ # print(f"Test file {test_file} not found. Please provide a valid audio file.")
318
+ # print("Creating example usage without actual file...")
319
+
320
+ # # Example usage without file
321
+ # print("\n=== Example Usage ===")
322
+
323
+ # # Using default model
324
+ # print("1. Using default model:")
325
+ # asr_default = create_inference()
326
+ # print(f" Model loaded: {asr_default.model_name}")
327
+
328
+ # # Using model key
329
+ # print("\n2. Using model key 'english_large':")
330
+ # asr_key = create_inference("english_large")
331
+ # print(f" Model loaded: {asr_key.model_name}")
332
+
333
+ # # Using full model name
334
+ # print("\n3. Using full model name:")
335
+ # asr_full = create_inference("facebook/wav2vec2-base-960h")
336
+ # print(f" Model loaded: {asr_full.model_name}")
337
+
338
+ # exit(0)
339
 
340
+ # # Test different model configurations
341
+ # for model_key in test_models:
342
+ # print(f"\n=== Testing model: {model_key} ===")
343
+
344
+ # # Test different configurations
345
+ # configs = [
346
+ # {"use_onnx": False, "use_gpu": True},
347
+ # {"use_onnx": True, "use_gpu": True, "use_onnx_quantize": False},
348
+ # ]
349
 
350
+ # for config in configs:
351
+ # print(f"\nConfig: {config}")
352
 
353
+ # # Create inference instance with model selection
354
+ # asr = create_inference(model_key, **config)
 
355
 
356
+ # # Warm up
357
+ # asr.file_to_text(test_file)
358
 
359
+ # # Test performance
360
+ # times = []
361
+ # for i in range(3):
362
+ # start_time = time.time()
363
+ # text = asr.file_to_text(test_file)
364
+ # end_time = time.time()
365
+ # execution_time = end_time - start_time
366
+ # times.append(execution_time)
367
+ # print(f"Run {i+1}: {execution_time:.3f}s - {text[:50]}...")
368
+
369
+ # avg_time = sum(times) / len(times)
370
+ # print(f"Average time: {avg_time:.3f}s")
371
+
372
+
373
+
374
+ import torch
375
+ from transformers import (
376
+ Wav2Vec2ForCTC,
377
+ Wav2Vec2Processor,
378
+ AutoProcessor,
379
+ AutoModelForCTC,
380
+ )
381
+
382
+ import deepspeed
383
+ import librosa
384
+ import numpy as np
385
+ from typing import Optional, List, Union
386
+
387
+
388
+ def get_model_name(model_name: Optional[str] = None) -> str:
389
+ """Helper function to get model name with default fallback"""
390
+ if model_name is None:
391
+ return "facebook/wav2vec2-large-robust-ft-libri-960h"
392
+ return model_name
393
 
394
 
395
  class Wave2Vec2Inference:
396
+ def __init__(
397
+ self,
398
+ model_name: Optional[str] = None,
399
+ use_gpu: bool = True,
400
+ use_deepspeed: bool = True,
401
+ ):
402
+ """
403
+ Initialize Wav2Vec2 model for inference with optional DeepSpeed optimization.
404
+
405
+ Args:
406
+ model_name: HuggingFace model name or None for default
407
+ use_gpu: Whether to use GPU acceleration
408
+ use_deepspeed: Whether to use DeepSpeed optimization
409
+ """
410
  # Get the actual model name using helper function
411
  self.model_name = get_model_name(model_name)
412
+ self.use_deepspeed = use_deepspeed
413
+
414
  # Auto-detect device
415
  if use_gpu:
416
  if torch.backends.mps.is_available():
 
424
 
425
  print(f"Using device: {self.device}")
426
  print(f"Loading model: {self.model_name}")
427
+ print(f"DeepSpeed enabled: {self.use_deepspeed}")
428
 
429
  # Check if model is XLSR and use appropriate processor/model
430
  is_xlsr = "xlsr" in self.model_name.lower()
431
+
432
  if is_xlsr:
433
  print("Using Wav2Vec2Processor and Wav2Vec2ForCTC for XLSR model")
434
  self.processor = Wav2Vec2Processor.from_pretrained(self.model_name)
 
437
  print("Using AutoProcessor and AutoModelForCTC")
438
  self.processor = AutoProcessor.from_pretrained(self.model_name)
439
  self.model = AutoModelForCTC.from_pretrained(self.model_name)
440
+
441
+ # Initialize DeepSpeed if enabled
442
+ if self.use_deepspeed:
443
+ self._init_deepspeed()
444
+ else:
445
+ self.model.to(self.device)
446
+ self.model.eval()
447
+ self.ds_engine = None
448
 
449
  # Disable gradients for inference
450
  torch.set_grad_enabled(False)
451
 
452
+ def _init_deepspeed(self):
453
+ """Initialize DeepSpeed inference engine"""
454
+ try:
455
+ # DeepSpeed configuration based on device
456
+ if self.device == "cuda":
457
+ ds_config = {
458
+ "tensor_parallel": {"tp_size": 1},
459
+ "dtype": torch.float32,
460
+ "replace_with_kernel_inject": True,
461
+ "enable_cuda_graph": False,
462
+ }
463
+ else:
464
+ ds_config = {
465
+ "tensor_parallel": {"tp_size": 1},
466
+ "dtype": torch.float32,
467
+ "replace_with_kernel_inject": False,
468
+ "enable_cuda_graph": False,
469
+ }
470
+
471
+ print("Initializing DeepSpeed inference engine...")
472
+ self.ds_engine = deepspeed.init_inference(self.model, **ds_config)
473
+ self.ds_engine.module.to(self.device)
474
+
475
+ except Exception as e:
476
+ print(f"DeepSpeed initialization failed: {e}")
477
+ print("Falling back to standard PyTorch inference...")
478
+ self.use_deepspeed = False
479
+ self.ds_engine = None
480
+ self.model.to(self.device)
481
+ self.model.eval()
482
+
483
+ def _get_model(self):
484
+ """Get the appropriate model for inference"""
485
+ if self.use_deepspeed and self.ds_engine is not None:
486
+ return self.ds_engine.module
487
+ return self.model
488
+
489
+ def buffer_to_text(
490
+ self, audio_buffer: Union[np.ndarray, torch.Tensor, List]
491
+ ) -> str:
492
+ """
493
+ Convert audio buffer to text transcription.
494
+
495
+ Args:
496
+ audio_buffer: Audio data as numpy array, tensor, or list
497
+
498
+ Returns:
499
+ str: Transcribed text
500
+ """
501
  if len(audio_buffer) == 0:
502
  return ""
503
 
504
  # Convert to tensor
505
  if isinstance(audio_buffer, np.ndarray):
506
  audio_tensor = torch.from_numpy(audio_buffer).float()
507
+ elif isinstance(audio_buffer, list):
508
  audio_tensor = torch.tensor(audio_buffer, dtype=torch.float32)
509
+ else:
510
+ audio_tensor = audio_buffer.float()
511
 
512
  # Process audio
513
  inputs = self.processor(
 
525
  else None
526
  )
527
 
528
+ # Get the appropriate model
529
+ model = self._get_model()
530
+
531
  # Inference
532
  with torch.no_grad():
533
  if attention_mask is not None:
534
+ outputs = model(input_values, attention_mask=attention_mask)
535
  else:
536
+ outputs = model(input_values)
537
+
538
+ # Handle different output formats
539
+ if hasattr(outputs, "logits"):
540
+ logits = outputs.logits
541
+ else:
542
+ logits = outputs
543
 
544
  # Decode
545
  predicted_ids = torch.argmax(logits, dim=-1)
 
549
  transcription = self.processor.batch_decode(predicted_ids)[0]
550
  return transcription.lower().strip()
551
 
552
+ def file_to_text(self, filename: str) -> str:
553
+ """
554
+ Transcribe audio file to text.
555
+
556
+ Args:
557
+ filename: Path to audio file
558
+
559
+ Returns:
560
+ str: Transcribed text
561
+ """
562
  try:
563
  audio_input, _ = librosa.load(filename, sr=16000, dtype=np.float32)
564
  return self.buffer_to_text(audio_input)
 
566
  print(f"Error loading audio file {filename}: {e}")
567
  return ""
568
 
569
+ def batch_file_to_text(self, filenames: List[str]) -> List[str]:
570
+ """
571
+ Transcribe multiple audio files to text.
572
+
573
+ Args:
574
+ filenames: List of audio file paths
575
+
576
+ Returns:
577
+ List[str]: List of transcribed texts
578
+ """
579
+ results = []
580
+ for i, filename in enumerate(filenames):
581
+ print(f"Processing file {i+1}/{len(filenames)}: {filename}")
582
+ transcription = self.file_to_text(filename)
583
+ results.append(transcription)
584
+ if transcription:
585
+ print(f"Transcription: {transcription}")
586
+ else:
587
+ print("Failed to transcribe")
588
+ return results
589
 
590
+ def transcribe_with_confidence(
591
+ self, audio_buffer: Union[np.ndarray, torch.Tensor]
592
+ ) -> tuple:
593
+ """
594
+ Transcribe audio and return confidence scores.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
595
 
596
+ Args:
597
+ audio_buffer: Audio data
 
598
 
599
+ Returns:
600
+ tuple: (transcription, confidence_scores)
601
+ """
602
  if len(audio_buffer) == 0:
603
+ return "", []
604
 
605
  # Convert to tensor
606
  if isinstance(audio_buffer, np.ndarray):
607
  audio_tensor = torch.from_numpy(audio_buffer).float()
608
  else:
609
+ audio_tensor = audio_buffer.float()
610
 
611
  # Process audio
612
  inputs = self.processor(
613
  audio_tensor,
614
  sampling_rate=16_000,
615
+ return_tensors="pt",
616
  padding=True,
617
  )
618
 
619
+ input_values = inputs.input_values.to(self.device)
620
+ attention_mask = (
621
+ inputs.attention_mask.to(self.device)
622
+ if "attention_mask" in inputs
623
+ else None
624
+ )
625
+
626
+ model = self._get_model()
627
 
628
+ # Inference
629
+ with torch.no_grad():
630
+ if attention_mask is not None:
631
+ outputs = model(input_values, attention_mask=attention_mask)
632
+ else:
633
+ outputs = model(input_values)
634
 
635
+ if hasattr(outputs, "logits"):
636
+ logits = outputs.logits
637
+ else:
638
+ logits = outputs
 
 
 
639
 
640
+ # Get probabilities and confidence scores
641
+ probs = torch.nn.functional.softmax(logits, dim=-1)
642
+ predicted_ids = torch.argmax(logits, dim=-1)
643
 
644
+ # Calculate confidence as max probability for each prediction
645
+ max_probs = torch.max(probs, dim=-1)[0]
646
+ confidence_scores = max_probs.cpu().numpy().tolist()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
647
 
648
+ if self.device != "cpu":
649
+ predicted_ids = predicted_ids.cpu()
650
+
651
+ transcription = self.processor.batch_decode(predicted_ids)[0]
652
+ return transcription.lower().strip(), confidence_scores
653
+
654
+ def cleanup(self):
655
+ """Clean up resources"""
656
+ if hasattr(self, "ds_engine") and self.ds_engine is not None:
657
+ del self.ds_engine
658
+ if hasattr(self, "model"):
659
+ del self.model
660
+ if hasattr(self, "processor"):
661
+ del self.processor
662
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
663
+
664
+ def __del__(self):
665
+ """Destructor to clean up resources"""
666
+ self.cleanup()
 
 
 
 
 
 
 
 
 
 
 
 
src/apis/controllers/speaking_controller.py CHANGED
@@ -14,10 +14,12 @@ import Levenshtein
14
  from dataclasses import dataclass
15
  from enum import Enum
16
  import os
17
- from src.AI_Models.wave2vec_inference import (
18
- create_inference,
19
- export_to_onnx,
20
- )
 
 
21
  from src.utils.vietnamese_tips import vietnamese_tips
22
 
23
  # Download required NLTK data
@@ -78,9 +80,7 @@ class EnhancedWav2Vec2CharacterASR:
78
  export_to_onnx(model_name, quantize=quantized)
79
 
80
  # Use optimized inference
81
- self.model = create_inference(
82
- model_name=model_name, use_onnx=onnx, use_onnx_quantize=quantized
83
- )
84
 
85
  def transcribe_with_features(self, audio_path: str, retry_count: int = 0) -> Dict:
86
  """Enhanced transcription with audio features for prosody analysis - Optimized with retry mechanism"""
 
14
  from dataclasses import dataclass
15
  from enum import Enum
16
  import os
17
+
18
+ # from src.AI_Models.wave2vec_inference import (
19
+ # create_inference,
20
+ # export_to_onnx,
21
+ # )
22
+ from src.AI_Models.wave2vec_inference import Wave2Vec2Inference
23
  from src.utils.vietnamese_tips import vietnamese_tips
24
 
25
  # Download required NLTK data
 
80
  export_to_onnx(model_name, quantize=quantized)
81
 
82
  # Use optimized inference
83
+ self.model = Wave2Vec2Inference(model_name, use_gpu=False, use_deepspeed=True)
 
 
84
 
85
  def transcribe_with_features(self, audio_path: str, retry_count: int = 0) -> Dict:
86
  """Enhanced transcription with audio features for prosody analysis - Optimized with retry mechanism"""
src/apis/routes/speaking_route.py CHANGED
@@ -511,7 +511,10 @@ async def assess_pronunciation(
511
  await optimize_post_assessment_processing(result, reference_text)
512
 
513
  # Add processing time
 
514
  processing_time = time.time() - start_time
 
 
515
  result["processing_info"]["processing_time"] = processing_time
516
 
517
  # Convert numpy types for JSON serialization
 
511
  await optimize_post_assessment_processing(result, reference_text)
512
 
513
  # Add processing time
514
+
515
  processing_time = time.time() - start_time
516
+ if "processing_info" not in result:
517
+ result["processing_info"] = {}
518
  result["processing_info"]["processing_time"] = processing_time
519
 
520
  # Convert numpy types for JSON serialization