Guess-What-Moves / dist.py
subhc's picture
Code Commit
5e88f62
import functools
import torch
import torch.distributions
import utils
LOGGER = utils.log.getLogger(__name__)
__defined_kl = False
EPS = 1e-5
def clamp_probs(probs):
probs = probs.clamp(EPS, 1. - EPS) # Will no longer sum to 1
return probs / probs.sum(-1, keepdim=True) # to simplex
def grid(h, w, pad=0, device='cpu', dtype=torch.float32, norm=False):
hr = torch.arange(h + 2 * pad, device=device) - pad
wr = torch.arange(w + 2 * pad, device=device) - pad
if norm:
hr = hr / (h + 2 * pad - 1)
wr = wr / (w + 2 * pad - 1)
ig, jg = torch.meshgrid(hr, wr)
g = torch.stack([jg, ig]).to(dtype)[None]
return g
@functools.lru_cache(2)
def cached_grid(h, w, pad=0, device='cpu', dtype=torch.float32, norm=False):
return grid(h, w, pad, device, dtype, norm)