File size: 1,585 Bytes
2f85de4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# python3.8
"""Contains some utility operators."""

import math
import torch
import torch.distributed as dist
import torch.nn.functional as F

__all__ = [
    'all_gather',
    'upsample',
    'downsample',
]


def all_gather(tensor):
    """Gathers tensor from all devices and executes averaging."""
    if not dist.is_initialized():
        return tensor

    world_size = dist.get_world_size()
    tensor_list = [torch.ones_like(tensor) for _ in range(world_size)]
    dist.all_gather(tensor_list, tensor, async_op=False)
    return torch.stack(tensor_list, dim=0).mean(dim=0)


def upsample(img_nerf, size, filter=None):
    up = size // img_nerf.size(-1)
    if up <= 1:
        return img_nerf

    if filter is not None:
        from third_party.stylegan2_official_ops import upfirdn2d
        for _ in range(int(math.log2(up))):
            img_nerf = upfirdn2d.downsample2d(img_nerf, filter, up=2)
    else:
        img_nerf = F.interpolate(img_nerf, (size, size),
                                 mode='bilinear',
                                 align_corners=False)
    return img_nerf


def downsample(img0, size, filter=None):
    down = img0.size(-1) // size
    if down <= 1:
        return img0

    if filter is not None:
        from third_party.stylegan2_official_ops import upfirdn2d
        for _ in range(int(math.log2(down))):
            img0 = upfirdn2d.downsample2d(img0, filter, down=2)
    else:
        img0 = F.interpolate(img0, (size, size),
                             mode='bilinear',
                             align_corners=False)
    return img0