Simple-KWS / audio_processor.py
IvanLayer7's picture
Upload audio_processor.py
c518e1d verified
"""
Audio processing module for zero-shot keyword spotting.
Handles audio loading, preprocessing, and feature extraction.
"""
import librosa
import numpy as np
import torch
from typing import Union, Tuple
import warnings
warnings.filterwarnings("ignore")
class AudioProcessor:
"""Handles audio preprocessing for the keyword spotting model."""
def __init__(self, target_sample_rate: int = 48000, max_duration: float = 30.0):
"""
Initialize the audio processor.
Args:
target_sample_rate: Target sampling rate for audio processing
max_duration: Maximum audio duration in seconds
"""
self.target_sample_rate = target_sample_rate
self.max_duration = max_duration
self.max_samples = int(target_sample_rate * max_duration)
def load_audio(self, audio_path: str) -> Tuple[np.ndarray, int]:
"""
Load audio file and return waveform and sample rate.
Args:
audio_path: Path to the audio file
Returns:
Tuple of (waveform, sample_rate)
"""
try:
# Use librosa for robust audio loading
waveform, sr = librosa.load(audio_path, sr=None)
return waveform, sr
except Exception as e:
raise ValueError(f"Error loading audio file: {str(e)}")
def preprocess_audio(self, waveform: np.ndarray, sample_rate: int) -> torch.Tensor:
"""
Preprocess audio waveform for model input.
Args:
waveform: Audio waveform as numpy array
sample_rate: Original sample rate
Returns:
Preprocessed audio tensor
"""
# Convert to float32 if needed
if waveform.dtype != np.float32:
waveform = waveform.astype(np.float32)
# Resample if necessary
if sample_rate != self.target_sample_rate:
waveform = librosa.resample(
waveform,
orig_sr=sample_rate,
target_sr=self.target_sample_rate
)
# Ensure mono audio
if len(waveform.shape) > 1:
waveform = librosa.to_mono(waveform)
# Trim or pad to max duration
if len(waveform) > self.max_samples:
# Trim to max duration
waveform = waveform[:self.max_samples]
elif len(waveform) < self.max_samples:
# Pad with zeros
padding = self.max_samples - len(waveform)
waveform = np.pad(waveform, (0, padding), mode='constant', constant_values=0)
# Normalize audio
waveform = self._normalize_audio(waveform)
# Convert to tensor
audio_tensor = torch.from_numpy(waveform).float()
return audio_tensor
def _normalize_audio(self, waveform: np.ndarray) -> np.ndarray:
"""
Normalize audio waveform.
Args:
waveform: Input waveform
Returns:
Normalized waveform
"""
# RMS normalization
rms = np.sqrt(np.mean(waveform**2))
if rms > 0:
waveform = waveform / (rms * 10) # Scale down to prevent clipping
# Clip to [-1, 1] range
waveform = np.clip(waveform, -1.0, 1.0)
return waveform
def process_audio_file(self, audio_path: str) -> torch.Tensor:
"""
Complete audio processing pipeline from file to tensor.
Args:
audio_path: Path to audio file
Returns:
Preprocessed audio tensor ready for model input
"""
waveform, sample_rate = self.load_audio(audio_path)
processed_audio = self.preprocess_audio(waveform, sample_rate)
return processed_audio
def process_audio_array(self, audio_array: np.ndarray, sample_rate: int) -> torch.Tensor:
"""
Process audio from numpy array (e.g., from Gradio microphone input).
Args:
audio_array: Audio data as numpy array
sample_rate: Sample rate of the audio
Returns:
Preprocessed audio tensor
"""
return self.preprocess_audio(audio_array, sample_rate)