File size: 815 Bytes
5e88f62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import functools

import torch
import torch.distributions

import utils

LOGGER = utils.log.getLogger(__name__)

__defined_kl = False

EPS = 1e-5


def clamp_probs(probs):
    probs = probs.clamp(EPS, 1. - EPS)  # Will no longer sum to 1
    return probs / probs.sum(-1, keepdim=True)  # to simplex


def grid(h, w, pad=0, device='cpu', dtype=torch.float32, norm=False):
    hr = torch.arange(h + 2 * pad, device=device) - pad
    wr = torch.arange(w + 2 * pad, device=device) - pad
    if norm:
        hr = hr / (h + 2 * pad - 1)
        wr = wr / (w + 2 * pad - 1)
    ig, jg = torch.meshgrid(hr, wr)
    g = torch.stack([jg, ig]).to(dtype)[None]
    return g


@functools.lru_cache(2)
def cached_grid(h, w, pad=0, device='cpu', dtype=torch.float32, norm=False):
    return grid(h, w, pad, device, dtype, norm)