import torch import torch.nn as nn import torch.nn.functional as F class WeightedLoss(nn.Module): @property def func(self): raise NotImplementedError def forward(self, inputs, targets, weight=None, reduction='mean'): assert reduction in ['none', 'sum', 'mean', 'valid_mean'] loss = self.func(inputs, targets, reduction='none') if weight is not None: while weight.ndim < inputs.ndim: weight = weight[..., None] loss *= weight.float() if reduction == 'none': return loss elif reduction == 'sum': return loss.sum() elif reduction == 'mean': return loss.mean() elif reduction == 'valid_mean': return loss.sum() / weight.float().sum() class MSELoss(WeightedLoss): @property def func(self): return F.mse_loss class L1Loss(WeightedLoss): @property def func(self): return F.l1_loss class PSNR(nn.Module): def __init__(self): super().__init__() def forward(self, inputs, targets, valid_mask=None, reduction='mean'): assert reduction in ['mean', 'none'] value = (inputs - targets)**2 if valid_mask is not None: value = value[valid_mask] if reduction == 'mean': return -10 * torch.log10(torch.mean(value)) elif reduction == 'none': return -10 * torch.log10(torch.mean(value, dim=tuple(range(value.ndim)[1:]))) class SSIM(): def __init__(self, data_range=(0, 1), kernel_size=(11, 11), sigma=(1.5, 1.5), k1=0.01, k2=0.03, gaussian=True): self.kernel_size = kernel_size self.sigma = sigma self.gaussian = gaussian if any(x % 2 == 0 or x <= 0 for x in self.kernel_size): raise ValueError(f"Expected kernel_size to have odd positive number. Got {kernel_size}.") if any(y <= 0 for y in self.sigma): raise ValueError(f"Expected sigma to have positive number. Got {sigma}.") data_scale = data_range[1] - data_range[0] self.c1 = (k1 * data_scale)**2 self.c2 = (k2 * data_scale)**2 self.pad_h = (self.kernel_size[0] - 1) // 2 self.pad_w = (self.kernel_size[1] - 1) // 2 self._kernel = self._gaussian_or_uniform_kernel(kernel_size=self.kernel_size, sigma=self.sigma) def _uniform(self, kernel_size): max, min = 2.5, -2.5 ksize_half = (kernel_size - 1) * 0.5 kernel = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) for i, j in enumerate(kernel): if min <= j <= max: kernel[i] = 1 / (max - min) else: kernel[i] = 0 return kernel.unsqueeze(dim=0) # (1, kernel_size) def _gaussian(self, kernel_size, sigma): ksize_half = (kernel_size - 1) * 0.5 kernel = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) gauss = torch.exp(-0.5 * (kernel / sigma).pow(2)) return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size) def _gaussian_or_uniform_kernel(self, kernel_size, sigma): if self.gaussian: kernel_x = self._gaussian(kernel_size[0], sigma[0]) kernel_y = self._gaussian(kernel_size[1], sigma[1]) else: kernel_x = self._uniform(kernel_size[0]) kernel_y = self._uniform(kernel_size[1]) return torch.matmul(kernel_x.t(), kernel_y) # (kernel_size, 1) * (1, kernel_size) def __call__(self, output, target, reduction='mean'): if output.dtype != target.dtype: raise TypeError( f"Expected output and target to have the same data type. Got output: {output.dtype} and y: {target.dtype}." ) if output.shape != target.shape: raise ValueError( f"Expected output and target to have the same shape. Got output: {output.shape} and y: {target.shape}." ) if len(output.shape) != 4 or len(target.shape) != 4: raise ValueError( f"Expected output and target to have BxCxHxW shape. Got output: {output.shape} and y: {target.shape}." ) assert reduction in ['mean', 'sum', 'none'] channel = output.size(1) if len(self._kernel.shape) < 4: self._kernel = self._kernel.expand(channel, 1, -1, -1) output = F.pad(output, [self.pad_w, self.pad_w, self.pad_h, self.pad_h], mode="reflect") target = F.pad(target, [self.pad_w, self.pad_w, self.pad_h, self.pad_h], mode="reflect") input_list = torch.cat([output, target, output * output, target * target, output * target]) outputs = F.conv2d(input_list, self._kernel, groups=channel) output_list = [outputs[x * output.size(0) : (x + 1) * output.size(0)] for x in range(len(outputs))] mu_pred_sq = output_list[0].pow(2) mu_target_sq = output_list[1].pow(2) mu_pred_target = output_list[0] * output_list[1] sigma_pred_sq = output_list[2] - mu_pred_sq sigma_target_sq = output_list[3] - mu_target_sq sigma_pred_target = output_list[4] - mu_pred_target a1 = 2 * mu_pred_target + self.c1 a2 = 2 * sigma_pred_target + self.c2 b1 = mu_pred_sq + mu_target_sq + self.c1 b2 = sigma_pred_sq + sigma_target_sq + self.c2 ssim_idx = (a1 * a2) / (b1 * b2) _ssim = torch.mean(ssim_idx, (1, 2, 3)) if reduction == 'none': return _ssim elif reduction == 'sum': return _ssim.sum() elif reduction == 'mean': return _ssim.mean() def binary_cross_entropy(input, target, reduction='mean'): """ F.binary_cross_entropy is not numerically stable in mixed-precision training. """ loss = -(target * torch.log(input) + (1 - target) * torch.log(1 - input)) if reduction == 'mean': return loss.mean() elif reduction == 'none': return loss