Spaces:
Runtime error
Runtime error
File size: 6,022 Bytes
cfb7702 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import torch
import torch.nn as nn
import torch.nn.functional as F
class WeightedLoss(nn.Module):
@property
def func(self):
raise NotImplementedError
def forward(self, inputs, targets, weight=None, reduction='mean'):
assert reduction in ['none', 'sum', 'mean', 'valid_mean']
loss = self.func(inputs, targets, reduction='none')
if weight is not None:
while weight.ndim < inputs.ndim:
weight = weight[..., None]
loss *= weight.float()
if reduction == 'none':
return loss
elif reduction == 'sum':
return loss.sum()
elif reduction == 'mean':
return loss.mean()
elif reduction == 'valid_mean':
return loss.sum() / weight.float().sum()
class MSELoss(WeightedLoss):
@property
def func(self):
return F.mse_loss
class L1Loss(WeightedLoss):
@property
def func(self):
return F.l1_loss
class PSNR(nn.Module):
def __init__(self):
super().__init__()
def forward(self, inputs, targets, valid_mask=None, reduction='mean'):
assert reduction in ['mean', 'none']
value = (inputs - targets)**2
if valid_mask is not None:
value = value[valid_mask]
if reduction == 'mean':
return -10 * torch.log10(torch.mean(value))
elif reduction == 'none':
return -10 * torch.log10(torch.mean(value, dim=tuple(range(value.ndim)[1:])))
class SSIM():
def __init__(self, data_range=(0, 1), kernel_size=(11, 11), sigma=(1.5, 1.5), k1=0.01, k2=0.03, gaussian=True):
self.kernel_size = kernel_size
self.sigma = sigma
self.gaussian = gaussian
if any(x % 2 == 0 or x <= 0 for x in self.kernel_size):
raise ValueError(f"Expected kernel_size to have odd positive number. Got {kernel_size}.")
if any(y <= 0 for y in self.sigma):
raise ValueError(f"Expected sigma to have positive number. Got {sigma}.")
data_scale = data_range[1] - data_range[0]
self.c1 = (k1 * data_scale)**2
self.c2 = (k2 * data_scale)**2
self.pad_h = (self.kernel_size[0] - 1) // 2
self.pad_w = (self.kernel_size[1] - 1) // 2
self._kernel = self._gaussian_or_uniform_kernel(kernel_size=self.kernel_size, sigma=self.sigma)
def _uniform(self, kernel_size):
max, min = 2.5, -2.5
ksize_half = (kernel_size - 1) * 0.5
kernel = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
for i, j in enumerate(kernel):
if min <= j <= max:
kernel[i] = 1 / (max - min)
else:
kernel[i] = 0
return kernel.unsqueeze(dim=0) # (1, kernel_size)
def _gaussian(self, kernel_size, sigma):
ksize_half = (kernel_size - 1) * 0.5
kernel = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
gauss = torch.exp(-0.5 * (kernel / sigma).pow(2))
return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size)
def _gaussian_or_uniform_kernel(self, kernel_size, sigma):
if self.gaussian:
kernel_x = self._gaussian(kernel_size[0], sigma[0])
kernel_y = self._gaussian(kernel_size[1], sigma[1])
else:
kernel_x = self._uniform(kernel_size[0])
kernel_y = self._uniform(kernel_size[1])
return torch.matmul(kernel_x.t(), kernel_y) # (kernel_size, 1) * (1, kernel_size)
def __call__(self, output, target, reduction='mean'):
if output.dtype != target.dtype:
raise TypeError(
f"Expected output and target to have the same data type. Got output: {output.dtype} and y: {target.dtype}."
)
if output.shape != target.shape:
raise ValueError(
f"Expected output and target to have the same shape. Got output: {output.shape} and y: {target.shape}."
)
if len(output.shape) != 4 or len(target.shape) != 4:
raise ValueError(
f"Expected output and target to have BxCxHxW shape. Got output: {output.shape} and y: {target.shape}."
)
assert reduction in ['mean', 'sum', 'none']
channel = output.size(1)
if len(self._kernel.shape) < 4:
self._kernel = self._kernel.expand(channel, 1, -1, -1)
output = F.pad(output, [self.pad_w, self.pad_w, self.pad_h, self.pad_h], mode="reflect")
target = F.pad(target, [self.pad_w, self.pad_w, self.pad_h, self.pad_h], mode="reflect")
input_list = torch.cat([output, target, output * output, target * target, output * target])
outputs = F.conv2d(input_list, self._kernel, groups=channel)
output_list = [outputs[x * output.size(0) : (x + 1) * output.size(0)] for x in range(len(outputs))]
mu_pred_sq = output_list[0].pow(2)
mu_target_sq = output_list[1].pow(2)
mu_pred_target = output_list[0] * output_list[1]
sigma_pred_sq = output_list[2] - mu_pred_sq
sigma_target_sq = output_list[3] - mu_target_sq
sigma_pred_target = output_list[4] - mu_pred_target
a1 = 2 * mu_pred_target + self.c1
a2 = 2 * sigma_pred_target + self.c2
b1 = mu_pred_sq + mu_target_sq + self.c1
b2 = sigma_pred_sq + sigma_target_sq + self.c2
ssim_idx = (a1 * a2) / (b1 * b2)
_ssim = torch.mean(ssim_idx, (1, 2, 3))
if reduction == 'none':
return _ssim
elif reduction == 'sum':
return _ssim.sum()
elif reduction == 'mean':
return _ssim.mean()
def binary_cross_entropy(input, target, reduction='mean'):
"""
F.binary_cross_entropy is not numerically stable in mixed-precision training.
"""
loss = -(target * torch.log(input) + (1 - target) * torch.log(1 - input))
if reduction == 'mean':
return loss.mean()
elif reduction == 'none':
return loss
|