# https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/utils/img_process_util.py import cv2 import numpy as np import torch from .common import filter2D class USMSharp(torch.nn.Module): def __init__(self, radius=50, sigma=0): super(USMSharp, self).__init__() if radius % 2 == 0: radius += 1 self.radius = radius kernel = cv2.getGaussianKernel(radius, sigma) kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0) self.register_buffer('kernel', kernel) def forward(self, img, weight=0.5, threshold=10): blur = filter2D(img, self.kernel) residual = img - blur mask = torch.abs(residual) * 255 > threshold mask = mask.float() soft_mask = filter2D(mask, self.kernel) sharp = img + weight * residual sharp = torch.clip(sharp, 0, 1) return soft_mask * sharp + (1 - soft_mask) * img