Spaces:
Build error
Build error
File size: 5,366 Bytes
905cd18 db5513e 48222b7 905cd18 db5513e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import itertools
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.nn import functional as F
# import cv2
import distutils.util
def show_result(num_epoch, G_net, imgs_lr, imgs_hr):
with torch.no_grad():
test_images = G_net(imgs_lr)
fig, ax = plt.subplots(1, 2)
for j in itertools.product(range(2)):
ax[j].get_xaxis().set_visible(False)
ax[j].get_yaxis().set_visible(False)
ax[0].cla()
ax[0].imshow(np.transpose(test_images.cpu().numpy()[0] * 0.5 + 0.5, [1,2,0]))
ax[1].cla()
ax[1].imshow(np.transpose(imgs_hr.cpu().numpy()[0] * 0.5 + 0.5, [1,2,0]))
label = 'Epoch {0}'.format(num_epoch)
fig.text(0.5, 0.04, label, ha='center')
plt.savefig("results/train_out/epoch_" + str(num_epoch) + "_results.png")
plt.close('all') #避免内存泄漏
#---------------------------------------------------------#
# 将图像转换成RGB图像,防止灰度图在预测时报错。
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
#---------------------------------------------------------#
def cvtColor(image):
if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
return image
else:
image = image.convert('RGB')
return image
def preprocess_input(image, mean, std):
image = (image/255 - mean)/std
return image
def get_lr(optimizer):
for param_group in optimizer.param_groups:
return param_group['lr']
def print_arguments(args):
print("----------- Configuration Arguments -----------")
for arg, value in sorted(vars(args).items()):
print("%s: %s" % (arg, value))
print("------------------------------------------------")
def add_arguments(argname, type, default, help, argparser, **kwargs):
type = distutils.util.strtobool if type == bool else type
argparser.add_argument("--" + argname,
default=default,
type=type,
help=help + ' 默认: %(default)s.',
**kwargs)
def filter2D(img, kernel):
"""PyTorch version of cv2.filter2D
Args:
img (Tensor): (b, c, h, w)
kernel (Tensor): (b, k, k)
"""
k = kernel.size(-1)
b, c, h, w = img.size()
if k % 2 == 1:
img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect')
else:
raise ValueError('Wrong kernel size')
ph, pw = img.size()[-2:]
if kernel.size(0) == 1:
# apply the same kernel to all batch images
img = img.view(b * c, 1, ph, pw)
kernel = kernel.view(1, 1, k, k)
return F.conv2d(img, kernel, padding=0).view(b, c, h, w)
else:
img = img.view(1, b * c, ph, pw)
kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k)
return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w)
def usm_sharp(img, weight=0.5, radius=50, threshold=10):
"""USM sharpening.
Input image: I; Blurry image: B.
1. sharp = I + weight * (I - B)
2. Mask = 1 if abs(I - B) > threshold, else: 0
3. Blur mask:
4. Out = Mask * sharp + (1 - Mask) * I
Args:
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
weight (float): Sharp weight. Default: 1.
radius (float): Kernel size of Gaussian blur. Default: 50.
threshold (int):
"""
if radius % 2 == 0:
radius += 1
blur = cv2.GaussianBlur(img, (radius, radius), 0)
residual = img - blur
mask = np.abs(residual) * 255 > threshold
mask = mask.astype('float32')
soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
sharp = img + weight * residual
sharp = np.clip(sharp, 0, 1)
return soft_mask * sharp + (1 - soft_mask) * img
class USMSharp(torch.nn.Module):
def __init__(self, radius=50, sigma=0):
super(USMSharp, self).__init__()
if radius % 2 == 0:
radius += 1
self.radius = radius
kernel = cv2.getGaussianKernel(radius, sigma)
kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0)
self.register_buffer('kernel', kernel)
def forward(self, img, weight=0.5, threshold=10):
blur = filter2D(img, self.kernel)
residual = img - blur
mask = torch.abs(residual) * 255 > threshold
mask = mask.float()
soft_mask = filter2D(mask, self.kernel)
sharp = img + weight * residual
sharp = torch.clip(sharp, 0, 1)
return soft_mask * sharp + (1 - soft_mask) * img
class USMSharp_npy():
def __init__(self, radius=50, sigma=0):
super(USMSharp_npy, self).__init__()
if radius % 2 == 0:
radius += 1
self.radius = radius
kernel = cv2.getGaussianKernel(radius, sigma)
self.kernel = np.dot(kernel, kernel.transpose()).astype(np.float32)
def filt(self, img, weight=0.5, threshold=10):
blur = cv2.filter2D(img, -1, self.kernel)
residual = img - blur
mask = np.abs(residual) * 255 > threshold
mask = mask.astype(np.float32)
soft_mask = cv2.filter2D(mask, -1, self.kernel)
sharp = img + weight * residual
sharp = np.clip(sharp, 0, 1)
return soft_mask * sharp + (1 - soft_mask) * img
|