Run_code_api / src /AI_Models /wave2vec_inference.py
ABAO77's picture
Refactor code structure for improved readability and maintainability
85fa45c
raw
history blame
15.1 kB
import torch
from transformers import (
AutoModelForCTC,
AutoProcessor,
Wav2Vec2Processor,
Wav2Vec2ForCTC,
)
import onnxruntime as rt
import numpy as np
import librosa
import warnings
import os
warnings.filterwarnings("ignore")
class Wave2Vec2Inference:
def __init__(self, model_name, hotwords=[], use_lm_if_possible=True, use_gpu=True, enable_optimizations=True):
# Auto-detect best available device
if use_gpu:
if torch.backends.mps.is_available():
self.device = "mps"
elif torch.cuda.is_available():
self.device = "cuda"
else:
self.device = "cpu"
else:
self.device = "cpu"
print(f"Using device: {self.device}")
# Set optimal torch settings for inference
torch.set_grad_enabled(False) # Disable gradients globally for inference
if self.device == "cpu":
# CPU optimizations
torch.set_num_threads(torch.get_num_threads()) # Use all available CPU cores
torch.set_float32_matmul_precision('high')
elif self.device == "cuda":
# CUDA optimizations
torch.backends.cudnn.benchmark = True # Enable cuDNN benchmark mode
torch.backends.cudnn.deterministic = False
elif self.device == "mps":
# MPS optimizations
torch.backends.mps.enable_fallback = True
if use_lm_if_possible:
self.processor = AutoProcessor.from_pretrained(model_name)
else:
self.processor = Wav2Vec2Processor.from_pretrained(model_name)
self.model = AutoModelForCTC.from_pretrained(model_name)
self.model.to(self.device)
# Set model to evaluation mode for inference optimization
self.model.eval()
# Try to optimize model for inference (safe version) - only if enabled
if enable_optimizations:
try:
# First try torch.compile (PyTorch 2.0+) - more robust
if hasattr(torch, 'compile') and self.device != "mps": # MPS doesn't support torch.compile yet
self.model = torch.compile(self.model, mode="reduce-overhead")
print("Model compiled with torch.compile for faster inference")
else:
# Alternative: try JIT scripting for older PyTorch versions
try:
scripted_model = torch.jit.script(self.model)
if hasattr(torch.jit, 'optimize_for_inference'):
scripted_model = torch.jit.optimize_for_inference(scripted_model)
self.model = scripted_model
print("Model optimized with JIT scripting")
except Exception as jit_e:
print(f"JIT optimization failed, using regular model: {jit_e}")
except Exception as e:
print(f"Model optimization failed, using regular model: {e}")
else:
print("Model optimizations disabled")
self.hotwords = hotwords
self.use_lm_if_possible = use_lm_if_possible
# Pre-allocate tensors for common audio lengths to avoid repeated allocation
self.tensor_cache = {}
# Warm up the model with a dummy input (only if optimizations enabled)
if enable_optimizations:
self._warmup_model()
def _warmup_model(self):
"""Warm up the model with dummy input to optimize first inference"""
try:
dummy_audio = torch.zeros(16000, device=self.device) # 1 second of silence
dummy_inputs = self.processor(
dummy_audio,
sampling_rate=16_000,
return_tensors="pt",
padding=True,
)
# Move inputs to device
dummy_inputs = {k: v.to(self.device) for k, v in dummy_inputs.items()}
# Run dummy inference
with torch.no_grad():
_ = self.model(
dummy_inputs["input_values"],
attention_mask=dummy_inputs.get("attention_mask")
)
print("Model warmed up successfully")
except Exception as e:
print(f"Warmup failed: {e}")
def buffer_to_text(self, audio_buffer):
if len(audio_buffer) == 0:
return ""
# Convert to tensor with optimal dtype and device placement
if isinstance(audio_buffer, np.ndarray):
audio_tensor = torch.from_numpy(audio_buffer).float()
else:
audio_tensor = torch.tensor(audio_buffer, dtype=torch.float32)
# Use optimized processing
inputs = self.processor(
audio_tensor,
sampling_rate=16_000,
return_tensors="pt",
padding=True,
)
# Move to device in one operation
input_values = inputs.input_values.to(self.device, non_blocking=True)
attention_mask = inputs.attention_mask.to(self.device, non_blocking=True) if "attention_mask" in inputs else None
# Optimized inference with mixed precision for GPU
if self.device in ["cuda", "mps"]:
with torch.no_grad(), torch.autocast(device_type=self.device.replace("mps", "cpu"), enabled=self.device=="cuda"):
if attention_mask is not None:
logits = self.model(input_values, attention_mask=attention_mask).logits
else:
logits = self.model(input_values).logits
else:
# CPU inference optimization
with torch.no_grad():
if attention_mask is not None:
logits = self.model(input_values, attention_mask=attention_mask).logits
else:
logits = self.model(input_values).logits
# Optimized decoding
if hasattr(self.processor, "decoder") and self.use_lm_if_possible:
# Move to CPU for decoder processing (decoder only works on CPU)
logits_cpu = logits[0].cpu().numpy()
transcription = self.processor.decode(
logits_cpu,
hotwords=self.hotwords,
output_word_offsets=True,
)
confidence = transcription.lm_score / max(len(transcription.text.split(" ")), 1)
transcription: str = transcription.text
else:
# Fast argmax on GPU/MPS, then move to CPU for batch_decode
predicted_ids = torch.argmax(logits, dim=-1)
if self.device != "cpu":
predicted_ids = predicted_ids.cpu()
transcription: str = self.processor.batch_decode(predicted_ids)[0]
return transcription.lower().strip()
def confidence_score(self, logits, predicted_ids):
scores = torch.nn.functional.softmax(logits, dim=-1)
pred_scores = scores.gather(-1, predicted_ids.unsqueeze(-1))[:, :, 0]
mask = torch.logical_and(
predicted_ids.not_equal(self.processor.tokenizer.word_delimiter_token_id),
predicted_ids.not_equal(self.processor.tokenizer.pad_token_id),
)
character_scores = pred_scores.masked_select(mask)
total_average = torch.sum(character_scores) / len(character_scores)
return total_average
def file_to_text(self, filename):
# Optimized audio loading
try:
audio_input, samplerate = librosa.load(filename, sr=16000, dtype=np.float32)
return self.buffer_to_text(audio_input)
except Exception as e:
print(f"Error loading audio file {filename}: {e}")
return ""
class Wave2Vec2ONNXInference:
def __init__(self, model_name, onnx_path):
self.processor = Wav2Vec2Processor.from_pretrained(model_name)
# Optimized ONNX Runtime session
options = rt.SessionOptions()
options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL
options.execution_mode = rt.ExecutionMode.ORT_PARALLEL
options.inter_op_num_threads = 0 # Use all available cores
options.intra_op_num_threads = 0 # Use all available cores
# Enable CPU optimizations
providers = []
if rt.get_device() == 'GPU':
providers.append('CUDAExecutionProvider')
providers.extend(['CPUExecutionProvider'])
self.model = rt.InferenceSession(
onnx_path,
options,
providers=providers
)
# Pre-compile input name for faster access
self.input_name = self.model.get_inputs()[0].name
print(f"ONNX model loaded with providers: {self.model.get_providers()}")
def buffer_to_text(self, audio_buffer):
if len(audio_buffer) == 0:
return ""
# Optimized preprocessing
if isinstance(audio_buffer, np.ndarray):
audio_tensor = torch.from_numpy(audio_buffer).float()
else:
audio_tensor = torch.tensor(audio_buffer, dtype=torch.float32)
inputs = self.processor(
audio_tensor,
sampling_rate=16_000,
return_tensors="np",
padding=True,
)
# Optimized ONNX inference
input_values = inputs.input_values.astype(np.float32)
onnx_outputs = self.model.run(
None,
{self.input_name: input_values}
)[0]
# Fast argmax and decoding
prediction = np.argmax(onnx_outputs, axis=-1)
transcription = self.processor.decode(prediction.squeeze().tolist())
return transcription.lower().strip()
def file_to_text(self, filename):
try:
audio_input, samplerate = librosa.load(filename, sr=16000, dtype=np.float32)
return self.buffer_to_text(audio_input)
except Exception as e:
print(f"Error loading audio file {filename}: {e}")
return ""
# took that script from: https://github.com/ccoreilly/wav2vec2-service/blob/master/convert_torch_to_onnx.py
class OptimizedWave2Vec2Factory:
"""Factory class to create the most optimized Wave2Vec2 inference instance"""
@staticmethod
def create_optimized_inference(model_name, onnx_path=None, safe_mode=False, **kwargs):
"""
Create the most optimized inference instance based on available resources
Args:
model_name: HuggingFace model name
onnx_path: Path to ONNX model (optional, for maximum speed)
safe_mode: If True, disable aggressive optimizations that might cause issues
**kwargs: Additional arguments for Wave2Vec2Inference
Returns:
Optimized inference instance
"""
if onnx_path and os.path.exists(onnx_path):
print("Using ONNX model for maximum speed")
return Wave2Vec2ONNXInference(model_name, onnx_path)
else:
print("Using PyTorch model with optimizations")
# In safe mode, disable optimizations that might cause issues
if safe_mode:
kwargs['enable_optimizations'] = False
print("Running in safe mode - optimizations disabled")
return Wave2Vec2Inference(model_name, **kwargs)
@staticmethod
def create_safe_inference(model_name, **kwargs):
"""Create a safe inference instance without aggressive optimizations"""
kwargs['enable_optimizations'] = False
return Wave2Vec2Inference(model_name, **kwargs)
def convert_to_onnx(model_id_or_path, onnx_model_name):
print(f"Converting {model_id_or_path} to onnx")
model = Wav2Vec2ForCTC.from_pretrained(model_id_or_path)
audio_len = 250000
x = torch.randn(1, audio_len, requires_grad=True)
torch.onnx.export(
model, # model being run
x, # model input (or a tuple for multiple inputs)
onnx_model_name, # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=14, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=["input"], # the model's input names
output_names=["output"], # the model's output names
dynamic_axes={
"input": {1: "audio_len"}, # variable length axes
"output": {1: "audio_len"},
},
)
def quantize_onnx_model(onnx_model_path, quantized_model_path):
print("Starting quantization...")
from onnxruntime.quantization import quantize_dynamic, QuantType
quantize_dynamic(
onnx_model_path, quantized_model_path, weight_type=QuantType.QUInt8
)
print(f"Quantized model saved to: {quantized_model_path}")
def export_to_onnx(
model: str = "facebook/wav2vec2-large-960h-lv60-self", quantize: bool = False
):
onnx_model_name = model.split("/")[-1] + ".onnx"
convert_to_onnx(model, onnx_model_name)
if quantize:
quantized_model_name = model.split("/")[-1] + ".quant.onnx"
quantize_onnx_model(onnx_model_name, quantized_model_name)
if __name__ == "__main__":
from loguru import logger
import time
# Use optimized factory to create the best inference instance
asr = OptimizedWave2Vec2Factory.create_optimized_inference(
"facebook/wav2vec2-large-960h-lv60-self"
)
# Test if file exists
test_file = "test.wav"
if not os.path.exists(test_file):
print(f"Test file {test_file} not found. Please provide a valid audio file.")
exit(1)
# Warm up runs (model already warmed up during initialization)
print("Running additional warm-up...")
for i in range(2):
asr.file_to_text(test_file)
print(f"Warm up {i+1} completed")
# Test runs
print("Running optimized performance tests...")
times = []
for i in range(10):
start_time = time.time()
text = asr.file_to_text(test_file)
end_time = time.time()
execution_time = end_time - start_time
times.append(execution_time)
print(f"Test {i+1}: {execution_time:.3f}s - {text}")
# Calculate statistics
average_time = sum(times) / len(times)
min_time = min(times)
max_time = max(times)
std_time = np.std(times)
print(f"\n=== Performance Statistics ===")
print(f"Average execution time: {average_time:.3f}s")
print(f"Min time: {min_time:.3f}s")
print(f"Max time: {max_time:.3f}s")
print(f"Standard deviation: {std_time:.3f}s")
print(f"Speed improvement: ~{((max_time - min_time) / max_time * 100):.1f}% faster (min vs max)")
# Calculate throughput
if times:
throughput = 1.0 / average_time
print(f"Average throughput: {throughput:.2f} inferences/second")