import functools import torch from torch.nn import functional as F def reduce_loss(loss, reduction): """Reduce loss as specified. Args: loss (Tensor): Elementwise loss tensor. reduction (str): Options are 'none', 'mean' and 'sum'. Returns: Tensor: Reduced loss tensor. """ reduction_enum = F._Reduction.get_enum(reduction) # none: 0, elementwise_mean:1, sum: 2 if reduction_enum == 0: return loss elif reduction_enum == 1: return loss.mean() else: return loss.sum() def weight_reduce_loss(loss, weight=None, reduction='mean'): """Apply element-wise weight and reduce loss. Args: loss (Tensor): Element-wise loss. weight (Tensor): Element-wise weights. Default: None. reduction (str): Same as built-in losses of PyTorch. Options are 'none', 'mean' and 'sum'. Default: 'mean'. Returns: Tensor: Loss values. """ # if weight is specified, apply element-wise weight if weight is not None: assert weight.dim() == loss.dim() assert weight.size(1) == 1 or weight.size(1) == loss.size(1) loss = loss * weight # if weight is not specified or reduction is sum, just reduce the loss if weight is None or reduction == 'sum': loss = reduce_loss(loss, reduction) # if reduction is mean, then compute mean over weight region elif reduction == 'mean': if weight.size(1) > 1: weight = weight.sum() else: weight = weight.sum() * loss.size(1) loss = loss.sum() / weight return loss def weighted_loss(loss_func): """Create a weighted version of a given loss function. To use this decorator, the loss function must have the signature like `loss_func(pred, target, **kwargs)`. The function only needs to compute element-wise loss without any reduction. This decorator will add weight and reduction arguments to the function. The decorated function will have the signature like `loss_func(pred, target, weight=None, reduction='mean', **kwargs)`. :Example: >>> import torch >>> @weighted_loss >>> def l1_loss(pred, target): >>> return (pred - target).abs() >>> pred = torch.Tensor([0, 2, 3]) >>> target = torch.Tensor([1, 1, 1]) >>> weight = torch.Tensor([1, 0, 1]) >>> l1_loss(pred, target) tensor(1.3333) >>> l1_loss(pred, target, weight) tensor(1.5000) >>> l1_loss(pred, target, reduction='none') tensor([1., 1., 2.]) >>> l1_loss(pred, target, weight, reduction='sum') tensor(3.) """ @functools.wraps(loss_func) def wrapper(pred, target, weight=None, reduction='mean', **kwargs): # get element-wise loss loss = loss_func(pred, target, **kwargs) loss = weight_reduce_loss(loss, weight, reduction) return loss return wrapper def get_local_weights(residual, ksize): """Get local weights for generating the artifact map of LDL. It is only called by the `get_refined_artifact_map` function. Args: residual (Tensor): Residual between predicted and ground truth images. ksize (Int): size of the local window. Returns: Tensor: weight for each pixel to be discriminated as an artifact pixel """ pad = (ksize - 1) // 2 residual_pad = F.pad(residual, pad=[pad, pad, pad, pad], mode='reflect') unfolded_residual = residual_pad.unfold(2, ksize, 1).unfold(3, ksize, 1) pixel_level_weight = torch.var(unfolded_residual, dim=(-1, -2), unbiased=True, keepdim=True).squeeze(-1).squeeze(-1) return pixel_level_weight def get_refined_artifact_map(img_gt, img_output, img_ema, ksize): """Calculate the artifact map of LDL (Details or Artifacts: A Locally Discriminative Learning Approach to Realistic Image Super-Resolution. In CVPR 2022) Args: img_gt (Tensor): ground truth images. img_output (Tensor): output images given by the optimizing model. img_ema (Tensor): output images given by the ema model. ksize (Int): size of the local window. Returns: overall_weight: weight for each pixel to be discriminated as an artifact pixel (calculated based on both local and global observations). """ residual_ema = torch.sum(torch.abs(img_gt - img_ema), 1, keepdim=True) residual_sr = torch.sum(torch.abs(img_gt - img_output), 1, keepdim=True) patch_level_weight = torch.var(residual_sr.clone(), dim=(-1, -2, -3), keepdim=True)**(1 / 5) pixel_level_weight = get_local_weights(residual_sr.clone(), ksize) overall_weight = patch_level_weight * pixel_level_weight overall_weight[residual_sr < residual_ema] = 0 return overall_weight