|
import torch |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
|
|
def adjust_learning_rate(optimiser, iters, learning_rate_decay, LR): |
|
for param_group in optimiser.param_groups: |
|
param_group['lr'] = LR / (1.0 + learning_rate_decay * iters) |
|
|
|
def concat_img(imgs, batch): |
|
plt.figure() |
|
|
|
imgs = imgs.movedim((0, 1, 2, 3), (0, 3, 1, 2)).detach().cpu().numpy() |
|
axs = plt.imshow(np.concatenate(imgs.tolist(), axis=1)) |
|
plt.axis('off') |
|
plt.savefig("../../produced-images/batch{}img.png".format(batch)) |
|
plt.close() |
|
|
|
def concat_img(imgs, batch): |
|
plt.figure() |
|
|
|
imgs = imgs.movedim((0, 1, 2, 3), (0, 3, 1, 2)).detach().cpu().numpy() |
|
axs = plt.imshow(np.concatenate(imgs.tolist(), axis=1)) |
|
plt.axis('off') |
|
plt.savefig("../../produced-images/batch{}img.png".format(batch)) |
|
|
|
|
|
def mean_and_std_of_image(x): |
|
x_size = x.size() |
|
|
|
x = x.view(x.shape[0], x.shape[1], -1) |
|
|
|
mean = x.mean(dim=2) |
|
std = x.var(dim=2).sqrt() |
|
|
|
|
|
|
|
|
|
|
|
mean = mean.view(mean.shape[0], mean.shape[1], 1, 1) |
|
std = std.view(std.shape[0], std.shape[1], 1, 1) |
|
|
|
return (mean, std) |