|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Feature extractor class for CED. |
|
""" |
|
|
|
from typing import Optional, Union |
|
|
|
import numpy as np |
|
import torch |
|
import torchaudio.transforms as audio_transforms |
|
|
|
from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor |
|
from transformers.feature_extraction_utils import BatchFeature |
|
from transformers.utils import logging |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class CedFeatureExtractor(SequenceFeatureExtractor): |
|
r""" |
|
CedFeatureExtractor extracts Mel spectrogram features from audio signals. |
|
|
|
Args: |
|
f_min (int, *optional*, defaults to 0): Minimum frequency for the Mel filterbank. |
|
sampling_rate (int, *optional*, defaults to 16000): |
|
Sampling rate of the input audio signal. |
|
win_size (int, *optional*, defaults to 512): Window size for the STFT. |
|
center (bool, *optional*, defaults to `True`): |
|
Whether to pad the signal on both sides to center it. |
|
n_fft (int, *optional*, defaults to 512): Number of FFT points for the STFT. |
|
f_max (int, optional, *optional*): Maximum frequency for the Mel filterbank. |
|
hop_size (int, *optional*, defaults to 160): Hop size for the STFT. |
|
feature_size (int, *optional*, defaults to 64): Number of Mel bands to generate. |
|
padding_value (float, *optional*, defaults to 0.0): Value for padding. |
|
|
|
Returns: |
|
BatchFeature: A BatchFeature object containing the extracted features. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
f_min: int = 0, |
|
sampling_rate: int = 16000, |
|
win_size: int = 512, |
|
center: bool = True, |
|
n_fft: int = 512, |
|
f_max: Optional[int] = None, |
|
hop_size: int = 160, |
|
feature_size: int = 64, |
|
padding_value: float = 0.0, |
|
**kwargs, |
|
): |
|
super().__init__( |
|
feature_size=feature_size, |
|
sampling_rate=sampling_rate, |
|
padding_value=padding_value, |
|
**kwargs, |
|
) |
|
self.f_min = f_min |
|
self.win_size = win_size |
|
self.center = center |
|
self.n_fft = n_fft |
|
self.f_max = f_max |
|
self.hop_size = hop_size |
|
|
|
def __call__( |
|
self, |
|
x: Union[np.ndarray, torch.Tensor], |
|
sampling_rate: Optional[int] = None, |
|
return_tensors="pt", |
|
) -> BatchFeature: |
|
r""" |
|
Extracts Mel spectrogram features from an audio signal tensor. |
|
|
|
Args: |
|
x: Input audio signal tensor. |
|
|
|
Returns: |
|
BatchFeature: A dictionary containing the extracted features. |
|
""" |
|
if sampling_rate is None: |
|
sampling_rate = self.sampling_rate |
|
|
|
if return_tensors != "pt": |
|
raise NotImplementedError( |
|
"Only return_tensors='pt' is currently supported." |
|
) |
|
|
|
mel_spectrogram = audio_transforms.MelSpectrogram( |
|
f_min=self.f_min, |
|
sample_rate=sampling_rate, |
|
win_length=self.win_size, |
|
center=self.center, |
|
n_fft=self.n_fft, |
|
f_max=self.f_max, |
|
hop_length=self.hop_size, |
|
n_mels=self.feature_size, |
|
) |
|
amplitude_to_db = audio_transforms.AmplitudeToDB(top_db=120) |
|
|
|
x = torch.from_numpy(x).float() if isinstance(x, np.ndarray) else x.float() |
|
if x.dim() == 1: |
|
x = x.unsqueeze(0) |
|
|
|
x = mel_spectrogram(x) |
|
x = amplitude_to_db(x) |
|
return BatchFeature({"input_values": x}) |
|
|