import itertools import math from functools import partial import matplotlib.pyplot as plt import numpy as np import torch from PIL import Image #---------------------------------------------------------# # 将图像转换成RGB图像,防止灰度图在预测时报错。 # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB #---------------------------------------------------------# def cvtColor(image): if len(np.shape(image)) == 3 and np.shape(image)[2] == 3: return image else: image = image.convert('RGB') return image #---------------------------------------------------# # 对输入图像进行resize #---------------------------------------------------# def resize_image(image, size, letterbox_image): iw, ih = image.size w, h = size if letterbox_image: scale = min(w/iw, h/ih) nw = int(iw*scale) nh = int(ih*scale) image = image.resize((nw,nh), Image.BICUBIC) new_image = Image.new('RGB', size, (128, 128, 128)) new_image.paste(image, ((w-nw)//2, (h-nh)//2)) return new_image, nw, nh else: new_image = image.resize((w, h), Image.BICUBIC) return new_image, None, None #----------------------------------------# # 预处理训练图片 #----------------------------------------# def preprocess_input(x): x /= 255 x -= 0.5 x /= 0.5 return x def postprocess_output(x): x *= 0.5 x += 0.5 x *= 255 return x def show_result(num_epoch, G_model_A2B_train, G_model_B2A_train, images_A, images_B): with torch.no_grad(): fake_image_B = G_model_A2B_train(images_A) fake_image_A = G_model_B2A_train(images_B) fig, ax = plt.subplots(2, 2) ax = ax.flatten() for j in itertools.product(range(4)): ax[j].get_xaxis().set_visible(False) ax[j].get_yaxis().set_visible(False) ax[0].cla() ax[0].imshow(np.transpose(np.uint8(postprocess_output(images_A.cpu().numpy()[0])), [1, 2, 0])) ax[1].cla() ax[1].imshow(np.transpose(np.clip(fake_image_B.cpu().numpy()[0] * 0.5 + 0.5, 0, 1), [1,2,0])) ax[2].cla() ax[2].imshow(np.transpose(np.uint8(postprocess_output(images_B.cpu().numpy()[0])), [1, 2, 0])) ax[3].cla() ax[3].imshow(np.transpose(np.clip(fake_image_A.cpu().numpy()[0] * 0.5 + 0.5, 0, 1), [1,2,0])) label = 'Epoch {0}'.format(num_epoch) fig.text(0.5, 0.04, label, ha='center') plt.savefig("results/train_out/epoch_" + str(num_epoch) + "_results.png") plt.close('all') #避免内存泄漏 def show_config(**kwargs): print('Configurations:') print('-' * 70) print('|%25s | %40s|' % ('keys', 'values')) print('-' * 70) for key, value in kwargs.items(): print('|%25s | %40s|' % (str(key), str(value))) print('-' * 70) #---------------------------------------------------# # 获得学习率 #---------------------------------------------------# def get_lr(optimizer): for param_group in optimizer.param_groups: return param_group['lr'] def get_lr_scheduler(lr_decay_type, lr, min_lr, total_iters, warmup_iters_ratio = 0.05, warmup_lr_ratio = 0.1, no_aug_iter_ratio = 0.05, step_num = 10): def yolox_warm_cos_lr(lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter, iters): if iters <= warmup_total_iters: # lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2) + warmup_lr_start elif iters >= total_iters - no_aug_iter: lr = min_lr else: lr = min_lr + 0.5 * (lr - min_lr) * ( 1.0 + math.cos(math.pi* (iters - warmup_total_iters) / (total_iters - warmup_total_iters - no_aug_iter)) ) return lr def step_lr(lr, decay_rate, step_size, iters): if step_size < 1: raise ValueError("step_size must above 1.") n = iters // step_size out_lr = lr * decay_rate ** n return out_lr if lr_decay_type == "cos": warmup_total_iters = min(max(warmup_iters_ratio * total_iters, 1), 3) warmup_lr_start = max(warmup_lr_ratio * lr, 1e-6) no_aug_iter = min(max(no_aug_iter_ratio * total_iters, 1), 15) func = partial(yolox_warm_cos_lr ,lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter) else: decay_rate = (min_lr / lr) ** (1 / (step_num - 1)) step_size = total_iters / step_num func = partial(step_lr, lr, decay_rate, step_size) return func def set_optimizer_lr(optimizer, lr_scheduler_func, epoch): lr = lr_scheduler_func(epoch) for param_group in optimizer.param_groups: param_group['lr'] = lr