File size: 890 Bytes
251e479
 
 
 
2a8678c
 
251e479
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a8678c
251e479
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import einops
import torch
import torch.nn.functional as F

device = 'cuda' if torch.cuda.is_available() else 'cpu'


@torch.no_grad()
def find_flat_region(mask):
    device = mask.device
    kernel_x = torch.Tensor([[-1, 0, 1], [-1, 0, 1],
                             [-1, 0, 1]]).unsqueeze(0).unsqueeze(0).to(device)
    kernel_y = torch.Tensor([[-1, -1, -1], [0, 0, 0],
                             [1, 1, 1]]).unsqueeze(0).unsqueeze(0).to(device)
    mask_ = F.pad(mask.unsqueeze(0), (1, 1, 1, 1), mode='replicate')

    grad_x = torch.nn.functional.conv2d(mask_, kernel_x)
    grad_y = torch.nn.functional.conv2d(mask_, kernel_y)
    return ((abs(grad_x) + abs(grad_y)) == 0).float()[0]


def numpy2tensor(img):
    x0 = torch.from_numpy(img.copy()).float().to(device) / 255.0 * 2.0 - 1.
    x0 = torch.stack([x0], dim=0)
    return einops.rearrange(x0, 'b h w c -> b c h w').clone()