Spaces:
Runtime error
Runtime error
import functools | |
import torch | |
import torch.distributions | |
import utils | |
LOGGER = utils.log.getLogger(__name__) | |
__defined_kl = False | |
EPS = 1e-5 | |
def clamp_probs(probs): | |
probs = probs.clamp(EPS, 1. - EPS) # Will no longer sum to 1 | |
return probs / probs.sum(-1, keepdim=True) # to simplex | |
def grid(h, w, pad=0, device='cpu', dtype=torch.float32, norm=False): | |
hr = torch.arange(h + 2 * pad, device=device) - pad | |
wr = torch.arange(w + 2 * pad, device=device) - pad | |
if norm: | |
hr = hr / (h + 2 * pad - 1) | |
wr = wr / (w + 2 * pad - 1) | |
ig, jg = torch.meshgrid(hr, wr) | |
g = torch.stack([jg, ig]).to(dtype)[None] | |
return g | |
def cached_grid(h, w, pad=0, device='cpu', dtype=torch.float32, norm=False): | |
return grid(h, w, pad, device, dtype, norm) | |