|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass |
|
from typing import Union |
|
|
|
import numpy as np |
|
import torch |
|
import torchaudio |
|
from lhotse.features.base import FeatureExtractor, register_extractor |
|
from lhotse.utils import Seconds, compute_num_frames |
|
|
|
|
|
@dataclass |
|
class VocosFbankConfig: |
|
sampling_rate: int = 24000 |
|
n_mels: int = 100 |
|
n_fft: int = 1024 |
|
hop_length: int = 256 |
|
|
|
|
|
@register_extractor |
|
class VocosFbank(FeatureExtractor): |
|
|
|
name = "VocosFbank" |
|
config_type = VocosFbankConfig |
|
|
|
def __init__(self, num_channels: int = 1): |
|
config = VocosFbankConfig |
|
super().__init__(config=config) |
|
assert num_channels in (1, 2) |
|
self.num_channels = num_channels |
|
self.fbank = torchaudio.transforms.MelSpectrogram( |
|
sample_rate=self.config.sampling_rate, |
|
n_fft=self.config.n_fft, |
|
hop_length=self.config.hop_length, |
|
n_mels=self.config.n_mels, |
|
center=True, |
|
power=1, |
|
) |
|
|
|
def _feature_fn(self, sample): |
|
mel = self.fbank(sample) |
|
logmel = mel.clamp(min=1e-7).log() |
|
|
|
return logmel |
|
|
|
@property |
|
def device(self) -> Union[str, torch.device]: |
|
return self.config.device |
|
|
|
def feature_dim(self, sampling_rate: int) -> int: |
|
return self.config.n_mels |
|
|
|
def extract( |
|
self, |
|
samples: Union[np.ndarray, torch.Tensor], |
|
sampling_rate: int, |
|
) -> Union[np.ndarray, torch.Tensor]: |
|
|
|
expected_sr = self.config.sampling_rate |
|
assert sampling_rate == expected_sr, ( |
|
f"Mismatched sampling rate: extractor expects {expected_sr}, " |
|
f"got {sampling_rate}" |
|
) |
|
is_numpy = False |
|
if not isinstance(samples, torch.Tensor): |
|
samples = torch.from_numpy(samples) |
|
is_numpy = True |
|
|
|
if len(samples.shape) == 1: |
|
samples = samples.unsqueeze(0) |
|
else: |
|
assert samples.ndim == 2, samples.shape |
|
|
|
if self.num_channels == 1: |
|
if samples.shape[0] == 2: |
|
samples = samples.mean(dim=0, keepdims=True) |
|
else: |
|
assert samples.shape[0] == 2, samples.shape |
|
|
|
mel = self._feature_fn(samples) |
|
|
|
mel = mel.reshape(-1, mel.shape[-1]).t() |
|
|
|
|
|
num_frames = compute_num_frames( |
|
samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate |
|
) |
|
|
|
if mel.shape[0] > num_frames: |
|
mel = mel[:num_frames] |
|
elif mel.shape[0] < num_frames: |
|
mel = mel.unsqueeze(0) |
|
mel = torch.nn.functional.pad( |
|
mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate" |
|
).squeeze(0) |
|
|
|
if is_numpy: |
|
return mel.cpu().numpy() |
|
else: |
|
return mel |
|
|
|
@property |
|
def frame_shift(self) -> Seconds: |
|
return self.config.hop_length / self.config.sampling_rate |
|
|