Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
class BaseScaler(): | |
def __init__(self): | |
self.stretched_limits = None | |
def setup_limits(self, schedule, input_scaler, stretch_max=True, stretch_min=True, shift=1): | |
min_logSNR = schedule(torch.ones(1), shift=shift) | |
max_logSNR = schedule(torch.zeros(1), shift=shift) | |
min_a, max_b = [v.item() for v in input_scaler(min_logSNR)] if stretch_max else [0, 1] | |
max_a, min_b = [v.item() for v in input_scaler(max_logSNR)] if stretch_min else [1, 0] | |
self.stretched_limits = [min_a, max_a, min_b, max_b] | |
return self.stretched_limits | |
def stretch_limits(self, a, b): | |
min_a, max_a, min_b, max_b = self.stretched_limits | |
return (a - min_a) / (max_a - min_a), (b - min_b) / (max_b - min_b) | |
def scalers(self, logSNR): | |
raise NotImplementedError("this method needs to be overridden") | |
def __call__(self, logSNR): | |
a, b = self.scalers(logSNR) | |
if self.stretched_limits is not None: | |
a, b = self.stretch_limits(a, b) | |
return a, b | |
class VPScaler(BaseScaler): | |
def scalers(self, logSNR): | |
a_squared = logSNR.sigmoid() | |
a = a_squared.sqrt() | |
b = (1-a_squared).sqrt() | |
return a, b | |
class LERPScaler(BaseScaler): | |
def scalers(self, logSNR): | |
_a = logSNR.exp() - 1 | |
_a[_a == 0] = 1e-3 # Avoid division by zero | |
a = 1 + (2-(2**2 + 4*_a)**0.5) / (2*_a) | |
b = 1-a | |
return a, b | |