import itertools import numpy as np import matplotlib.pyplot as plt import torch import distutils.util def show_result(num_epoch, G_net, imgs_lr, imgs_hr): with torch.no_grad(): test_images = G_net(imgs_lr) fig, ax = plt.subplots(1, 2) for j in itertools.product(range(2)): ax[j].get_xaxis().set_visible(False) ax[j].get_yaxis().set_visible(False) ax[0].cla() ax[0].imshow(np.transpose(test_images.cpu().numpy()[0] * 0.5 + 0.5, [1,2,0])) ax[1].cla() ax[1].imshow(np.transpose(imgs_hr.cpu().numpy()[0] * 0.5 + 0.5, [1,2,0])) label = 'Epoch {0}'.format(num_epoch) fig.text(0.5, 0.04, label, ha='center') plt.savefig("results/train_out/epoch_" + str(num_epoch) + "_results.png") plt.close('all') #避免内存泄漏 #---------------------------------------------------------# # 将图像转换成RGB图像,防止灰度图在预测时报错。 # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB #---------------------------------------------------------# def cvtColor(image): if len(np.shape(image)) == 3 and np.shape(image)[2] == 3: return image else: image = image.convert('RGB') return image def preprocess_input(image, mean, std): image = (image/255 - mean)/std return image def get_lr(optimizer): for param_group in optimizer.param_groups: return param_group['lr'] def print_arguments(args): print("----------- Configuration Arguments -----------") for arg, value in sorted(vars(args).items()): print("%s: %s" % (arg, value)) print("------------------------------------------------") def add_arguments(argname, type, default, help, argparser, **kwargs): type = distutils.util.strtobool if type == bool else type argparser.add_argument("--" + argname, default=default, type=type, help=help + ' 默认: %(default)s.', **kwargs)