|
import itertools |
|
import math |
|
from functools import partial |
|
|
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
|
|
|
|
|
|
|
|
|
|
|
|
def cvtColor(image): |
|
if len(np.shape(image)) == 3 and np.shape(image)[2] == 3: |
|
return image |
|
else: |
|
image = image.convert('RGB') |
|
return image |
|
|
|
|
|
|
|
|
|
def resize_image(image, size, letterbox_image): |
|
iw, ih = image.size |
|
w, h = size |
|
if letterbox_image: |
|
scale = min(w/iw, h/ih) |
|
nw = int(iw*scale) |
|
nh = int(ih*scale) |
|
|
|
image = image.resize((nw,nh), Image.BICUBIC) |
|
new_image = Image.new('RGB', size, (128, 128, 128)) |
|
new_image.paste(image, ((w-nw)//2, (h-nh)//2)) |
|
return new_image, nw, nh |
|
else: |
|
new_image = image.resize((w, h), Image.BICUBIC) |
|
return new_image, None, None |
|
|
|
|
|
|
|
|
|
def preprocess_input(x): |
|
x /= 255 |
|
x -= 0.5 |
|
x /= 0.5 |
|
return x |
|
|
|
def postprocess_output(x): |
|
x *= 0.5 |
|
x += 0.5 |
|
x *= 255 |
|
return x |
|
|
|
def show_result(num_epoch, G_model_A2B_train, G_model_B2A_train, images_A, images_B): |
|
with torch.no_grad(): |
|
fake_image_B = G_model_A2B_train(images_A) |
|
fake_image_A = G_model_B2A_train(images_B) |
|
|
|
fig, ax = plt.subplots(2, 2) |
|
|
|
ax = ax.flatten() |
|
for j in itertools.product(range(4)): |
|
ax[j].get_xaxis().set_visible(False) |
|
ax[j].get_yaxis().set_visible(False) |
|
|
|
ax[0].cla() |
|
ax[0].imshow(np.transpose(np.uint8(postprocess_output(images_A.cpu().numpy()[0])), [1, 2, 0])) |
|
|
|
ax[1].cla() |
|
ax[1].imshow(np.transpose(np.clip(fake_image_B.cpu().numpy()[0] * 0.5 + 0.5, 0, 1), [1,2,0])) |
|
|
|
ax[2].cla() |
|
ax[2].imshow(np.transpose(np.uint8(postprocess_output(images_B.cpu().numpy()[0])), [1, 2, 0])) |
|
|
|
ax[3].cla() |
|
ax[3].imshow(np.transpose(np.clip(fake_image_A.cpu().numpy()[0] * 0.5 + 0.5, 0, 1), [1,2,0])) |
|
|
|
label = 'Epoch {0}'.format(num_epoch) |
|
fig.text(0.5, 0.04, label, ha='center') |
|
plt.savefig("results/train_out/epoch_" + str(num_epoch) + "_results.png") |
|
plt.close('all') |
|
|
|
def show_config(**kwargs): |
|
print('Configurations:') |
|
print('-' * 70) |
|
print('|%25s | %40s|' % ('keys', 'values')) |
|
print('-' * 70) |
|
for key, value in kwargs.items(): |
|
print('|%25s | %40s|' % (str(key), str(value))) |
|
print('-' * 70) |
|
|
|
|
|
|
|
|
|
def get_lr(optimizer): |
|
for param_group in optimizer.param_groups: |
|
return param_group['lr'] |
|
|
|
def get_lr_scheduler(lr_decay_type, lr, min_lr, total_iters, warmup_iters_ratio = 0.05, warmup_lr_ratio = 0.1, no_aug_iter_ratio = 0.05, step_num = 10): |
|
def yolox_warm_cos_lr(lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter, iters): |
|
if iters <= warmup_total_iters: |
|
|
|
lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2) + warmup_lr_start |
|
elif iters >= total_iters - no_aug_iter: |
|
lr = min_lr |
|
else: |
|
lr = min_lr + 0.5 * (lr - min_lr) * ( |
|
1.0 + math.cos(math.pi* (iters - warmup_total_iters) / (total_iters - warmup_total_iters - no_aug_iter)) |
|
) |
|
return lr |
|
|
|
def step_lr(lr, decay_rate, step_size, iters): |
|
if step_size < 1: |
|
raise ValueError("step_size must above 1.") |
|
n = iters // step_size |
|
out_lr = lr * decay_rate ** n |
|
return out_lr |
|
|
|
if lr_decay_type == "cos": |
|
warmup_total_iters = min(max(warmup_iters_ratio * total_iters, 1), 3) |
|
warmup_lr_start = max(warmup_lr_ratio * lr, 1e-6) |
|
no_aug_iter = min(max(no_aug_iter_ratio * total_iters, 1), 15) |
|
func = partial(yolox_warm_cos_lr ,lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter) |
|
else: |
|
decay_rate = (min_lr / lr) ** (1 / (step_num - 1)) |
|
step_size = total_iters / step_num |
|
func = partial(step_lr, lr, decay_rate, step_size) |
|
|
|
return func |
|
|
|
def set_optimizer_lr(optimizer, lr_scheduler_func, epoch): |
|
lr = lr_scheduler_func(epoch) |
|
for param_group in optimizer.param_groups: |
|
param_group['lr'] = lr |
|
|