|
import torch |
|
from typeguard import check_argument_types |
|
from typing import Sequence |
|
from typing import Union |
|
|
|
|
|
def mask_along_axis( |
|
spec: torch.Tensor, |
|
spec_lengths: torch.Tensor, |
|
mask_width_range: Sequence[int] = (0, 30), |
|
dim: int = 1, |
|
num_mask: int = 2, |
|
replace_with_zero: bool = True, |
|
): |
|
"""Apply mask along the specified direction. |
|
|
|
Args: |
|
spec: (Batch, Length, Freq) |
|
spec_lengths: (Length): Not using lenghts in this implementation |
|
mask_width_range: Select the width randomly between this range |
|
""" |
|
|
|
org_size = spec.size() |
|
if spec.dim() == 4: |
|
|
|
spec = spec.view(-1, spec.size(2), spec.size(3)) |
|
|
|
B = spec.shape[0] |
|
|
|
D = spec.shape[dim] |
|
|
|
mask_length = torch.randint( |
|
mask_width_range[0], |
|
mask_width_range[1], |
|
(B, num_mask), |
|
device=spec.device, |
|
).unsqueeze(2) |
|
|
|
|
|
mask_pos = torch.randint( |
|
0, max(1, D - mask_length.max()), (B, num_mask), device=spec.device |
|
).unsqueeze(2) |
|
|
|
|
|
aran = torch.arange(D, device=spec.device)[None, None, :] |
|
|
|
mask = (mask_pos <= aran) * (aran < (mask_pos + mask_length)) |
|
|
|
mask = mask.any(dim=1) |
|
if dim == 1: |
|
|
|
mask = mask.unsqueeze(2) |
|
elif dim == 2: |
|
|
|
mask = mask.unsqueeze(1) |
|
|
|
if replace_with_zero: |
|
value = 0.0 |
|
else: |
|
value = spec.mean() |
|
|
|
if spec.requires_grad: |
|
spec = spec.masked_fill(mask, value) |
|
else: |
|
spec = spec.masked_fill_(mask, value) |
|
spec = spec.view(*org_size) |
|
return spec, spec_lengths |
|
|
|
|
|
class MaskAlongAxis(torch.nn.Module): |
|
def __init__( |
|
self, |
|
mask_width_range: Union[int, Sequence[int]] = (0, 30), |
|
num_mask: int = 2, |
|
dim: Union[int, str] = "time", |
|
replace_with_zero: bool = True, |
|
): |
|
assert check_argument_types() |
|
if isinstance(mask_width_range, int): |
|
mask_width_range = (0, mask_width_range) |
|
if len(mask_width_range) != 2: |
|
raise TypeError( |
|
f"mask_width_range must be a tuple of int and int values: " |
|
f"{mask_width_range}", |
|
) |
|
|
|
assert mask_width_range[1] > mask_width_range[0] |
|
if isinstance(dim, str): |
|
if dim == "time": |
|
dim = 1 |
|
elif dim == "freq": |
|
dim = 2 |
|
else: |
|
raise ValueError("dim must be int, 'time' or 'freq'") |
|
if dim == 1: |
|
self.mask_axis = "time" |
|
elif dim == 2: |
|
self.mask_axis = "freq" |
|
else: |
|
self.mask_axis = "unknown" |
|
|
|
super().__init__() |
|
self.mask_width_range = mask_width_range |
|
self.num_mask = num_mask |
|
self.dim = dim |
|
self.replace_with_zero = replace_with_zero |
|
|
|
def extra_repr(self): |
|
return ( |
|
f"mask_width_range={self.mask_width_range}, " |
|
f"num_mask={self.num_mask}, axis={self.mask_axis}" |
|
) |
|
|
|
def forward(self, spec: torch.Tensor, spec_lengths: torch.Tensor = None): |
|
"""Forward function. |
|
|
|
Args: |
|
spec: (Batch, Length, Freq) |
|
""" |
|
|
|
return mask_along_axis( |
|
spec, |
|
spec_lengths, |
|
mask_width_range=self.mask_width_range, |
|
dim=self.dim, |
|
num_mask=self.num_mask, |
|
replace_with_zero=self.replace_with_zero, |
|
) |
|
|