| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import math |
| import random |
| from abc import ABC, abstractmethod |
| from dataclasses import dataclass |
| from typing import Any, Dict, Optional, Tuple |
|
|
| import torch |
| from packaging import version |
|
|
| from nemo.collections.asr.parts.numba.spec_augment import SpecAugmentNumba, spec_augment_launch_heuristics |
| from nemo.collections.asr.parts.preprocessing.features import ( |
| FilterbankFeatures, |
| FilterbankFeaturesTA, |
| make_seq_mask_like, |
| ) |
| from nemo.collections.asr.parts.submodules.spectr_augment import SpecAugment, SpecCutout |
| from nemo.core.classes import Exportable, NeuralModule, typecheck |
| from nemo.core.neural_types import ( |
| AudioSignal, |
| LengthsType, |
| MelSpectrogramType, |
| MFCCSpectrogramType, |
| NeuralType, |
| SpectrogramType, |
| ) |
| from nemo.core.utils import numba_utils |
| from nemo.core.utils.numba_utils import __NUMBA_MINIMUM_VERSION__ |
| from nemo.utils import logging |
|
|
| try: |
| import torchaudio |
| import torchaudio.functional |
| import torchaudio.transforms |
|
|
| TORCHAUDIO_VERSION = version.parse(torchaudio.__version__) |
| TORCHAUDIO_VERSION_MIN = version.parse('0.5') |
|
|
| HAVE_TORCHAUDIO = True |
| except ModuleNotFoundError: |
| HAVE_TORCHAUDIO = False |
|
|
| __all__ = [ |
| 'AudioToMelSpectrogramPreprocessor', |
| 'AudioToSpectrogram', |
| 'SpectrogramToAudio', |
| 'AudioToMFCCPreprocessor', |
| 'SpectrogramAugmentation', |
| 'MaskedPatchAugmentation', |
| 'CropOrPadSpectrogramAugmentation', |
| ] |
|
|
|
|
| class AudioPreprocessor(NeuralModule, ABC): |
| """ |
| An interface for Neural Modules that performs audio pre-processing, |
| transforming the wav files to features. |
| """ |
|
|
| def __init__(self, win_length, hop_length): |
| super().__init__() |
|
|
| self.win_length = win_length |
| self.hop_length = hop_length |
|
|
| self.torch_windows = { |
| 'hann': torch.hann_window, |
| 'hamming': torch.hamming_window, |
| 'blackman': torch.blackman_window, |
| 'bartlett': torch.bartlett_window, |
| 'ones': torch.ones, |
| None: torch.ones, |
| } |
|
|
| @typecheck() |
| @torch.no_grad() |
| def forward(self, input_signal, length): |
| processed_signal, processed_length = self.get_features(input_signal, length) |
|
|
| return processed_signal, processed_length |
|
|
| @abstractmethod |
| def get_features(self, input_signal, length): |
| |
| pass |
|
|
|
|
| class AudioToMelSpectrogramPreprocessor(AudioPreprocessor, Exportable): |
| """Featurizer module that converts wavs to mel spectrograms. |
| |
| Args: |
| sample_rate (int): Sample rate of the input audio data. |
| Defaults to 16000 |
| window_size (float): Size of window for fft in seconds |
| Defaults to 0.02 |
| window_stride (float): Stride of window for fft in seconds |
| Defaults to 0.01 |
| n_window_size (int): Size of window for fft in samples |
| Defaults to None. Use one of window_size or n_window_size. |
| n_window_stride (int): Stride of window for fft in samples |
| Defaults to None. Use one of window_stride or n_window_stride. |
| window (str): Windowing function for fft. can be one of ['hann', |
| 'hamming', 'blackman', 'bartlett'] |
| Defaults to "hann" |
| normalize (str): Can be one of ['per_feature', 'all_features']; all |
| other options disable feature normalization. 'all_features' |
| normalizes the entire spectrogram to be mean 0 with std 1. |
| 'pre_features' normalizes per channel / freq instead. |
| Defaults to "per_feature" |
| n_fft (int): Length of FT window. If None, it uses the smallest power |
| of 2 that is larger than n_window_size. |
| Defaults to None |
| preemph (float): Amount of pre emphasis to add to audio. Can be |
| disabled by passing None. |
| Defaults to 0.97 |
| features (int): Number of mel spectrogram freq bins to output. |
| Defaults to 64 |
| lowfreq (int): Lower bound on mel basis in Hz. |
| Defaults to 0 |
| highfreq (int): Lower bound on mel basis in Hz. |
| Defaults to None |
| log (bool): Log features. |
| Defaults to True |
| log_zero_guard_type(str): Need to avoid taking the log of zero. There |
| are two options: "add" or "clamp". |
| Defaults to "add". |
| log_zero_guard_value(float, or str): Add or clamp requires the number |
| to add with or clamp to. log_zero_guard_value can either be a float |
| or "tiny" or "eps". torch.finfo is used if "tiny" or "eps" is |
| passed. |
| Defaults to 2**-24. |
| dither (float): Amount of white-noise dithering. |
| Defaults to 1e-5 |
| pad_to (int): Ensures that the output size of the time dimension is |
| a multiple of pad_to. |
| Defaults to 16 |
| frame_splicing (int): Defaults to 1 |
| exact_pad (bool): If True, sets stft center to False and adds padding, such that num_frames = audio_length |
| // hop_length. Defaults to False. |
| pad_value (float): The value that shorter mels are padded with. |
| Defaults to 0 |
| mag_power (float): The power that the linear spectrogram is raised to |
| prior to multiplication with mel basis. |
| Defaults to 2 for a power spec |
| rng : Random number generator |
| nb_augmentation_prob (float) : Probability with which narrowband augmentation would be applied to |
| samples in the batch. |
| Defaults to 0.0 |
| nb_max_freq (int) : Frequency above which all frequencies will be masked for narrowband augmentation. |
| Defaults to 4000 |
| use_torchaudio: Whether to use the `torchaudio` implementation. |
| mel_norm: Normalization used for mel filterbank weights. |
| Defaults to 'slaney' (area normalization) |
| stft_exact_pad: Deprecated argument, kept for compatibility with older checkpoints. |
| stft_conv: Deprecated argument, kept for compatibility with older checkpoints. |
| """ |
|
|
| def save_to(self, save_path: str): |
| pass |
|
|
| @classmethod |
| def restore_from(cls, restore_path: str): |
| pass |
|
|
| @property |
| def input_types(self): |
| """Returns definitions of module input ports. |
| """ |
| return { |
| "input_signal": NeuralType(('B', 'T'), AudioSignal(freq=self._sample_rate)), |
| "length": NeuralType( |
| tuple('B'), LengthsType() |
| ), |
| } |
|
|
| @property |
| def output_types(self): |
| """Returns definitions of module output ports. |
| |
| processed_signal: |
| 0: AxisType(BatchTag) |
| 1: AxisType(MelSpectrogramSignalTag) |
| 2: AxisType(ProcessedTimeTag) |
| processed_length: |
| 0: AxisType(BatchTag) |
| """ |
| return { |
| "processed_signal": NeuralType(('B', 'D', 'T'), MelSpectrogramType()), |
| "processed_length": NeuralType(tuple('B'), LengthsType()), |
| } |
|
|
| def __init__( |
| self, |
| sample_rate=16000, |
| window_size=0.02, |
| window_stride=0.01, |
| n_window_size=None, |
| n_window_stride=None, |
| window="hann", |
| normalize="per_feature", |
| n_fft=None, |
| preemph=0.97, |
| features=64, |
| lowfreq=0, |
| highfreq=None, |
| log=True, |
| log_zero_guard_type="add", |
| log_zero_guard_value=2 ** -24, |
| dither=1e-5, |
| pad_to=16, |
| frame_splicing=1, |
| exact_pad=False, |
| pad_value=0, |
| mag_power=2.0, |
| rng=None, |
| nb_augmentation_prob=0.0, |
| nb_max_freq=4000, |
| use_torchaudio: bool = False, |
| mel_norm="slaney", |
| stft_exact_pad=False, |
| stft_conv=False, |
| ): |
| super().__init__(n_window_size, n_window_stride) |
|
|
| self._sample_rate = sample_rate |
| if window_size and n_window_size: |
| raise ValueError(f"{self} received both window_size and " f"n_window_size. Only one should be specified.") |
| if window_stride and n_window_stride: |
| raise ValueError( |
| f"{self} received both window_stride and " f"n_window_stride. Only one should be specified." |
| ) |
| if window_size: |
| n_window_size = int(window_size * self._sample_rate) |
| if window_stride: |
| n_window_stride = int(window_stride * self._sample_rate) |
|
|
| |
| if not use_torchaudio: |
| featurizer_class = FilterbankFeatures |
| else: |
| featurizer_class = FilterbankFeaturesTA |
| self.featurizer = featurizer_class( |
| sample_rate=self._sample_rate, |
| n_window_size=n_window_size, |
| n_window_stride=n_window_stride, |
| window=window, |
| normalize=normalize, |
| n_fft=n_fft, |
| preemph=preemph, |
| nfilt=features, |
| lowfreq=lowfreq, |
| highfreq=highfreq, |
| log=log, |
| log_zero_guard_type=log_zero_guard_type, |
| log_zero_guard_value=log_zero_guard_value, |
| dither=dither, |
| pad_to=pad_to, |
| frame_splicing=frame_splicing, |
| exact_pad=exact_pad, |
| pad_value=pad_value, |
| mag_power=mag_power, |
| rng=rng, |
| nb_augmentation_prob=nb_augmentation_prob, |
| nb_max_freq=nb_max_freq, |
| mel_norm=mel_norm, |
| stft_exact_pad=stft_exact_pad, |
| stft_conv=stft_conv, |
| ) |
|
|
| def input_example(self, max_batch: int = 8, max_dim: int = 32000, min_length: int = 200): |
| batch_size = torch.randint(low=1, high=max_batch, size=[1]).item() |
| max_length = torch.randint(low=min_length, high=max_dim, size=[1]).item() |
| signals = torch.rand(size=[batch_size, max_length]) * 2 - 1 |
| lengths = torch.randint(low=min_length, high=max_dim, size=[batch_size]) |
| lengths[0] = max_length |
| return signals, lengths |
|
|
| def get_features(self, input_signal, length): |
| return self.featurizer(input_signal, length) |
|
|
| @property |
| def filter_banks(self): |
| return self.featurizer.filter_banks |
|
|
|
|
| class AudioToMFCCPreprocessor(AudioPreprocessor): |
| """Preprocessor that converts wavs to MFCCs. |
| Uses torchaudio.transforms.MFCC. |
| |
| Args: |
| sample_rate: The sample rate of the audio. |
| Defaults to 16000. |
| window_size: Size of window for fft in seconds. Used to calculate the |
| win_length arg for mel spectrogram. |
| Defaults to 0.02 |
| window_stride: Stride of window for fft in seconds. Used to caculate |
| the hop_length arg for mel spect. |
| Defaults to 0.01 |
| n_window_size: Size of window for fft in samples |
| Defaults to None. Use one of window_size or n_window_size. |
| n_window_stride: Stride of window for fft in samples |
| Defaults to None. Use one of window_stride or n_window_stride. |
| window: Windowing function for fft. can be one of ['hann', |
| 'hamming', 'blackman', 'bartlett', 'none', 'null']. |
| Defaults to 'hann' |
| n_fft: Length of FT window. If None, it uses the smallest power of 2 |
| that is larger than n_window_size. |
| Defaults to None |
| lowfreq (int): Lower bound on mel basis in Hz. |
| Defaults to 0 |
| highfreq (int): Lower bound on mel basis in Hz. |
| Defaults to None |
| n_mels: Number of mel filterbanks. |
| Defaults to 64 |
| n_mfcc: Number of coefficients to retain |
| Defaults to 64 |
| dct_type: Type of discrete cosine transform to use |
| norm: Type of norm to use |
| log: Whether to use log-mel spectrograms instead of db-scaled. |
| Defaults to True. |
| """ |
|
|
| @property |
| def input_types(self): |
| """Returns definitions of module input ports. |
| """ |
| return { |
| "input_signal": NeuralType(('B', 'T'), AudioSignal(freq=self._sample_rate)), |
| "length": NeuralType(tuple('B'), LengthsType()), |
| } |
|
|
| @property |
| def output_types(self): |
| """Returns definitions of module output ports. |
| """ |
| return { |
| "processed_signal": NeuralType(('B', 'D', 'T'), MFCCSpectrogramType()), |
| "processed_length": NeuralType(tuple('B'), LengthsType()), |
| } |
|
|
| def save_to(self, save_path: str): |
| pass |
|
|
| @classmethod |
| def restore_from(cls, restore_path: str): |
| pass |
|
|
| def __init__( |
| self, |
| sample_rate=16000, |
| window_size=0.02, |
| window_stride=0.01, |
| n_window_size=None, |
| n_window_stride=None, |
| window='hann', |
| n_fft=None, |
| lowfreq=0.0, |
| highfreq=None, |
| n_mels=64, |
| n_mfcc=64, |
| dct_type=2, |
| norm='ortho', |
| log=True, |
| ): |
| self._sample_rate = sample_rate |
| if not HAVE_TORCHAUDIO: |
| logging.error('Could not import torchaudio. Some features might not work.') |
|
|
| raise ModuleNotFoundError( |
| "torchaudio is not installed but is necessary for " |
| "AudioToMFCCPreprocessor. We recommend you try " |
| "building it from source for the PyTorch version you have." |
| ) |
| if window_size and n_window_size: |
| raise ValueError(f"{self} received both window_size and " f"n_window_size. Only one should be specified.") |
| if window_stride and n_window_stride: |
| raise ValueError( |
| f"{self} received both window_stride and " f"n_window_stride. Only one should be specified." |
| ) |
| |
| if window_size: |
| n_window_size = int(window_size * self._sample_rate) |
| if window_stride: |
| n_window_stride = int(window_stride * self._sample_rate) |
|
|
| super().__init__(n_window_size, n_window_stride) |
|
|
| mel_kwargs = {} |
|
|
| mel_kwargs['f_min'] = lowfreq |
| mel_kwargs['f_max'] = highfreq |
| mel_kwargs['n_mels'] = n_mels |
|
|
| mel_kwargs['n_fft'] = n_fft or 2 ** math.ceil(math.log2(n_window_size)) |
|
|
| mel_kwargs['win_length'] = n_window_size |
| mel_kwargs['hop_length'] = n_window_stride |
|
|
| |
| window_fn = self.torch_windows.get(window, None) |
| if window_fn is None: |
| raise ValueError( |
| f"Window argument for AudioProcessor is invalid: {window}." |
| f"For no window function, use 'ones' or None." |
| ) |
| mel_kwargs['window_fn'] = window_fn |
|
|
| |
| self.featurizer = torchaudio.transforms.MFCC( |
| sample_rate=self._sample_rate, |
| n_mfcc=n_mfcc, |
| dct_type=dct_type, |
| norm=norm, |
| log_mels=log, |
| melkwargs=mel_kwargs, |
| ) |
|
|
| def get_features(self, input_signal, length): |
| features = self.featurizer(input_signal) |
| seq_len = torch.ceil(length.to(torch.float32) / self.hop_length).to(dtype=torch.long) |
| return features, seq_len |
|
|
|
|
| class SpectrogramAugmentation(NeuralModule): |
| """ |
| Performs time and freq cuts in one of two ways. |
| SpecAugment zeroes out vertical and horizontal sections as described in |
| SpecAugment (https://arxiv.org/abs/1904.08779). Arguments for use with |
| SpecAugment are `freq_masks`, `time_masks`, `freq_width`, and `time_width`. |
| SpecCutout zeroes out rectangulars as described in Cutout |
| (https://arxiv.org/abs/1708.04552). Arguments for use with Cutout are |
| `rect_masks`, `rect_freq`, and `rect_time`. |
| |
| Args: |
| freq_masks (int): how many frequency segments should be cut. |
| Defaults to 0. |
| time_masks (int): how many time segments should be cut |
| Defaults to 0. |
| freq_width (int): maximum number of frequencies to be cut in one |
| segment. |
| Defaults to 10. |
| time_width (int): maximum number of time steps to be cut in one |
| segment |
| Defaults to 10. |
| rect_masks (int): how many rectangular masks should be cut |
| Defaults to 0. |
| rect_freq (int): maximum size of cut rectangles along the frequency |
| dimension |
| Defaults to 5. |
| rect_time (int): maximum size of cut rectangles along the time |
| dimension |
| Defaults to 25. |
| """ |
|
|
| @property |
| def input_types(self): |
| """Returns definitions of module input types |
| """ |
| return { |
| "input_spec": NeuralType(('B', 'D', 'T'), SpectrogramType()), |
| "length": NeuralType(tuple('B'), LengthsType()), |
| } |
|
|
| @property |
| def output_types(self): |
| """Returns definitions of module output types |
| """ |
| return {"augmented_spec": NeuralType(('B', 'D', 'T'), SpectrogramType())} |
|
|
| def __init__( |
| self, |
| freq_masks=0, |
| time_masks=0, |
| freq_width=10, |
| time_width=10, |
| rect_masks=0, |
| rect_time=5, |
| rect_freq=20, |
| rng=None, |
| mask_value=0.0, |
| use_numba_spec_augment: bool = True, |
| ): |
| super().__init__() |
|
|
| if rect_masks > 0: |
| self.spec_cutout = SpecCutout(rect_masks=rect_masks, rect_time=rect_time, rect_freq=rect_freq, rng=rng,) |
| |
| else: |
| self.spec_cutout = lambda input_spec: input_spec |
| if freq_masks + time_masks > 0: |
| self.spec_augment = SpecAugment( |
| freq_masks=freq_masks, |
| time_masks=time_masks, |
| freq_width=freq_width, |
| time_width=time_width, |
| rng=rng, |
| mask_value=mask_value, |
| ) |
| else: |
| self.spec_augment = lambda input_spec, length: input_spec |
|
|
| |
| if use_numba_spec_augment and numba_utils.numba_cuda_is_supported(__NUMBA_MINIMUM_VERSION__): |
| logging.info('Numba CUDA SpecAugment kernel is being used') |
| self.spec_augment_numba = SpecAugmentNumba( |
| freq_masks=freq_masks, |
| time_masks=time_masks, |
| freq_width=freq_width, |
| time_width=time_width, |
| rng=rng, |
| mask_value=mask_value, |
| ) |
| else: |
| self.spec_augment_numba = None |
|
|
| @typecheck() |
| def forward(self, input_spec, length): |
| augmented_spec = self.spec_cutout(input_spec=input_spec) |
|
|
| |
| |
| if self.spec_augment_numba is not None and spec_augment_launch_heuristics(augmented_spec, length): |
| augmented_spec = self.spec_augment_numba(input_spec=augmented_spec, length=length) |
| else: |
| augmented_spec = self.spec_augment(input_spec=augmented_spec, length=length) |
| return augmented_spec |
|
|
|
|
| class MaskedPatchAugmentation(NeuralModule): |
| """ |
| Zeroes out fixed size time patches of the spectrogram. |
| All samples in batch are guaranteed to have the same amount of masked time steps. |
| Optionally also performs frequency masking in the same way as SpecAugment. |
| Args: |
| patch_size (int): up to how many time steps does one patch consist of. |
| Defaults to 48. |
| mask_patches (float): how many patches should be masked in each sample. |
| if >= 1., interpreted as number of patches (after converting to int) |
| if <1., interpreted as fraction of total tokens to be masked (number of patches is rounded up) |
| Defaults to 10. |
| freq_masks (int): how many frequency segments should be cut. |
| Defaults to 0. |
| freq_width (int): maximum number of frequencies to be cut in a segment. |
| Defaults to 0. |
| """ |
|
|
| @property |
| def input_types(self): |
| """Returns definitions of module input types |
| """ |
| return { |
| "input_spec": NeuralType(('B', 'D', 'T'), SpectrogramType()), |
| "length": NeuralType(tuple('B'), LengthsType()), |
| } |
|
|
| @property |
| def output_types(self): |
| """Returns definitions of module output types |
| """ |
| return {"augmented_spec": NeuralType(('B', 'D', 'T'), SpectrogramType())} |
|
|
| def __init__( |
| self, patch_size: int = 48, mask_patches: float = 10.0, freq_masks: int = 0, freq_width: int = 0, |
| ): |
| super().__init__() |
| self.patch_size = patch_size |
| if mask_patches >= 1: |
| self.mask_patches = int(mask_patches) |
| elif mask_patches >= 0: |
| self._mask_fraction = mask_patches |
| self.mask_patches = None |
| else: |
| raise ValueError('mask_patches cannot be negative') |
|
|
| if freq_masks > 0: |
| self.spec_augment = SpecAugment(freq_masks=freq_masks, time_masks=0, freq_width=freq_width, time_width=0,) |
| else: |
| self.spec_augment = None |
|
|
| @typecheck() |
| def forward(self, input_spec, length): |
| augmented_spec = input_spec |
|
|
| min_len = torch.min(length) |
|
|
| if self.mask_patches is None: |
| |
| len_fraction = int(min_len * self._mask_fraction) |
| mask_patches = len_fraction // self.patch_size + int(len_fraction % self.patch_size != 0) |
| else: |
| mask_patches = self.mask_patches |
|
|
| if min_len < self.patch_size * mask_patches: |
| mask_patches = min_len // self.patch_size |
|
|
| for idx in range(input_spec.shape[0]): |
| cur_len = length[idx] |
| patches = range(cur_len // self.patch_size - 1) |
| masked_patches = random.sample(patches, mask_patches) |
|
|
| for mp in masked_patches: |
| augmented_spec[idx, :, mp * self.patch_size : (mp + 1) * self.patch_size] = 0.0 |
|
|
| if self.spec_augment is not None: |
| augmented_spec = self.spec_augment(input_spec=augmented_spec, length=length) |
|
|
| return augmented_spec |
|
|
|
|
| class CropOrPadSpectrogramAugmentation(NeuralModule): |
| """ |
| Pad or Crop the incoming Spectrogram to a certain shape. |
| |
| Args: |
| audio_length (int): the final number of timesteps that is required. |
| The signal will be either padded or cropped temporally to this |
| size. |
| """ |
|
|
| def __init__(self, audio_length): |
| super(CropOrPadSpectrogramAugmentation, self).__init__() |
| self.audio_length = audio_length |
|
|
| @typecheck() |
| @torch.no_grad() |
| def forward(self, input_signal, length): |
| image = input_signal |
| num_images = image.shape[0] |
|
|
| audio_length = self.audio_length |
| image_len = image.shape[-1] |
|
|
| |
| if image_len > audio_length: |
| cutout_images = [] |
| offset = torch.randint(low=0, high=image_len - audio_length + 1, size=[num_images]) |
|
|
| for idx, offset in enumerate(offset): |
| cutout_images.append(image[idx : idx + 1, :, offset : offset + audio_length]) |
|
|
| image = torch.cat(cutout_images, dim=0) |
| del cutout_images |
|
|
| else: |
| pad_left = (audio_length - image_len) // 2 |
| pad_right = (audio_length - image_len) // 2 |
|
|
| if (audio_length - image_len) % 2 == 1: |
| pad_right += 1 |
|
|
| image = torch.nn.functional.pad(image, [pad_left, pad_right], mode="constant", value=0) |
|
|
| |
| length = (length * 0) + audio_length |
|
|
| return image, length |
|
|
| @property |
| def input_types(self): |
| """Returns definitions of module output ports. |
| """ |
| return { |
| "input_signal": NeuralType(('B', 'D', 'T'), SpectrogramType()), |
| "length": NeuralType(tuple('B'), LengthsType()), |
| } |
|
|
| @property |
| def output_types(self): |
| """Returns definitions of module output ports. |
| """ |
| return { |
| "processed_signal": NeuralType(('B', 'D', 'T'), SpectrogramType()), |
| "processed_length": NeuralType(tuple('B'), LengthsType()), |
| } |
|
|
| def save_to(self, save_path: str): |
| pass |
|
|
| @classmethod |
| def restore_from(cls, restore_path: str): |
| pass |
|
|
|
|
| class AudioToSpectrogram(NeuralModule): |
| """Transform a batch of input multi-channel signals into a batch of |
| STFT-based spectrograms. |
| |
| Args: |
| fft_length: length of FFT |
| hop_length: length of hops/shifts of the sliding window |
| power: exponent for magnitude spectrogram. Default `None` will |
| return a complex-valued spectrogram |
| """ |
|
|
| def __init__(self, fft_length: int, hop_length: int, power: Optional[float] = None): |
| if not HAVE_TORCHAUDIO: |
| logging.error('Could not import torchaudio. Some features might not work.') |
|
|
| raise ModuleNotFoundError( |
| "torchaudio is not installed but is necessary to instantiate a {self.__class__.__name__}" |
| ) |
|
|
| super().__init__() |
|
|
| |
| if fft_length % 2 != 0: |
| raise ValueError(f'fft_length = {fft_length} must be divisible by 2') |
|
|
| self.stft = torchaudio.transforms.Spectrogram( |
| n_fft=fft_length, hop_length=hop_length, power=power, pad_mode='constant' |
| ) |
|
|
| |
| self.F = fft_length // 2 + 1 |
|
|
| @property |
| def num_subbands(self) -> int: |
| return self.F |
|
|
| @property |
| def input_types(self) -> Dict[str, NeuralType]: |
| """Returns definitions of module output ports. |
| """ |
| return { |
| "input": NeuralType(('B', 'C', 'T'), AudioSignal()), |
| "input_length": NeuralType(('B',), LengthsType(), optional=True), |
| } |
|
|
| @property |
| def output_types(self) -> Dict[str, NeuralType]: |
| """Returns definitions of module output ports. |
| """ |
| return { |
| "output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), |
| "output_length": NeuralType(('B',), LengthsType()), |
| } |
|
|
| @typecheck() |
| def forward( |
| self, input: torch.Tensor, input_length: Optional[torch.Tensor] = None |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Convert a batch of C-channel input signals |
| into a batch of complex-valued spectrograms. |
| |
| Args: |
| input: Time-domain input signal with C channels, shape (B, C, T) |
| input_length: Length of valid entries along the time dimension, shape (B,) |
| |
| Returns: |
| Output spectrogram with F subbands and N time frames, shape (B, C, F, N) |
| and output length with shape (B,). |
| """ |
| B, T = input.size(0), input.size(-1) |
| input = input.view(B, -1, T) |
|
|
| |
| with torch.cuda.amp.autocast(enabled=False): |
| output = self.stft(input.float()) |
|
|
| if input_length is not None: |
| |
| output_length = self.get_output_length(input_length=input_length) |
|
|
| length_mask: torch.Tensor = make_seq_mask_like( |
| lengths=output_length, like=output, time_dim=-1, valid_ones=False |
| ) |
| output = output.masked_fill(length_mask, 0.0) |
| else: |
| |
| output_length = output.size(-1) * torch.ones(B, device=output.device).long() |
|
|
| return output, output_length |
|
|
| def get_output_length(self, input_length: torch.Tensor) -> torch.Tensor: |
| """Get length of valid frames for the output. |
| |
| Args: |
| input_length: number of valid samples, shape (B,) |
| |
| Returns: |
| Number of valid frames, shape (B,) |
| """ |
| output_length = input_length.div(self.stft.hop_length, rounding_mode='floor').add(1).long() |
| return output_length |
|
|
|
|
| class SpectrogramToAudio(NeuralModule): |
| """Transform a batch of input multi-channel spectrograms into a batch of |
| time-domain multi-channel signals. |
| |
| Args: |
| fft_length: length of FFT |
| hop_length: length of hops/shifts of the sliding window |
| power: exponent for magnitude spectrogram. Default `None` will |
| return a complex-valued spectrogram |
| """ |
|
|
| def __init__(self, fft_length: int, hop_length: int): |
| if not HAVE_TORCHAUDIO: |
| logging.error('Could not import torchaudio. Some features might not work.') |
|
|
| raise ModuleNotFoundError( |
| "torchaudio is not installed but is necessary to instantiate a {self.__class__.__name__}" |
| ) |
|
|
| super().__init__() |
|
|
| |
| if fft_length % 2 != 0: |
| raise ValueError(f'fft_length = {fft_length} must be divisible by 2') |
|
|
| self.istft = torchaudio.transforms.InverseSpectrogram( |
| n_fft=fft_length, hop_length=hop_length, pad_mode='constant' |
| ) |
|
|
| self.F = fft_length // 2 + 1 |
|
|
| @property |
| def num_subbands(self) -> int: |
| return self.F |
|
|
| @property |
| def input_types(self) -> Dict[str, NeuralType]: |
| """Returns definitions of module output ports. |
| """ |
| return { |
| "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), |
| "input_length": NeuralType(('B',), LengthsType(), optional=True), |
| } |
|
|
| @property |
| def output_types(self) -> Dict[str, NeuralType]: |
| """Returns definitions of module output ports. |
| """ |
| return { |
| "output": NeuralType(('B', 'C', 'T'), AudioSignal()), |
| "output_length": NeuralType(('B',), LengthsType()), |
| } |
|
|
| @typecheck() |
| def forward(self, input: torch.Tensor, input_length: Optional[torch.Tensor] = None) -> torch.Tensor: |
| """Convert input complex-valued spectrogram to a time-domain |
| signal. Multi-channel IO is supported. |
| |
| Args: |
| input: Input spectrogram for C channels, shape (B, C, F, N) |
| input_length: Length of valid entries along the time dimension, shape (B,) |
| |
| Returns: |
| Time-domain signal with T time-domain samples and C channels, (B, C, T) |
| and output length with shape (B,). |
| """ |
| B, F, N = input.size(0), input.size(-2), input.size(-1) |
| assert F == self.F, f'Number of subbands F={F} not matching self.F={self.F}' |
| input = input.view(B, -1, F, N) |
|
|
| |
| with torch.cuda.amp.autocast(enabled=False): |
| output = self.istft(input.cfloat()) |
|
|
| if input_length is not None: |
| |
| output_length = self.get_output_length(input_length=input_length) |
|
|
| length_mask: torch.Tensor = make_seq_mask_like( |
| lengths=output_length, like=output, time_dim=-1, valid_ones=False |
| ) |
| output = output.masked_fill(length_mask, 0.0) |
| else: |
| |
| output_length = output.size(-1) * torch.ones(B, device=output.device).long() |
|
|
| return output, output_length |
|
|
| def get_output_length(self, input_length: torch.Tensor) -> torch.Tensor: |
| """Get length of valid samples for the output. |
| |
| Args: |
| input_length: number of valid frames, shape (B,) |
| |
| Returns: |
| Number of valid samples, shape (B,) |
| """ |
| output_length = input_length.sub(1).mul(self.istft.hop_length).long() |
| return output_length |
|
|
|
|
| @dataclass |
| class AudioToMelSpectrogramPreprocessorConfig: |
| _target_: str = "nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor" |
| sample_rate: int = 16000 |
| window_size: float = 0.02 |
| window_stride: float = 0.01 |
| n_window_size: Optional[int] = None |
| n_window_stride: Optional[int] = None |
| window: str = "hann" |
| normalize: str = "per_feature" |
| n_fft: Optional[int] = None |
| preemph: float = 0.97 |
| features: int = 64 |
| lowfreq: int = 0 |
| highfreq: Optional[int] = None |
| log: bool = True |
| log_zero_guard_type: str = "add" |
| log_zero_guard_value: float = 2 ** -24 |
| dither: float = 1e-5 |
| pad_to: int = 16 |
| frame_splicing: int = 1 |
| exact_pad: bool = False |
| pad_value: int = 0 |
| mag_power: float = 2.0 |
| rng: Optional[str] = None |
| nb_augmentation_prob: float = 0.0 |
| nb_max_freq: int = 4000 |
| use_torchaudio: bool = False |
| mel_norm: str = "slaney" |
| stft_exact_pad: bool = False |
| stft_conv: bool = False |
|
|
|
|
| @dataclass |
| class AudioToMFCCPreprocessorConfig: |
| _target_: str = 'nemo.collections.asr.modules.AudioToMFCCPreprocessor' |
| sample_rate: int = 16000 |
| window_size: float = 0.02 |
| window_stride: float = 0.01 |
| n_window_size: Optional[int] = None |
| n_window_stride: Optional[int] = None |
| window: str = 'hann' |
| n_fft: Optional[int] = None |
| lowfreq: Optional[float] = 0.0 |
| highfreq: Optional[float] = None |
| n_mels: int = 64 |
| n_mfcc: int = 64 |
| dct_type: int = 2 |
| norm: str = 'ortho' |
| log: bool = True |
|
|
|
|
| @dataclass |
| class SpectrogramAugmentationConfig: |
| _target_: str = "nemo.collections.asr.modules.SpectrogramAugmentation" |
| freq_masks: int = 0 |
| time_masks: int = 0 |
| freq_width: int = 0 |
| time_width: Optional[Any] = 0 |
| rect_masks: int = 0 |
| rect_time: int = 0 |
| rect_freq: int = 0 |
| mask_value: float = 0 |
| rng: Optional[Any] = None |
| use_numba_spec_augment: bool = True |
|
|
|
|
| @dataclass |
| class CropOrPadSpectrogramAugmentationConfig: |
| audio_length: int |
| _target_: str = "nemo.collections.asr.modules.CropOrPadSpectrogramAugmentation" |
|
|
|
|
| @dataclass |
| class MaskedPatchAugmentationConfig: |
| patch_size: int = 48 |
| mask_patches: float = 10.0 |
| freq_masks: int = 0 |
| freq_width: int = 0 |
| _target_: str = "nemo.collections.asr.modules.MaskedPatchAugmentation" |
|
|