|
import itertools |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import torch |
|
from torch.nn import functional as F |
|
import cv2 |
|
import distutils.util |
|
|
|
def show_result(num_epoch, G_net, imgs_lr, imgs_hr): |
|
with torch.no_grad(): |
|
test_images = G_net(imgs_lr) |
|
|
|
fig, ax = plt.subplots(1, 3) |
|
|
|
for j in itertools.product(range(3)): |
|
ax[j].get_xaxis().set_visible(False) |
|
ax[j].get_yaxis().set_visible(False) |
|
ax[0].cla() |
|
ax[0].imshow(np.transpose(np.clip(imgs_lr.cpu().numpy()[0] * 0.5 + 0.5, 0, 1), [1,2,0])) |
|
|
|
ax[1].cla() |
|
ax[1].imshow(np.transpose(np.clip(test_images.cpu().numpy()[0] * 0.5 + 0.5, 0, 1), [1,2,0])) |
|
|
|
ax[2].cla() |
|
ax[2].imshow(np.transpose(np.clip(imgs_hr.cpu().numpy()[0] * 0.5 + 0.5, 0, 1), [1,2,0])) |
|
|
|
label = 'Epoch {0}'.format(num_epoch) |
|
fig.text(0.5, 0.04, label, ha='center') |
|
plt.savefig("results/train_out/epoch_" + str(num_epoch) + "_results.png") |
|
plt.close('all') |
|
|
|
|
|
|
|
|
|
|
|
def cvtColor(image): |
|
if len(np.shape(image)) == 3 and np.shape(image)[2] == 3: |
|
return image |
|
else: |
|
image = image.convert('RGB') |
|
return image |
|
|
|
def preprocess_input(image, mean, std): |
|
image = (image/255 - mean)/std |
|
return image |
|
|
|
def get_lr(optimizer): |
|
for param_group in optimizer.param_groups: |
|
return param_group['lr'] |
|
|
|
def print_arguments(args): |
|
print("----------- Configuration Arguments -----------") |
|
for arg, value in sorted(vars(args).items()): |
|
print("%s: %s" % (arg, value)) |
|
print("------------------------------------------------") |
|
|
|
|
|
def add_arguments(argname, type, default, help, argparser, **kwargs): |
|
type = distutils.util.strtobool if type == bool else type |
|
argparser.add_argument("--" + argname, |
|
default=default, |
|
type=type, |
|
help=help + ' 默认: %(default)s.', |
|
**kwargs) |
|
|
|
def filter2D(img, kernel): |
|
"""PyTorch version of cv2.filter2D |
|
|
|
Args: |
|
img (Tensor): (b, c, h, w) |
|
kernel (Tensor): (b, k, k) |
|
""" |
|
k = kernel.size(-1) |
|
b, c, h, w = img.size() |
|
if k % 2 == 1: |
|
img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect') |
|
else: |
|
raise ValueError('Wrong kernel size') |
|
|
|
ph, pw = img.size()[-2:] |
|
|
|
if kernel.size(0) == 1: |
|
|
|
img = img.view(b * c, 1, ph, pw) |
|
kernel = kernel.view(1, 1, k, k) |
|
return F.conv2d(img, kernel, padding=0).view(b, c, h, w) |
|
else: |
|
img = img.view(1, b * c, ph, pw) |
|
kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k) |
|
return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w) |
|
|
|
|
|
def usm_sharp(img, weight=0.5, radius=50, threshold=10): |
|
"""USM sharpening. |
|
|
|
Input image: I; Blurry image: B. |
|
1. sharp = I + weight * (I - B) |
|
2. Mask = 1 if abs(I - B) > threshold, else: 0 |
|
3. Blur mask: |
|
4. Out = Mask * sharp + (1 - Mask) * I |
|
|
|
|
|
Args: |
|
img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. |
|
weight (float): Sharp weight. Default: 1. |
|
radius (float): Kernel size of Gaussian blur. Default: 50. |
|
threshold (int): |
|
""" |
|
if radius % 2 == 0: |
|
radius += 1 |
|
blur = cv2.GaussianBlur(img, (radius, radius), 0) |
|
residual = img - blur |
|
mask = np.abs(residual) * 255 > threshold |
|
mask = mask.astype('float32') |
|
soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) |
|
|
|
sharp = img + weight * residual |
|
sharp = np.clip(sharp, 0, 1) |
|
return soft_mask * sharp + (1 - soft_mask) * img |
|
|
|
|
|
class USMSharp(torch.nn.Module): |
|
|
|
def __init__(self, radius=50, sigma=0): |
|
super(USMSharp, self).__init__() |
|
if radius % 2 == 0: |
|
radius += 1 |
|
self.radius = radius |
|
kernel = cv2.getGaussianKernel(radius, sigma) |
|
kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0) |
|
self.register_buffer('kernel', kernel) |
|
|
|
def forward(self, img, weight=0.5, threshold=10): |
|
blur = filter2D(img, self.kernel) |
|
residual = img - blur |
|
|
|
mask = torch.abs(residual) * 255 > threshold |
|
mask = mask.float() |
|
soft_mask = filter2D(mask, self.kernel) |
|
sharp = img + weight * residual |
|
sharp = torch.clip(sharp, 0, 1) |
|
return soft_mask * sharp + (1 - soft_mask) * img |
|
|
|
class USMSharp_npy(): |
|
|
|
def __init__(self, radius=50, sigma=0): |
|
super(USMSharp_npy, self).__init__() |
|
if radius % 2 == 0: |
|
radius += 1 |
|
self.radius = radius |
|
kernel = cv2.getGaussianKernel(radius, sigma) |
|
self.kernel = np.dot(kernel, kernel.transpose()).astype(np.float32) |
|
|
|
def filt(self, img, weight=0.5, threshold=10): |
|
blur = cv2.filter2D(img, -1, self.kernel) |
|
residual = img - blur |
|
|
|
mask = np.abs(residual) * 255 > threshold |
|
mask = mask.astype(np.float32) |
|
soft_mask = cv2.filter2D(mask, -1, self.kernel) |
|
sharp = img + weight * residual |
|
sharp = np.clip(sharp, 0, 1) |
|
return soft_mask * sharp + (1 - soft_mask) * img |
|
|
|
|