Spaces:
Sleeping
Sleeping
| import torch | |
| import onnx | |
| import onnxruntime | |
| import numpy as np | |
| from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC | |
| from typing import Dict, Tuple | |
| import librosa | |
| import os | |
| class Wav2Vec2ONNXConverter: | |
| """Convert Wav2Vec2 model to ONNX format""" | |
| def __init__(self, model_name: str = "facebook/wav2vec2-base-960h"): | |
| """Initialize the converter with the specified model""" | |
| print(f"Loading Wav2Vec2 model: {model_name}") | |
| self.model_name = model_name | |
| self.processor = Wav2Vec2Processor.from_pretrained(model_name) | |
| self.model = Wav2Vec2ForCTC.from_pretrained(model_name) | |
| # Disable flash attention and scaled_dot_product_attention for ONNX compatibility | |
| if hasattr(self.model.config, 'use_flash_attention_2'): | |
| self.model.config.use_flash_attention_2 = False | |
| # Force model to use standard attention | |
| if hasattr(self.model, 'wav2vec2') and hasattr(self.model.wav2vec2, 'encoder'): | |
| for layer in self.model.wav2vec2.encoder.layers: | |
| if hasattr(layer.attention, 'attention_dropout'): | |
| # Ensure standard attention is used | |
| layer.attention.attention_dropout = torch.nn.Dropout(layer.attention.attention_dropout.p) | |
| self.model.eval() | |
| self.sample_rate = 16000 | |
| print("Model loaded successfully") | |
| def convert_to_onnx(self, | |
| onnx_path: str = "wav2vec2_model.onnx", | |
| input_length: int = 160000, # 10 seconds at 16kHz | |
| opset_version: int = 14) -> str: | |
| """ | |
| Convert the Wav2Vec2 model to ONNX format | |
| Args: | |
| onnx_path: Path to save the ONNX model | |
| input_length: Length of input audio (samples) | |
| opset_version: ONNX opset version | |
| Returns: | |
| Path to the saved ONNX model | |
| """ | |
| print(f"Converting model to ONNX format...") | |
| # Create dummy input | |
| dummy_input = torch.randn(1, input_length, dtype=torch.float32) | |
| # Input names and dynamic axes | |
| input_names = ["input_values"] | |
| output_names = ["logits"] | |
| # Dynamic axes for variable length input | |
| dynamic_axes = { | |
| "input_values": {0: "batch_size", 1: "sequence_length"}, | |
| "logits": {0: "batch_size", 1: "sequence_length"} | |
| } | |
| try: | |
| # Disable torch optimizations that may cause ONNX issues | |
| with torch.no_grad(): | |
| # Set model to evaluation mode and disable dropout | |
| self.model.eval() | |
| for module in self.model.modules(): | |
| if isinstance(module, torch.nn.Dropout): | |
| module.p = 0.0 | |
| # Export to ONNX | |
| torch.onnx.export( | |
| self.model, | |
| dummy_input, | |
| onnx_path, | |
| input_names=input_names, | |
| output_names=output_names, | |
| dynamic_axes=dynamic_axes, | |
| opset_version=opset_version, | |
| do_constant_folding=True, | |
| verbose=False, | |
| export_params=True, | |
| training=torch.onnx.TrainingMode.EVAL, | |
| operator_export_type=torch.onnx.OperatorExportTypes.ONNX | |
| ) | |
| print(f"Model successfully exported to: {onnx_path}") | |
| # Verify the exported model | |
| self._verify_onnx_model(onnx_path, dummy_input) | |
| return onnx_path | |
| except Exception as e: | |
| print(f"Error during ONNX conversion: {e}") | |
| raise | |
| def _verify_onnx_model(self, onnx_path: str, test_input: torch.Tensor): | |
| """Verify the exported ONNX model""" | |
| print("Verifying ONNX model...") | |
| try: | |
| # Load and check ONNX model | |
| onnx_model = onnx.load(onnx_path) | |
| onnx.checker.check_model(onnx_model) | |
| print("✓ ONNX model structure is valid") | |
| # Test inference with ONNX Runtime | |
| ort_session = onnxruntime.InferenceSession(onnx_path) | |
| # Get model input/output info | |
| input_name = ort_session.get_inputs()[0].name | |
| output_name = ort_session.get_outputs()[0].name | |
| print(f"✓ Input name: {input_name}") | |
| print(f"✓ Output name: {output_name}") | |
| # Run inference | |
| ort_inputs = {input_name: test_input.numpy()} | |
| ort_outputs = ort_session.run([output_name], ort_inputs) | |
| # Compare with original PyTorch model | |
| with torch.no_grad(): | |
| torch_output = self.model(test_input) | |
| torch_logits = torch_output.logits | |
| # Check output similarity | |
| onnx_logits = ort_outputs[0] | |
| max_diff = np.max(np.abs(torch_logits.numpy() - onnx_logits)) | |
| print(f"✓ Maximum difference between PyTorch and ONNX: {max_diff:.6f}") | |
| if max_diff < 1e-4: | |
| print("✓ ONNX model verification successful!") | |
| else: | |
| print("⚠ Warning: Large difference detected between models") | |
| except Exception as e: | |
| print(f"Error during verification: {e}") | |
| raise | |
| class Wav2Vec2ONNXInference: | |
| """ONNX inference class for Wav2Vec2""" | |
| def __init__(self, onnx_path: str, processor_name: str = "facebook/wav2vec2-base-960h"): | |
| """Initialize ONNX inference""" | |
| print(f"Loading ONNX model from: {onnx_path}") | |
| # Load processor for tokenization | |
| self.processor = Wav2Vec2Processor.from_pretrained(processor_name) | |
| # Create ONNX Runtime session | |
| self.session = onnxruntime.InferenceSession(onnx_path) | |
| self.input_name = self.session.get_inputs()[0].name | |
| self.output_name = self.session.get_outputs()[0].name | |
| self.sample_rate = 16000 | |
| print("ONNX model loaded successfully") | |
| def transcribe(self, audio_path: str) -> Dict: | |
| """Transcribe audio using ONNX model""" | |
| try: | |
| # Load audio | |
| speech, sr = librosa.load(audio_path, sr=self.sample_rate) | |
| # Prepare input | |
| input_values = self.processor( | |
| speech, | |
| sampling_rate=self.sample_rate, | |
| return_tensors="np" | |
| ).input_values | |
| # Run ONNX inference | |
| ort_inputs = {self.input_name: input_values} | |
| ort_outputs = self.session.run([self.output_name], ort_inputs) | |
| logits = ort_outputs[0] | |
| # Decode predictions | |
| predicted_ids = np.argmax(logits, axis=-1) | |
| transcription = self.processor.batch_decode(predicted_ids)[0] | |
| # Calculate confidence scores | |
| confidence_scores = np.max(self._softmax(logits), axis=-1)[0] | |
| return { | |
| "transcription": transcription, | |
| "confidence_scores": confidence_scores[:100].tolist(), # Limit for JSON | |
| "predicted_ids": predicted_ids[0].tolist() | |
| } | |
| except Exception as e: | |
| print(f"Transcription error: {e}") | |
| return { | |
| "transcription": "", | |
| "confidence_scores": [], | |
| "predicted_ids": [] | |
| } | |
| def _softmax(self, x): | |
| """Apply softmax to logits""" | |
| exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True)) | |
| return exp_x / np.sum(exp_x, axis=-1, keepdims=True) | |
| # Example usage and testing | |
| def main(): | |
| """Example usage of the converter""" | |
| # Method 1: Try standard conversion | |
| try: | |
| print("Method 1: Standard conversion...") | |
| converter = Wav2Vec2ONNXConverter("facebook/wav2vec2-base-960h") | |
| onnx_path = converter.convert_to_onnx( | |
| onnx_path="wav2vec2_asr.onnx", | |
| input_length=160000, # 10 seconds | |
| opset_version=14 # Updated to version 14 for compatibility | |
| ) | |
| print("✓ Standard conversion successful!") | |
| except Exception as e: | |
| print(f"✗ Standard conversion failed: {e}") | |
| print("\nMethod 2: Trying fallback approach...") | |
| try: | |
| # Method 2: Use compatible model creation | |
| model, processor = create_compatible_model("facebook/wav2vec2-base-960h") | |
| onnx_path = export_with_fallback( | |
| model, | |
| processor, | |
| "wav2vec2_asr_fallback.onnx", | |
| input_length=160000 | |
| ) | |
| print("✓ Fallback conversion successful!") | |
| except Exception as e2: | |
| print(f"✗ All conversion methods failed: {e2}") | |
| return | |
| # Test ONNX inference | |
| print("\nTesting ONNX inference...") | |
| try: | |
| onnx_inference = Wav2Vec2ONNXInference(onnx_path) | |
| print("✓ ONNX model loaded successfully for inference") | |
| # Create a test audio file (or use your own) | |
| # result = onnx_inference.transcribe("test_audio.wav") | |
| # print("Transcription:", result["transcription"]) | |
| except Exception as e: | |
| print(f"✗ ONNX inference test failed: {e}") | |
| print("Conversion process completed!") | |
| # Additional utility functions | |
| def create_compatible_model(model_name: str = "facebook/wav2vec2-base-960h"): | |
| """Create a Wav2Vec2 model compatible with ONNX export""" | |
| from transformers import Wav2Vec2Config | |
| # Load config and modify for ONNX compatibility | |
| config = Wav2Vec2Config.from_pretrained(model_name) | |
| # Disable features that may cause ONNX issues | |
| if hasattr(config, 'use_flash_attention_2'): | |
| config.use_flash_attention_2 = False | |
| if hasattr(config, 'torch_dtype'): | |
| config.torch_dtype = torch.float32 | |
| # Load model with modified config | |
| model = Wav2Vec2ForCTC.from_pretrained(model_name, config=config, torch_dtype=torch.float32) | |
| processor = Wav2Vec2Processor.from_pretrained(model_name) | |
| return model, processor | |
| def export_with_fallback(model, processor, onnx_path: str, input_length: int = 160000): | |
| """Export model with fallback options for different opset versions""" | |
| dummy_input = torch.randn(1, input_length, dtype=torch.float32) | |
| input_names = ["input_values"] | |
| output_names = ["logits"] | |
| dynamic_axes = { | |
| "input_values": {0: "batch_size", 1: "sequence_length"}, | |
| "logits": {0: "batch_size", 1: "sequence_length"} | |
| } | |
| # Try different opset versions | |
| opset_versions = [14, 13, 12, 11] | |
| for opset_version in opset_versions: | |
| try: | |
| print(f"Trying ONNX export with opset version {opset_version}...") | |
| with torch.no_grad(): | |
| model.eval() | |
| # Disable all dropouts | |
| for module in model.modules(): | |
| if isinstance(module, torch.nn.Dropout): | |
| module.p = 0.0 | |
| torch.onnx.export( | |
| model, | |
| dummy_input, | |
| onnx_path, | |
| input_names=input_names, | |
| output_names=output_names, | |
| dynamic_axes=dynamic_axes, | |
| opset_version=opset_version, | |
| do_constant_folding=True, | |
| verbose=False, | |
| export_params=True, | |
| training=torch.onnx.TrainingMode.EVAL | |
| ) | |
| print(f"✓ Successfully exported with opset version {opset_version}") | |
| return onnx_path | |
| except Exception as e: | |
| print(f"✗ Failed with opset {opset_version}: {str(e)[:100]}...") | |
| continue | |
| raise Exception("Failed to export with all attempted opset versions") | |
| def optimize_onnx_model(onnx_path: str, optimized_path: str = None): | |
| """Optimize ONNX model for inference""" | |
| try: | |
| from onnxruntime.tools import optimizer | |
| if optimized_path is None: | |
| optimized_path = onnx_path.replace(".onnx", "_optimized.onnx") | |
| # Optimize model | |
| opt_model = optimizer.optimize_model( | |
| onnx_path, | |
| model_type="bert", # Similar architecture | |
| num_heads=12, | |
| hidden_size=768 | |
| ) | |
| opt_model.save_model_to_file(optimized_path) | |
| print(f"Optimized model saved to: {optimized_path}") | |
| return optimized_path | |
| except ImportError: | |
| print("ONNX Runtime tools not available for optimization") | |
| return onnx_path | |
| except Exception as e: | |
| print(f"Optimization error: {e}") | |
| return onnx_path | |
| def compare_models(original_converter, onnx_inference, test_audio_path: str): | |
| """Compare PyTorch and ONNX model outputs""" | |
| print("Comparing PyTorch vs ONNX outputs...") | |
| # PyTorch inference | |
| torch_result = original_converter.transcribe_to_characters(test_audio_path) | |
| # ONNX inference | |
| onnx_result = onnx_inference.transcribe(test_audio_path) | |
| print(f"PyTorch transcription: {torch_result['character_transcript']}") | |
| print(f"ONNX transcription: {onnx_result['transcription']}") | |
| # Compare similarity | |
| if torch_result['character_transcript'] == onnx_result['transcription']: | |
| print("✓ Transcriptions match exactly!") | |
| else: | |
| print("⚠ Transcriptions differ") | |
| if __name__ == "__main__": | |
| main() |