| import math |
| import numbers |
| from typing import Optional |
|
|
| import numpy as np |
| from fairseq.data.audio.feature_transforms import ( |
| AudioFeatureTransform, |
| register_audio_feature_transform, |
| ) |
|
|
|
|
| @register_audio_feature_transform("specaugment") |
| class SpecAugmentTransform(AudioFeatureTransform): |
| """SpecAugment (https://arxiv.org/abs/1904.08779)""" |
|
|
| @classmethod |
| def from_config_dict(cls, config=None): |
| _config = {} if config is None else config |
| return SpecAugmentTransform( |
| _config.get("time_warp_W", 0), |
| _config.get("freq_mask_N", 0), |
| _config.get("freq_mask_F", 0), |
| _config.get("time_mask_N", 0), |
| _config.get("time_mask_T", 0), |
| _config.get("time_mask_p", 0.0), |
| _config.get("mask_value", None), |
| ) |
|
|
| def __init__( |
| self, |
| time_warp_w: int = 0, |
| freq_mask_n: int = 0, |
| freq_mask_f: int = 0, |
| time_mask_n: int = 0, |
| time_mask_t: int = 0, |
| time_mask_p: float = 0.0, |
| mask_value: Optional[float] = 0.0, |
| ): |
| |
| assert mask_value is None or isinstance( |
| mask_value, numbers.Number |
| ), f"mask_value (type: {type(mask_value)}) must be None or a number" |
| if freq_mask_n > 0: |
| assert freq_mask_f > 0, ( |
| f"freq_mask_F ({freq_mask_f}) " |
| f"must be larger than 0 when doing freq masking." |
| ) |
| if time_mask_n > 0: |
| assert time_mask_t > 0, ( |
| f"time_mask_T ({time_mask_t}) must be larger than 0 when " |
| f"doing time masking." |
| ) |
|
|
| self.time_warp_w = time_warp_w |
| self.freq_mask_n = freq_mask_n |
| self.freq_mask_f = freq_mask_f |
| self.time_mask_n = time_mask_n |
| self.time_mask_t = time_mask_t |
| self.time_mask_p = time_mask_p |
| self.mask_value = mask_value |
|
|
| def __repr__(self): |
| return ( |
| self.__class__.__name__ |
| + "(" |
| + ", ".join( |
| [ |
| f"time_warp_w={self.time_warp_w}", |
| f"freq_mask_n={self.freq_mask_n}", |
| f"freq_mask_f={self.freq_mask_f}", |
| f"time_mask_n={self.time_mask_n}", |
| f"time_mask_t={self.time_mask_t}", |
| f"time_mask_p={self.time_mask_p}", |
| ] |
| ) |
| + ")" |
| ) |
|
|
| def __call__(self, spectrogram): |
| assert len(spectrogram.shape) == 2, "spectrogram must be a 2-D tensor." |
|
|
| distorted = spectrogram.copy() |
| num_frames = spectrogram.shape[0] |
| num_freqs = spectrogram.shape[1] |
| mask_value = self.mask_value |
|
|
| if mask_value is None: |
| mask_value = spectrogram.mean() |
|
|
| if num_frames == 0: |
| return spectrogram |
|
|
| if num_freqs < self.freq_mask_f: |
| return spectrogram |
|
|
| if self.time_warp_w > 0: |
| if 2 * self.time_warp_w < num_frames: |
| import cv2 |
|
|
| w0 = np.random.randint(self.time_warp_w, num_frames - self.time_warp_w) |
| w = np.random.randint(-self.time_warp_w + 1, self.time_warp_w) |
| upper, lower = distorted[:w0, :], distorted[w0:, :] |
| upper = cv2.resize( |
| upper, dsize=(num_freqs, w0 + w), interpolation=cv2.INTER_LINEAR |
| ) |
| lower = cv2.resize( |
| lower, |
| dsize=(num_freqs, num_frames - w0 - w), |
| interpolation=cv2.INTER_LINEAR, |
| ) |
| distorted = np.concatenate((upper, lower), axis=0) |
|
|
| for _i in range(self.freq_mask_n): |
| f = np.random.randint(0, self.freq_mask_f) |
| f0 = np.random.randint(0, num_freqs - f) |
| if f != 0: |
| distorted[:, f0 : f0 + f] = mask_value |
|
|
| max_time_mask_t = min( |
| self.time_mask_t, math.floor(num_frames * self.time_mask_p) |
| ) |
| if max_time_mask_t < 1: |
| return distorted |
|
|
| for _i in range(self.time_mask_n): |
| t = np.random.randint(0, max_time_mask_t) |
| t0 = np.random.randint(0, num_frames - t) |
| if t != 0: |
| distorted[t0 : t0 + t, :] = mask_value |
|
|
| return distorted |
|
|