ced-base / ced_model /feature_extraction_ced.py
jimbozhang's picture
Update 3 files
2e9a543 verified
raw history blame
No virus
6.36 kB
# coding=utf-8
# Copyright 2023 Xiaomi Corporation and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Feature extractor class for CED.
"""
from typing import List, Optional, Union
import numpy as np
import torch
import torchaudio.transforms as audio_transforms
from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
from transformers.feature_extraction_utils import BatchFeature
from transformers.utils import logging
logger = logging.get_logger(__name__)
class CedFeatureExtractor(SequenceFeatureExtractor):
r"""
CedFeatureExtractor extracts Mel spectrogram features from audio signals.
Args:
f_min (int, *optional*, defaults to 0): Minimum frequency for the Mel filterbank.
sampling_rate (int, *optional*, defaults to 16000):
Sampling rate of the input audio signal.
win_size (int, *optional*, defaults to 512): Window size for the STFT.
center (bool, *optional*, defaults to `True`):
Whether to pad the signal on both sides to center it.
n_fft (int, *optional*, defaults to 512): Number of FFT points for the STFT.
f_max (int, optional, *optional*): Maximum frequency for the Mel filterbank.
hop_size (int, *optional*, defaults to 160): Hop size for the STFT.
feature_size (int, *optional*, defaults to 64): Number of Mel bands to generate.
padding_value (float, *optional*, defaults to 0.0): Value for padding.
Returns:
BatchFeature: A BatchFeature object containing the extracted features.
"""
def __init__(
self,
f_min: int = 0,
sampling_rate: int = 16000,
win_size: int = 512,
center: bool = True,
n_fft: int = 512,
f_max: Optional[int] = None,
hop_size: int = 160,
feature_size: int = 64,
padding_value: float = 0.0,
**kwargs,
):
super().__init__(
feature_size=feature_size,
sampling_rate=sampling_rate,
padding_value=padding_value,
**kwargs,
)
self.f_min = f_min
self.win_size = win_size
self.center = center
self.n_fft = n_fft
self.f_max = f_max
self.hop_size = hop_size
self.model_input_names = ["input_values"]
def __call__(
self,
x: Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]],
sampling_rate: Optional[int] = None,
max_length: Optional[int] = 16000,
truncation: bool = True,
return_tensors="pt",
) -> BatchFeature:
r"""
Extracts Mel spectrogram features from an audio signal tensor.
Args:
x: Input audio signal tensor.
sampling_rate (int, *optional*, defaults to `None`):
Sampling rate of the input audio signal.
max_length (int, *optional*, defaults to 16000):
Maximum length of the input audio signal.
truncation (bool, *optional*, defaults to `True`):
Whether to truncate the input signal to max_length.
return_tensors (str, *optional*, defaults to "pt"):
If set to "pt", the return type will be a PyTorch tensor.
Returns:
BatchFeature: A dictionary containing the extracted features.
"""
if sampling_rate is None:
sampling_rate = self.sampling_rate
if return_tensors != "pt":
raise NotImplementedError("Only return_tensors='pt' is currently supported.")
mel_spectrogram = audio_transforms.MelSpectrogram(
f_min=self.f_min,
sample_rate=sampling_rate,
win_length=self.win_size,
center=self.center,
n_fft=self.n_fft,
f_max=self.f_max,
hop_length=self.hop_size,
n_mels=self.feature_size,
)
amplitude_to_db = audio_transforms.AmplitudeToDB(top_db=120)
if isinstance(x, np.ndarray):
if x.ndim == 1:
x = x[np.newaxis, :]
if x.ndim != 2:
raise ValueError("np.ndarray input must be a 1D or 2D.")
x = torch.from_numpy(x)
elif isinstance(x, torch.Tensor):
if x.dim() == 1:
x = x.unsqueeze(0)
if x.dim() != 2:
raise ValueError("torch.Tensor input must be a 1D or 2D.")
elif isinstance(x, (list, tuple)):
longest_length = max(x_.shape[0] for x_ in x)
if not truncation and max_length < longest_length:
max_length = longest_length
if all(isinstance(x_, np.ndarray) for x_ in x):
if not all(x_.ndim == 1 for x_ in x):
raise ValueError("All np.ndarray in a list must be 1D.")
x_trim = [x_[:max_length] for x_ in x]
x_pad = [np.pad(x_, (0, max_length - x_.shape[0]), mode="constant", constant_values=0) for x_ in x_trim]
x = torch.stack([torch.from_numpy(x_) for x_ in x_pad])
elif all(isinstance(x_, torch.Tensor) for x_ in x):
if not all(x_.dim() == 1 for x_ in x):
raise ValueError("All torch.Tensor in a list must be 1D.")
x_pad = [torch.nn.functional.pad(x_, (0, max_length - x_.shape[0]), value=0) for x_ in x]
x = torch.stack(x_pad)
else:
raise ValueError("Input list must be numpy arrays or PyTorch tensors.")
else:
raise ValueError(
"Input must be a numpy array, a list of numpy arrays, a PyTorch tensor, or a list of PyTorch tensor."
)
x = x.float()
x = mel_spectrogram(x)
x = amplitude_to_db(x)
return BatchFeature({"input_values": x})