Spaces:
Running
Running
File size: 15,107 Bytes
c5ca6dc 1f79c2f c5ca6dc 85fa45c c5ca6dc 85fa45c c5ca6dc 85fa45c c5ca6dc 85fa45c c5ca6dc 85fa45c c5ca6dc 85fa45c c5ca6dc 85fa45c c5ca6dc 85fa45c c5ca6dc 85fa45c c5ca6dc 85fa45c c5ca6dc 85fa45c c5ca6dc 85fa45c c5ca6dc 85fa45c c5ca6dc 85fa45c c5ca6dc 85fa45c c5ca6dc 85fa45c c5ca6dc 85fa45c c5ca6dc 85fa45c c5ca6dc 85fa45c c5ca6dc 85fa45c c5ca6dc 85fa45c c5ca6dc 85fa45c c5ca6dc 85fa45c c5ca6dc 85fa45c c5ca6dc 85fa45c c5ca6dc 85fa45c c5ca6dc 85fa45c c5ca6dc 85fa45c c5ca6dc 85fa45c c5ca6dc 85fa45c c5ca6dc 85fa45c c5ca6dc 85fa45c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 |
import torch
from transformers import (
AutoModelForCTC,
AutoProcessor,
Wav2Vec2Processor,
Wav2Vec2ForCTC,
)
import onnxruntime as rt
import numpy as np
import librosa
import warnings
import os
warnings.filterwarnings("ignore")
class Wave2Vec2Inference:
def __init__(self, model_name, hotwords=[], use_lm_if_possible=True, use_gpu=True, enable_optimizations=True):
# Auto-detect best available device
if use_gpu:
if torch.backends.mps.is_available():
self.device = "mps"
elif torch.cuda.is_available():
self.device = "cuda"
else:
self.device = "cpu"
else:
self.device = "cpu"
print(f"Using device: {self.device}")
# Set optimal torch settings for inference
torch.set_grad_enabled(False) # Disable gradients globally for inference
if self.device == "cpu":
# CPU optimizations
torch.set_num_threads(torch.get_num_threads()) # Use all available CPU cores
torch.set_float32_matmul_precision('high')
elif self.device == "cuda":
# CUDA optimizations
torch.backends.cudnn.benchmark = True # Enable cuDNN benchmark mode
torch.backends.cudnn.deterministic = False
elif self.device == "mps":
# MPS optimizations
torch.backends.mps.enable_fallback = True
if use_lm_if_possible:
self.processor = AutoProcessor.from_pretrained(model_name)
else:
self.processor = Wav2Vec2Processor.from_pretrained(model_name)
self.model = AutoModelForCTC.from_pretrained(model_name)
self.model.to(self.device)
# Set model to evaluation mode for inference optimization
self.model.eval()
# Try to optimize model for inference (safe version) - only if enabled
if enable_optimizations:
try:
# First try torch.compile (PyTorch 2.0+) - more robust
if hasattr(torch, 'compile') and self.device != "mps": # MPS doesn't support torch.compile yet
self.model = torch.compile(self.model, mode="reduce-overhead")
print("Model compiled with torch.compile for faster inference")
else:
# Alternative: try JIT scripting for older PyTorch versions
try:
scripted_model = torch.jit.script(self.model)
if hasattr(torch.jit, 'optimize_for_inference'):
scripted_model = torch.jit.optimize_for_inference(scripted_model)
self.model = scripted_model
print("Model optimized with JIT scripting")
except Exception as jit_e:
print(f"JIT optimization failed, using regular model: {jit_e}")
except Exception as e:
print(f"Model optimization failed, using regular model: {e}")
else:
print("Model optimizations disabled")
self.hotwords = hotwords
self.use_lm_if_possible = use_lm_if_possible
# Pre-allocate tensors for common audio lengths to avoid repeated allocation
self.tensor_cache = {}
# Warm up the model with a dummy input (only if optimizations enabled)
if enable_optimizations:
self._warmup_model()
def _warmup_model(self):
"""Warm up the model with dummy input to optimize first inference"""
try:
dummy_audio = torch.zeros(16000, device=self.device) # 1 second of silence
dummy_inputs = self.processor(
dummy_audio,
sampling_rate=16_000,
return_tensors="pt",
padding=True,
)
# Move inputs to device
dummy_inputs = {k: v.to(self.device) for k, v in dummy_inputs.items()}
# Run dummy inference
with torch.no_grad():
_ = self.model(
dummy_inputs["input_values"],
attention_mask=dummy_inputs.get("attention_mask")
)
print("Model warmed up successfully")
except Exception as e:
print(f"Warmup failed: {e}")
def buffer_to_text(self, audio_buffer):
if len(audio_buffer) == 0:
return ""
# Convert to tensor with optimal dtype and device placement
if isinstance(audio_buffer, np.ndarray):
audio_tensor = torch.from_numpy(audio_buffer).float()
else:
audio_tensor = torch.tensor(audio_buffer, dtype=torch.float32)
# Use optimized processing
inputs = self.processor(
audio_tensor,
sampling_rate=16_000,
return_tensors="pt",
padding=True,
)
# Move to device in one operation
input_values = inputs.input_values.to(self.device, non_blocking=True)
attention_mask = inputs.attention_mask.to(self.device, non_blocking=True) if "attention_mask" in inputs else None
# Optimized inference with mixed precision for GPU
if self.device in ["cuda", "mps"]:
with torch.no_grad(), torch.autocast(device_type=self.device.replace("mps", "cpu"), enabled=self.device=="cuda"):
if attention_mask is not None:
logits = self.model(input_values, attention_mask=attention_mask).logits
else:
logits = self.model(input_values).logits
else:
# CPU inference optimization
with torch.no_grad():
if attention_mask is not None:
logits = self.model(input_values, attention_mask=attention_mask).logits
else:
logits = self.model(input_values).logits
# Optimized decoding
if hasattr(self.processor, "decoder") and self.use_lm_if_possible:
# Move to CPU for decoder processing (decoder only works on CPU)
logits_cpu = logits[0].cpu().numpy()
transcription = self.processor.decode(
logits_cpu,
hotwords=self.hotwords,
output_word_offsets=True,
)
confidence = transcription.lm_score / max(len(transcription.text.split(" ")), 1)
transcription: str = transcription.text
else:
# Fast argmax on GPU/MPS, then move to CPU for batch_decode
predicted_ids = torch.argmax(logits, dim=-1)
if self.device != "cpu":
predicted_ids = predicted_ids.cpu()
transcription: str = self.processor.batch_decode(predicted_ids)[0]
return transcription.lower().strip()
def confidence_score(self, logits, predicted_ids):
scores = torch.nn.functional.softmax(logits, dim=-1)
pred_scores = scores.gather(-1, predicted_ids.unsqueeze(-1))[:, :, 0]
mask = torch.logical_and(
predicted_ids.not_equal(self.processor.tokenizer.word_delimiter_token_id),
predicted_ids.not_equal(self.processor.tokenizer.pad_token_id),
)
character_scores = pred_scores.masked_select(mask)
total_average = torch.sum(character_scores) / len(character_scores)
return total_average
def file_to_text(self, filename):
# Optimized audio loading
try:
audio_input, samplerate = librosa.load(filename, sr=16000, dtype=np.float32)
return self.buffer_to_text(audio_input)
except Exception as e:
print(f"Error loading audio file {filename}: {e}")
return ""
class Wave2Vec2ONNXInference:
def __init__(self, model_name, onnx_path):
self.processor = Wav2Vec2Processor.from_pretrained(model_name)
# Optimized ONNX Runtime session
options = rt.SessionOptions()
options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL
options.execution_mode = rt.ExecutionMode.ORT_PARALLEL
options.inter_op_num_threads = 0 # Use all available cores
options.intra_op_num_threads = 0 # Use all available cores
# Enable CPU optimizations
providers = []
if rt.get_device() == 'GPU':
providers.append('CUDAExecutionProvider')
providers.extend(['CPUExecutionProvider'])
self.model = rt.InferenceSession(
onnx_path,
options,
providers=providers
)
# Pre-compile input name for faster access
self.input_name = self.model.get_inputs()[0].name
print(f"ONNX model loaded with providers: {self.model.get_providers()}")
def buffer_to_text(self, audio_buffer):
if len(audio_buffer) == 0:
return ""
# Optimized preprocessing
if isinstance(audio_buffer, np.ndarray):
audio_tensor = torch.from_numpy(audio_buffer).float()
else:
audio_tensor = torch.tensor(audio_buffer, dtype=torch.float32)
inputs = self.processor(
audio_tensor,
sampling_rate=16_000,
return_tensors="np",
padding=True,
)
# Optimized ONNX inference
input_values = inputs.input_values.astype(np.float32)
onnx_outputs = self.model.run(
None,
{self.input_name: input_values}
)[0]
# Fast argmax and decoding
prediction = np.argmax(onnx_outputs, axis=-1)
transcription = self.processor.decode(prediction.squeeze().tolist())
return transcription.lower().strip()
def file_to_text(self, filename):
try:
audio_input, samplerate = librosa.load(filename, sr=16000, dtype=np.float32)
return self.buffer_to_text(audio_input)
except Exception as e:
print(f"Error loading audio file {filename}: {e}")
return ""
# took that script from: https://github.com/ccoreilly/wav2vec2-service/blob/master/convert_torch_to_onnx.py
class OptimizedWave2Vec2Factory:
"""Factory class to create the most optimized Wave2Vec2 inference instance"""
@staticmethod
def create_optimized_inference(model_name, onnx_path=None, safe_mode=False, **kwargs):
"""
Create the most optimized inference instance based on available resources
Args:
model_name: HuggingFace model name
onnx_path: Path to ONNX model (optional, for maximum speed)
safe_mode: If True, disable aggressive optimizations that might cause issues
**kwargs: Additional arguments for Wave2Vec2Inference
Returns:
Optimized inference instance
"""
if onnx_path and os.path.exists(onnx_path):
print("Using ONNX model for maximum speed")
return Wave2Vec2ONNXInference(model_name, onnx_path)
else:
print("Using PyTorch model with optimizations")
# In safe mode, disable optimizations that might cause issues
if safe_mode:
kwargs['enable_optimizations'] = False
print("Running in safe mode - optimizations disabled")
return Wave2Vec2Inference(model_name, **kwargs)
@staticmethod
def create_safe_inference(model_name, **kwargs):
"""Create a safe inference instance without aggressive optimizations"""
kwargs['enable_optimizations'] = False
return Wave2Vec2Inference(model_name, **kwargs)
def convert_to_onnx(model_id_or_path, onnx_model_name):
print(f"Converting {model_id_or_path} to onnx")
model = Wav2Vec2ForCTC.from_pretrained(model_id_or_path)
audio_len = 250000
x = torch.randn(1, audio_len, requires_grad=True)
torch.onnx.export(
model, # model being run
x, # model input (or a tuple for multiple inputs)
onnx_model_name, # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=14, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=["input"], # the model's input names
output_names=["output"], # the model's output names
dynamic_axes={
"input": {1: "audio_len"}, # variable length axes
"output": {1: "audio_len"},
},
)
def quantize_onnx_model(onnx_model_path, quantized_model_path):
print("Starting quantization...")
from onnxruntime.quantization import quantize_dynamic, QuantType
quantize_dynamic(
onnx_model_path, quantized_model_path, weight_type=QuantType.QUInt8
)
print(f"Quantized model saved to: {quantized_model_path}")
def export_to_onnx(
model: str = "facebook/wav2vec2-large-960h-lv60-self", quantize: bool = False
):
onnx_model_name = model.split("/")[-1] + ".onnx"
convert_to_onnx(model, onnx_model_name)
if quantize:
quantized_model_name = model.split("/")[-1] + ".quant.onnx"
quantize_onnx_model(onnx_model_name, quantized_model_name)
if __name__ == "__main__":
from loguru import logger
import time
# Use optimized factory to create the best inference instance
asr = OptimizedWave2Vec2Factory.create_optimized_inference(
"facebook/wav2vec2-large-960h-lv60-self"
)
# Test if file exists
test_file = "test.wav"
if not os.path.exists(test_file):
print(f"Test file {test_file} not found. Please provide a valid audio file.")
exit(1)
# Warm up runs (model already warmed up during initialization)
print("Running additional warm-up...")
for i in range(2):
asr.file_to_text(test_file)
print(f"Warm up {i+1} completed")
# Test runs
print("Running optimized performance tests...")
times = []
for i in range(10):
start_time = time.time()
text = asr.file_to_text(test_file)
end_time = time.time()
execution_time = end_time - start_time
times.append(execution_time)
print(f"Test {i+1}: {execution_time:.3f}s - {text}")
# Calculate statistics
average_time = sum(times) / len(times)
min_time = min(times)
max_time = max(times)
std_time = np.std(times)
print(f"\n=== Performance Statistics ===")
print(f"Average execution time: {average_time:.3f}s")
print(f"Min time: {min_time:.3f}s")
print(f"Max time: {max_time:.3f}s")
print(f"Standard deviation: {std_time:.3f}s")
print(f"Speed improvement: ~{((max_time - min_time) / max_time * 100):.1f}% faster (min vs max)")
# Calculate throughput
if times:
throughput = 1.0 / average_time
print(f"Average throughput: {throughput:.2f} inferences/second")
|