OpenSound's picture
Upload 544 files
3b6a091 verified
import tqdm
import torch
class TorchScaler(torch.nn.Module):
"""
This torch module implements scaling for input tensors, both instance based
and dataset-wide statistic based.
Args:
statistic: str, (default='dataset'), represent how to compute the statistic for normalisation.
Choice in {'dataset', 'instance'}.
'dataset' needs to be 'fit()' with a dataloader of the dataset.
'instance' apply the normalisation at an instance-level, so compute the statitics on the instance
specified, it can be a clip or a batch.
normtype: str, (default='standard') the type of normalisation to use.
Choice in {'standard', 'mean', 'minmax'}. 'standard' applies a classic normalisation with mean and standard
deviation. 'mean' substract the mean to the data. 'minmax' substract the minimum of the data and divide by
the difference between max and min.
"""
def __init__(self, statistic="dataset", normtype="standard", dims=(1, 2), eps=1e-8):
super(TorchScaler, self).__init__()
assert statistic in ["dataset", "instance", None]
assert normtype in ["standard", "mean", "minmax", None]
if statistic == "dataset" and normtype == "minmax":
raise NotImplementedError(
"statistic==dataset and normtype==minmax is not currently implemented."
)
self.statistic = statistic
self.normtype = normtype
self.dims = dims
self.eps = eps
def load_state_dict(self, state_dict, strict=True):
if self.statistic == "dataset":
super(TorchScaler, self).load_state_dict(state_dict, strict)
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
if self.statistic == "dataset":
super(TorchScaler, self)._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
def fit(self, dataloader, transform_func=lambda x: x[0]):
"""
Scaler fitting
Args:
dataloader (DataLoader): training data DataLoader
transform_func (lambda function, optional): Transforms applied to the data.
Defaults to lambdax:x[0].
"""
indx = 0
for batch in tqdm.tqdm(dataloader):
feats = transform_func(batch)
if indx == 0:
mean = torch.mean(feats, self.dims, keepdim=True).mean(0).unsqueeze(0)
mean_squared = (
torch.mean(feats ** 2, self.dims, keepdim=True).mean(0).unsqueeze(0)
)
else:
mean += torch.mean(feats, self.dims, keepdim=True).mean(0).unsqueeze(0)
mean_squared += (
torch.mean(feats ** 2, self.dims, keepdim=True).mean(0).unsqueeze(0)
)
indx += 1
mean /= indx
mean_squared /= indx
self.register_buffer("mean", mean)
self.register_buffer("mean_squared", mean_squared)
def forward(self, tensor):
if self.statistic is None or self.normtype is None:
return tensor
if self.statistic == "dataset":
assert hasattr(self, "mean") and hasattr(
self, "mean_squared"
), "TorchScaler should be fit before used if statistics=dataset"
assert tensor.ndim == self.mean.ndim, "Pre-computed statistics "
if self.normtype == "mean":
return tensor - self.mean
elif self.normtype == "standard":
std = torch.sqrt(self.mean_squared - self.mean ** 2)
return (tensor - self.mean) / (std + self.eps)
else:
raise NotImplementedError
else:
if self.normtype == "mean":
return tensor - torch.mean(tensor, self.dims, keepdim=True)
elif self.normtype == "standard":
return (tensor - torch.mean(tensor, self.dims, keepdim=True)) / (
torch.std(tensor, self.dims, keepdim=True) + self.eps
)
elif self.normtype == "minmax":
return (tensor - torch.amin(tensor, dim=self.dims, keepdim=True)) / (
torch.amax(tensor, dim=self.dims, keepdim=True)
- torch.amin(tensor, dim=self.dims, keepdim=True)
+ self.eps
)