File size: 4,398 Bytes
5d756f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import torch
import tops


def denormalize_img(image, mean=0.5, std=0.5):
    image = image * std + mean
    image = torch.clamp(image.float(), 0, 1)
    image = (image * 255)
    image = torch.round(image)
    return image / 255


@torch.no_grad()
def im2numpy(images, to_uint8=False, denormalize=False):
    if denormalize:
        images = denormalize_img(images)
        if images.dtype != torch.uint8:
            images = images.clamp(0, 1)
    return tops.im2numpy(images, to_uint8=to_uint8)


@torch.no_grad()
def im2torch(im, cuda=False, normalize=True, to_float=True):
    im = tops.im2torch(im, cuda=cuda, to_float=to_float)
    if normalize:
        assert im.min() >= 0.0 and im.max() <= 1.0
        if normalize:
            im = im * 2 - 1
    return im


@torch.no_grad()
def binary_dilation(im: torch.Tensor, kernel: torch.Tensor):
    assert len(im.shape) == 4
    assert len(kernel.shape) == 2
    kernel = kernel.unsqueeze(0).unsqueeze(0)
    padding = kernel.shape[-1]//2
    assert kernel.shape[-1] % 2 != 0
    if isinstance(im, torch.cuda.FloatTensor):
        im, kernel = im.half(), kernel.half()
    else:
        im, kernel = im.float(), kernel.float()
    im = torch.nn.functional.conv2d(
        im, kernel, groups=im.shape[1], padding=padding)
    im = im > 0.5
    return im


@torch.no_grad()
def binary_erosion(im: torch.Tensor, kernel: torch.Tensor):
    assert len(im.shape) == 4
    assert len(kernel.shape) == 2
    kernel = kernel.unsqueeze(0).unsqueeze(0)
    padding = kernel.shape[-1]//2
    assert kernel.shape[-1] % 2 != 0
    if isinstance(im, torch.cuda.FloatTensor):
        im, kernel = im.half(), kernel.half()
    else:
        im, kernel = im.float(), kernel.float()
    ksum = kernel.sum()
    padding = (padding, padding, padding, padding)
    im = torch.nn.functional.pad(im, padding, mode="reflect")
    im = torch.nn.functional.conv2d(
        im, kernel, groups=im.shape[1])
    return im.round() == ksum


def set_requires_grad(value: torch.nn.Module, requires_grad: bool):
    if isinstance(value, (list, tuple)):
        for param in value:
            param.requires_grad = requires_grad
        return
    for p in value.parameters():
        p.requires_grad = requires_grad


def forward_D_fake(batch, fake_img, discriminator, **kwargs):
    fake_batch = {k: v for k, v in batch.items() if k != "img"}
    fake_batch["img"] = fake_img
    return discriminator(**fake_batch, **kwargs)


def remove_pad(x: torch.Tensor, bbox_XYXY, imshape):
    """
    Remove padding that is shown as negative 
    """
    H, W = imshape
    x0, y0, x1, y1 = bbox_XYXY
    padding = [
        max(0, -x0),
        max(0, -y0),
        max(x1 - W, 0),
        max(y1 - H, 0)
    ]
    x0, y0 = padding[:2]
    x1 = x.shape[2] - padding[2]
    y1 = x.shape[1] - padding[3]
    return x[:, y0:y1, x0:x1]


def crop_box(x: torch.Tensor, bbox_XYXY) -> torch.Tensor:
    """
        Crops x by bbox_XYXY. 
    """
    x0, y0, x1, y1 = bbox_XYXY
    x0 = max(x0, 0)
    y0 = max(y0, 0)
    x1 = min(x1, x.shape[-1])
    y1 = min(y1, x.shape[-2])
    return x[..., y0:y1, x0:x1]


def torch_wasserstein_loss(tensor_a, tensor_b):
    # Compute the first Wasserstein distance between two 1D distributions.
    return (torch_cdf_loss(tensor_a, tensor_b, p=1))


def torch_cdf_loss(tensor_a, tensor_b, p=1):
    # last-dimension is weight distribution
    # p is the norm of the distance, p=1 --> First Wasserstein Distance
    # to get a positive weight with our normalized distribution
    # we recommend combining this loss with other difference-based losses like L1

    # normalize distribution, add 1e-14 to divisor to avoid 0/0
    tensor_a = tensor_a / (torch.sum(tensor_a, dim=-1, keepdim=True) + 1e-14)
    tensor_b = tensor_b / (torch.sum(tensor_b, dim=-1, keepdim=True) + 1e-14)
    # make cdf with cumsum
    cdf_tensor_a = torch.cumsum(tensor_a, dim=-1)
    cdf_tensor_b = torch.cumsum(tensor_b, dim=-1)

    # choose different formulas for different norm situations
    if p == 1:
        cdf_distance = torch.sum(torch.abs((cdf_tensor_a-cdf_tensor_b)), dim=-1)
    elif p == 2:
        cdf_distance = torch.sqrt(torch.sum(torch.pow((cdf_tensor_a-cdf_tensor_b), 2), dim=-1))
    else:
        cdf_distance = torch.pow(torch.sum(torch.pow(torch.abs(cdf_tensor_a-cdf_tensor_b), p), dim=-1), 1/p)

    cdf_loss = cdf_distance.mean()
    return cdf_loss