Spaces:
Build error
Build error
import itertools | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import torch | |
import distutils.util | |
def show_result(num_epoch, G_net, imgs_lr, imgs_hr): | |
with torch.no_grad(): | |
test_images = G_net(imgs_lr) | |
fig, ax = plt.subplots(1, 2) | |
for j in itertools.product(range(2)): | |
ax[j].get_xaxis().set_visible(False) | |
ax[j].get_yaxis().set_visible(False) | |
ax[0].cla() | |
ax[0].imshow(np.transpose(test_images.cpu().numpy()[0] * 0.5 + 0.5, [1,2,0])) | |
ax[1].cla() | |
ax[1].imshow(np.transpose(imgs_hr.cpu().numpy()[0] * 0.5 + 0.5, [1,2,0])) | |
label = 'Epoch {0}'.format(num_epoch) | |
fig.text(0.5, 0.04, label, ha='center') | |
plt.savefig("results/train_out/epoch_" + str(num_epoch) + "_results.png") | |
plt.close('all') #避免内存泄漏 | |
#---------------------------------------------------------# | |
# 将图像转换成RGB图像,防止灰度图在预测时报错。 | |
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB | |
#---------------------------------------------------------# | |
def cvtColor(image): | |
if len(np.shape(image)) == 3 and np.shape(image)[2] == 3: | |
return image | |
else: | |
image = image.convert('RGB') | |
return image | |
def preprocess_input(image, mean, std): | |
image = (image/255 - mean)/std | |
return image | |
def get_lr(optimizer): | |
for param_group in optimizer.param_groups: | |
return param_group['lr'] | |
def print_arguments(args): | |
print("----------- Configuration Arguments -----------") | |
for arg, value in sorted(vars(args).items()): | |
print("%s: %s" % (arg, value)) | |
print("------------------------------------------------") | |
def add_arguments(argname, type, default, help, argparser, **kwargs): | |
type = distutils.util.strtobool if type == bool else type | |
argparser.add_argument("--" + argname, | |
default=default, | |
type=type, | |
help=help + ' 默认: %(default)s.', | |
**kwargs) |