Spaces:
Running
Running
| import logging | |
| import torch | |
| from torch import Tensor, nn | |
| logger = logging.getLogger(__name__) | |
| class Normalizer(nn.Module): | |
| def __init__(self, momentum=0.01, eps=1e-9): | |
| super().__init__() | |
| self.momentum = momentum | |
| self.eps = eps | |
| self.running_mean_unsafe: Tensor | |
| self.running_var_unsafe: Tensor | |
| self.register_buffer("running_mean_unsafe", torch.full([], torch.nan)) | |
| self.register_buffer("running_var_unsafe", torch.full([], torch.nan)) | |
| def started(self): | |
| return not torch.isnan(self.running_mean_unsafe) | |
| def running_mean(self): | |
| if not self.started: | |
| return torch.zeros_like(self.running_mean_unsafe) | |
| return self.running_mean_unsafe | |
| def running_std(self): | |
| if not self.started: | |
| return torch.ones_like(self.running_var_unsafe) | |
| return (self.running_var_unsafe + self.eps).sqrt() | |
| def _ema(self, a: Tensor, x: Tensor): | |
| return (1 - self.momentum) * a + self.momentum * x | |
| def update_(self, x): | |
| if not self.started: | |
| self.running_mean_unsafe = x.mean() | |
| self.running_var_unsafe = x.var() | |
| else: | |
| self.running_mean_unsafe = self._ema(self.running_mean_unsafe, x.mean()) | |
| self.running_var_unsafe = self._ema(self.running_var_unsafe, (x - self.running_mean).pow(2).mean()) | |
| def forward(self, x: Tensor, update=True): | |
| if self.training and update: | |
| self.update_(x) | |
| self.stats = dict(mean=self.running_mean.item(), std=self.running_std.item()) | |
| x = (x - self.running_mean) / self.running_std | |
| return x | |
| def inverse(self, x: Tensor): | |
| return x * self.running_std + self.running_mean | |