import tqdm import torch class TorchScaler(torch.nn.Module): """ This torch module implements scaling for input tensors, both instance based and dataset-wide statistic based. Args: statistic: str, (default='dataset'), represent how to compute the statistic for normalisation. Choice in {'dataset', 'instance'}. 'dataset' needs to be 'fit()' with a dataloader of the dataset. 'instance' apply the normalisation at an instance-level, so compute the statitics on the instance specified, it can be a clip or a batch. normtype: str, (default='standard') the type of normalisation to use. Choice in {'standard', 'mean', 'minmax'}. 'standard' applies a classic normalisation with mean and standard deviation. 'mean' substract the mean to the data. 'minmax' substract the minimum of the data and divide by the difference between max and min. """ def __init__(self, statistic="dataset", normtype="standard", dims=(1, 2), eps=1e-8): super(TorchScaler, self).__init__() assert statistic in ["dataset", "instance", None] assert normtype in ["standard", "mean", "minmax", None] if statistic == "dataset" and normtype == "minmax": raise NotImplementedError( "statistic==dataset and normtype==minmax is not currently implemented." ) self.statistic = statistic self.normtype = normtype self.dims = dims self.eps = eps def load_state_dict(self, state_dict, strict=True): if self.statistic == "dataset": super(TorchScaler, self).load_state_dict(state_dict, strict) def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, ): if self.statistic == "dataset": super(TorchScaler, self)._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, ) def fit(self, dataloader, transform_func=lambda x: x[0]): """ Scaler fitting Args: dataloader (DataLoader): training data DataLoader transform_func (lambda function, optional): Transforms applied to the data. Defaults to lambdax:x[0]. """ indx = 0 for batch in tqdm.tqdm(dataloader): feats = transform_func(batch) if indx == 0: mean = torch.mean(feats, self.dims, keepdim=True).mean(0).unsqueeze(0) mean_squared = ( torch.mean(feats ** 2, self.dims, keepdim=True).mean(0).unsqueeze(0) ) else: mean += torch.mean(feats, self.dims, keepdim=True).mean(0).unsqueeze(0) mean_squared += ( torch.mean(feats ** 2, self.dims, keepdim=True).mean(0).unsqueeze(0) ) indx += 1 mean /= indx mean_squared /= indx self.register_buffer("mean", mean) self.register_buffer("mean_squared", mean_squared) def forward(self, tensor): if self.statistic is None or self.normtype is None: return tensor if self.statistic == "dataset": assert hasattr(self, "mean") and hasattr( self, "mean_squared" ), "TorchScaler should be fit before used if statistics=dataset" assert tensor.ndim == self.mean.ndim, "Pre-computed statistics " if self.normtype == "mean": return tensor - self.mean elif self.normtype == "standard": std = torch.sqrt(self.mean_squared - self.mean ** 2) return (tensor - self.mean) / (std + self.eps) else: raise NotImplementedError else: if self.normtype == "mean": return tensor - torch.mean(tensor, self.dims, keepdim=True) elif self.normtype == "standard": return (tensor - torch.mean(tensor, self.dims, keepdim=True)) / ( torch.std(tensor, self.dims, keepdim=True) + self.eps ) elif self.normtype == "minmax": return (tensor - torch.amin(tensor, dim=self.dims, keepdim=True)) / ( torch.amax(tensor, dim=self.dims, keepdim=True) - torch.amin(tensor, dim=self.dims, keepdim=True) + self.eps )