lmzjms's picture
Upload 1162 files
0b32ad6 verified
raw
history blame
9.1 kB
# -*- coding: utf-8 -*- #
"""*********************************************************************************************"""
# FileName [ dataset.py ]
# Synopsis [ the phone dataset ]
# Author [ S3PRL, Xuankai Chang ]
# Copyright [ Copyleft(c), Speech Lab, NTU, Taiwan ]
"""*********************************************************************************************"""
# Adaptive_SpecAugment Author: ShampooWang, cornliu
###############
# IMPORTATION #
###############
import os
import random
#-------------#
import pandas as pd
#-------------#
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data.dataset import Dataset
#-------------#
import torchaudio
DEFAULT_TIME_WARP_MODE = "bicubic"
class SpecAug(torch.nn.Module):
"""SpecAug"""
def __init__(
self,
apply_time_warp=True,
time_warp_window=5,
time_warp_mode="bicubic",
apply_freq_mask=True,
freq_mask_width_range=(0,20),
num_freq_mask=2,
apply_time_mask=True,
time_mask_width_range=(0,100),
num_time_mask=2,
adaptive_number_ratio = 0.04,
adaptive_size_ratio = 0.04,
max_n_time_masks = 20,
adaptive=False
):
assert any([apply_time_warp, apply_freq_mask, apply_time_mask])
super(SpecAug, self).__init__()
self.apply_time_warp = apply_time_warp
self.apply_freq_mask = apply_freq_mask
self.apply_time_mask = apply_time_mask
if apply_time_warp:
self.time_warp = TimeWarp(window=time_warp_window, mode=time_warp_mode)
else:
self.time_warp = None
if apply_freq_mask:
self.freq_mask = MaskAlongAxis(
dim="freq",
mask_width_range=freq_mask_width_range,
num_mask=num_freq_mask,
)
else:
self.freq_mask = None
if apply_time_mask:
self.time_mask = MaskAlongAxis(
dim="time",
mask_width_range=time_mask_width_range,
num_mask=num_time_mask,
adaptive=adaptive,
adaptive_number_ratio=adaptive_number_ratio,
adaptive_size_ratio=adaptive_size_ratio,
max_n_time_masks=max_n_time_masks
)
else:
self.time_mask = None
def apply_specaug(self, x, x_lengths=None):
if self.time_warp is not None:
x, x_lengths = self.time_warp(x, x_lengths)
if self.freq_mask is not None:
x, x_lengths = self.freq_mask(x, x_lengths)
if self.time_mask is not None:
x, x_lengths = self.time_mask(x, x_lengths)
return x, x_lengths
def forward(self, xs, x_lengths=None):
"""Forward SpecAug.
Args:
xs: list of features [(T, D)] x batchsize
x_lengths: length of features
Return:
list of features[(T, D)] x batchsize
"""
assert len(xs[0].size()) == 2
if x_lengths is None:
x_lengths = torch.LongTensor([x.size(0) for x in xs])
batchsize, max_len, dim = len(x_lengths), torch.max(x_lengths).item(), xs[0].size(1)
xs_pad = xs[0].new_zeros((batchsize, max_len, dim))
for i, x in enumerate(xs):
xs_pad[i, :x_lengths[i]] = x
xs_pad, _ = self.apply_specaug(xs_pad, x_lengths)
xs = [xs_pad[i, :xs[i].size(0)] for i in range(batchsize)]
return xs, x_lengths
class TimeWarp(torch.nn.Module):
"""Time warping using torch.interpolate.
Args:
window: time warp parameter
mode: Interpolate mode
"""
def __init__(self, window=80, mode=DEFAULT_TIME_WARP_MODE):
super().__init__()
self.window = window
self.mode = mode
def extra_repr(self):
return f"window={self.window}, mode={self.mode}"
def time_warp(self, x):
org_size = x.size()
if x.dim() == 3:
# x: (Batch, Time, Freq) -> (Batch, 1, Time, Freq)
x = x[:, None]
t = x.shape[2]
if t - self.window <= self.window:
return x.view(*org_size)
center = torch.randint(self.window, t - self.window, (1,))[0]
warped = torch.randint(center - self.window, center + self.window, (1,))[0] + 1
# left: (Batch, Channel, warped, Freq)
# right: (Batch, Channel, time - warped, Freq)
left = torch.nn.functional.interpolate(
x[:, :, :center], (warped, x.shape[3]), mode=self.mode, align_corners=False
)
right = torch.nn.functional.interpolate(
x[:, :, center:], (t - warped, x.shape[3]), mode=self.mode, align_corners=False
)
if x.requires_grad:
x = torch.cat([left, right], dim=-2)
else:
x[:, :, :warped] = left
x[:, :, warped:] = right
return x.view(*org_size)
def forward(self, x, x_lengths=None):
"""Forward function.
Args:
x: (Batch, Time, Freq)
x_lengths: (Batch,)
"""
ys = x.new_zeros(x.size())
for i in range(x.size(0)):
_y = self.time_warp(
x[i][None, : x_lengths[i]],
)[0]
ys[i, : x_lengths[i]] = _y
return ys, x_lengths
class MaskAlongAxis(torch.nn.Module):
def __init__(
self,
mask_width_range=(0, 30),
num_mask=2,
dim="time",
replace_with_zero=True,
adaptive_number_ratio = 0.04,
adaptive_size_ratio = 0.04,
max_n_time_masks = 20,
adaptive = False
):
if isinstance(mask_width_range, int):
mask_width_range = (0, mask_width_range)
if len(mask_width_range) != 2:
raise TypeError(
f"mask_width_range must be a tuple of int and int values: "
f"{mask_width_range}",
)
assert mask_width_range[1] > mask_width_range[0]
if isinstance(dim, str):
if dim == "time":
dim = 1
elif dim == "freq":
dim = 2
else:
raise ValueError("dim must be int, 'time' or 'freq'")
if dim == 1:
self.mask_axis = "time"
elif dim == 2:
self.mask_axis = "freq"
else:
self.mask_axis = "unknown"
super().__init__()
self.mask_width_range = mask_width_range
self.num_mask = num_mask
self.dim = dim
self.replace_with_zero = replace_with_zero
###############################################
self.adaptive = adaptive
self.adaptive_number_ratio = adaptive_number_ratio
self.adaptive_size_ratio = adaptive_size_ratio
self.max_n_time_masks = max_n_time_masks
###############################################
def mask_along_axis(self, spec, spec_lengths):
org_size = spec.size()
if spec.dim() == 4:
# spec: (Batch, Channel, Length, Freq) -> (Batch * Channel, Length, Freq)
spec = spec.view(-1, spec.size(2), spec.size(3))
B = spec.shape[0]
# D = Length or Freq
D = spec.shape[self.dim]
T = self.mask_width_range[1]
num_mask = self.num_mask
if self.dim == 1 & self.adaptive :
if self.adaptive_number_ratio > 0:
num_mask = min(int(self.adaptive_number_ratio * D), self.max_n_time_masks)
if self.adaptive_size_ratio > 0:
T = min(self.mask_width_range[1], int(self.adaptive_size_ratio * D))
# mask_length: (B, num_mask, 1)
mask_length = torch.randint(
self.mask_width_range[0],
T,
(B, num_mask),
device=spec.device,
).unsqueeze(2)
# mask_pos: (B, num_mask, 1)
mask_pos = torch.randint(
0, max(1, D - mask_length.max()), (B, num_mask), device=spec.device
).unsqueeze(2)
# aran: (1, 1, D)
aran = torch.arange(D, device=spec.device)[None, None, :]
# mask: (Batch, num_mask, D)
mask = (mask_pos <= aran) * (aran < (mask_pos + mask_length))
# Multiply masks: (Batch, num_mask, D) -> (Batch, D)
mask = mask.any(dim=1)
if self.dim == 1:
# mask: (Batch, Length, 1)
mask = mask.unsqueeze(2)
elif self.dim == 2:
# mask: (Batch, 1, Freq)
mask = mask.unsqueeze(1)
if self.replace_with_zero:
value = 0.0
else:
value = spec.mean()
if spec.requires_grad:
spec = spec.masked_fill(mask, value)
else:
spec = spec.masked_fill_(mask, value)
spec = spec.view(*org_size)
return spec, spec_lengths
def forward(self, spec, spec_lengths=None):
"""Forward function.
Args:
spec: (Batch, Length, Freq)
"""
return self.mask_along_axis(
spec,
spec_lengths,
)