import torch from transformers import ( Wav2Vec2ForCTC, Wav2Vec2Processor, AutoProcessor, AutoModelForCTC, ) # import deepspeed import librosa import numpy as np from typing import Optional, List, Union def get_model_name(model_name: Optional[str] = None) -> str: """Helper function to get model name with default fallback""" if model_name is None: return "facebook/wav2vec2-large-robust-ft-libri-960h" return model_name class Wave2Vec2Inference: def __init__( self, model_name: Optional[str] = None, use_gpu: bool = True, use_deepspeed: bool = True, ): """ Initialize Wav2Vec2 model for inference with optional DeepSpeed optimization. Args: model_name: HuggingFace model name or None for default use_gpu: Whether to use GPU acceleration use_deepspeed: Whether to use DeepSpeed optimization """ # Get the actual model name using helper function self.model_name = get_model_name(model_name) self.use_deepspeed = use_deepspeed # Auto-detect device if use_gpu: if torch.backends.mps.is_available(): self.device = "mps" elif torch.cuda.is_available(): self.device = "cuda" else: self.device = "cpu" else: self.device = "cpu" print(f"Using device: {self.device}") print(f"Loading model: {self.model_name}") print(f"DeepSpeed enabled: {self.use_deepspeed}") # Check if model is XLSR and use appropriate processor/model is_xlsr = "xlsr" in self.model_name.lower() if is_xlsr: print("Using Wav2Vec2Processor and Wav2Vec2ForCTC for XLSR model") self.processor = Wav2Vec2Processor.from_pretrained(self.model_name) self.model = Wav2Vec2ForCTC.from_pretrained(self.model_name) else: print("Using AutoProcessor and AutoModelForCTC") self.processor = AutoProcessor.from_pretrained(self.model_name) self.model = AutoModelForCTC.from_pretrained(self.model_name) # Initialize DeepSpeed if enabled if self.use_deepspeed: self._init_deepspeed() else: self.model.to(self.device) self.model.eval() self.ds_engine = None # Disable gradients for inference torch.set_grad_enabled(False) def _init_deepspeed(self): """Initialize DeepSpeed inference engine""" try: # DeepSpeed configuration based on device if self.device == "cuda": ds_config = { "tensor_parallel": {"tp_size": 1}, "dtype": torch.float32, "replace_with_kernel_inject": True, "enable_cuda_graph": False, } else: ds_config = { "tensor_parallel": {"tp_size": 1}, "dtype": torch.float32, "replace_with_kernel_inject": False, "enable_cuda_graph": False, } print("Initializing DeepSpeed inference engine...") self.ds_engine = deepspeed.init_inference(self.model, **ds_config) self.ds_engine.module.to(self.device) except Exception as e: print(f"DeepSpeed initialization failed: {e}") print("Falling back to standard PyTorch inference...") self.use_deepspeed = False self.ds_engine = None self.model.to(self.device) self.model.eval() def _get_model(self): """Get the appropriate model for inference""" if self.use_deepspeed and self.ds_engine is not None: return self.ds_engine.module return self.model def buffer_to_text( self, audio_buffer: Union[np.ndarray, torch.Tensor, List] ) -> str: """ Convert audio buffer to text transcription. Args: audio_buffer: Audio data as numpy array, tensor, or list Returns: str: Transcribed text """ if len(audio_buffer) == 0: return "" # Convert to tensor if isinstance(audio_buffer, np.ndarray): audio_tensor = torch.from_numpy(audio_buffer).float() elif isinstance(audio_buffer, list): audio_tensor = torch.tensor(audio_buffer, dtype=torch.float32) else: audio_tensor = audio_buffer.float() # Process audio inputs = self.processor( audio_tensor, sampling_rate=16_000, return_tensors="pt", padding=True, ) # Move to device input_values = inputs.input_values.to(self.device) attention_mask = ( inputs.attention_mask.to(self.device) if "attention_mask" in inputs else None ) # Get the appropriate model model = self._get_model() # Inference with torch.no_grad(): if attention_mask is not None: outputs = model(input_values, attention_mask=attention_mask) else: outputs = model(input_values) # Handle different output formats if hasattr(outputs, "logits"): logits = outputs.logits else: logits = outputs # Decode predicted_ids = torch.argmax(logits, dim=-1) if self.device != "cpu": predicted_ids = predicted_ids.cpu() transcription = self.processor.batch_decode(predicted_ids)[0] return transcription.lower().strip() def file_to_text(self, filename: str) -> str: """ Transcribe audio file to text. Args: filename: Path to audio file Returns: str: Transcribed text """ try: audio_input, _ = librosa.load(filename, sr=16000, dtype=np.float32) return self.buffer_to_text(audio_input) except Exception as e: print(f"Error loading audio file {filename}: {e}") return "" def batch_file_to_text(self, filenames: List[str]) -> List[str]: """ Transcribe multiple audio files to text. Args: filenames: List of audio file paths Returns: List[str]: List of transcribed texts """ results = [] for i, filename in enumerate(filenames): print(f"Processing file {i+1}/{len(filenames)}: {filename}") transcription = self.file_to_text(filename) results.append(transcription) if transcription: print(f"Transcription: {transcription}") else: print("Failed to transcribe") return results def transcribe_with_confidence( self, audio_buffer: Union[np.ndarray, torch.Tensor] ) -> tuple: """ Transcribe audio and return confidence scores. Args: audio_buffer: Audio data Returns: tuple: (transcription, confidence_scores) """ if len(audio_buffer) == 0: return "", [] # Convert to tensor if isinstance(audio_buffer, np.ndarray): audio_tensor = torch.from_numpy(audio_buffer).float() else: audio_tensor = audio_buffer.float() # Process audio inputs = self.processor( audio_tensor, sampling_rate=16_000, return_tensors="pt", padding=True, ) input_values = inputs.input_values.to(self.device) attention_mask = ( inputs.attention_mask.to(self.device) if "attention_mask" in inputs else None ) model = self._get_model() # Inference with torch.no_grad(): if attention_mask is not None: outputs = model(input_values, attention_mask=attention_mask) else: outputs = model(input_values) if hasattr(outputs, "logits"): logits = outputs.logits else: logits = outputs # Get probabilities and confidence scores probs = torch.nn.functional.softmax(logits, dim=-1) predicted_ids = torch.argmax(logits, dim=-1) # Calculate confidence as max probability for each prediction max_probs = torch.max(probs, dim=-1)[0] confidence_scores = max_probs.cpu().numpy().tolist() if self.device != "cpu": predicted_ids = predicted_ids.cpu() transcription = self.processor.batch_decode(predicted_ids)[0] return transcription.lower().strip(), confidence_scores def cleanup(self): """Clean up resources""" if hasattr(self, "ds_engine") and self.ds_engine is not None: del self.ds_engine if hasattr(self, "model"): del self.model if hasattr(self, "processor"): del self.processor torch.cuda.empty_cache() if torch.cuda.is_available() else None def __del__(self): """Destructor to clean up resources""" self.cleanup() # Example usage if __name__ == "__main__": # Initialize with DeepSpeed asr = Wave2Vec2Inference( model_name="facebook/wav2vec2-large-robust-ft-libri-960h", use_gpu=False, use_deepspeed=False, ) # Single file transcription result = asr.file_to_text("./test_audio/hello_how_are_you_today.wav") print(f"Transcription: {result}") # # Batch processing # files = ["audio1.wav", "audio2.wav", "audio3.wav"] # batch_results = asr.batch_file_to_text(files) # # Transcription with confidence scores # audio_data, _ = librosa.load("path/to/audio.wav", sr=16000) # transcription, confidence = asr.transcribe_with_confidence(audio_data) # print(f"Transcription: {transcription}") # print(f"Average confidence: {np.mean(confidence):.3f}") # Cleanup