chenming-wu's picture
code
436b829 verified
from abc import ABC
from typing import Sequence, Union
import torch
from torch.distributions import LogisticNormal
class Timesteps(ABC):
"""
Timesteps base class.
"""
def __init__(self, T: Union[int, float]):
assert T > 0
self._T = T
@property
def T(self) -> Union[int, float]:
"""
Maximum timestep inclusive.
int if discrete, float if continuous.
"""
return self._T
def is_continuous(self) -> bool:
"""
Whether the schedule is continuous.
"""
return isinstance(self.T, float)
class LogitNormalTrainingTimesteps(Timesteps):
"""
Logit-Normal sampling of timesteps in [0, T].
"""
def __init__(self, T: Union[int, float], loc: float, scale: float):
super().__init__(T)
self.dist = LogisticNormal(loc, scale)
def sample(
self,
size: Sequence[int],
device: torch.device = "cpu",
) -> torch.Tensor:
t = self.dist.sample(size)[..., 0].to(device).mul_(self.T)
return t if self.is_continuous() else t.round().int()