GCycleGAN / utils /utils.py
Egrt's picture
init
95e767b
raw
history blame
4.96 kB
import itertools
import math
from functools import partial
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
#---------------------------------------------------------#
# 将图像转换成RGB图像,防止灰度图在预测时报错。
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
#---------------------------------------------------------#
def cvtColor(image):
if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
return image
else:
image = image.convert('RGB')
return image
#---------------------------------------------------#
# 对输入图像进行resize
#---------------------------------------------------#
def resize_image(image, size, letterbox_image):
iw, ih = image.size
w, h = size
if letterbox_image:
scale = min(w/iw, h/ih)
nw = int(iw*scale)
nh = int(ih*scale)
image = image.resize((nw,nh), Image.BICUBIC)
new_image = Image.new('RGB', size, (128, 128, 128))
new_image.paste(image, ((w-nw)//2, (h-nh)//2))
return new_image, nw, nh
else:
new_image = image.resize((w, h), Image.BICUBIC)
return new_image, None, None
#----------------------------------------#
# 预处理训练图片
#----------------------------------------#
def preprocess_input(x):
x /= 255
x -= 0.5
x /= 0.5
return x
def postprocess_output(x):
x *= 0.5
x += 0.5
x *= 255
return x
def show_result(num_epoch, G_model_A2B_train, G_model_B2A_train, images_A, images_B):
with torch.no_grad():
fake_image_B = G_model_A2B_train(images_A)
fake_image_A = G_model_B2A_train(images_B)
fig, ax = plt.subplots(2, 2)
ax = ax.flatten()
for j in itertools.product(range(4)):
ax[j].get_xaxis().set_visible(False)
ax[j].get_yaxis().set_visible(False)
ax[0].cla()
ax[0].imshow(np.transpose(np.uint8(postprocess_output(images_A.cpu().numpy()[0])), [1, 2, 0]))
ax[1].cla()
ax[1].imshow(np.transpose(np.clip(fake_image_B.cpu().numpy()[0] * 0.5 + 0.5, 0, 1), [1,2,0]))
ax[2].cla()
ax[2].imshow(np.transpose(np.uint8(postprocess_output(images_B.cpu().numpy()[0])), [1, 2, 0]))
ax[3].cla()
ax[3].imshow(np.transpose(np.clip(fake_image_A.cpu().numpy()[0] * 0.5 + 0.5, 0, 1), [1,2,0]))
label = 'Epoch {0}'.format(num_epoch)
fig.text(0.5, 0.04, label, ha='center')
plt.savefig("results/train_out/epoch_" + str(num_epoch) + "_results.png")
plt.close('all') #避免内存泄漏
def show_config(**kwargs):
print('Configurations:')
print('-' * 70)
print('|%25s | %40s|' % ('keys', 'values'))
print('-' * 70)
for key, value in kwargs.items():
print('|%25s | %40s|' % (str(key), str(value)))
print('-' * 70)
#---------------------------------------------------#
# 获得学习率
#---------------------------------------------------#
def get_lr(optimizer):
for param_group in optimizer.param_groups:
return param_group['lr']
def get_lr_scheduler(lr_decay_type, lr, min_lr, total_iters, warmup_iters_ratio = 0.05, warmup_lr_ratio = 0.1, no_aug_iter_ratio = 0.05, step_num = 10):
def yolox_warm_cos_lr(lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter, iters):
if iters <= warmup_total_iters:
# lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start
lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2) + warmup_lr_start
elif iters >= total_iters - no_aug_iter:
lr = min_lr
else:
lr = min_lr + 0.5 * (lr - min_lr) * (
1.0 + math.cos(math.pi* (iters - warmup_total_iters) / (total_iters - warmup_total_iters - no_aug_iter))
)
return lr
def step_lr(lr, decay_rate, step_size, iters):
if step_size < 1:
raise ValueError("step_size must above 1.")
n = iters // step_size
out_lr = lr * decay_rate ** n
return out_lr
if lr_decay_type == "cos":
warmup_total_iters = min(max(warmup_iters_ratio * total_iters, 1), 3)
warmup_lr_start = max(warmup_lr_ratio * lr, 1e-6)
no_aug_iter = min(max(no_aug_iter_ratio * total_iters, 1), 15)
func = partial(yolox_warm_cos_lr ,lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter)
else:
decay_rate = (min_lr / lr) ** (1 / (step_num - 1))
step_size = total_iters / step_num
func = partial(step_lr, lr, decay_rate, step_size)
return func
def set_optimizer_lr(optimizer, lr_scheduler_func, epoch):
lr = lr_scheduler_func(epoch)
for param_group in optimizer.param_groups:
param_group['lr'] = lr