endo-yuki-t
initial commit
d7dbcdd
raw
history blame
No virus
8.17 kB
import os
import math, random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
matplotlib.use('Agg')
import torch
from torch import nn
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
from utils import common
from criteria.lpips.lpips import LPIPS
from models.StyleGANControler import StyleGANControler
from training.ranger import Ranger
from expansion.submission import Expansion
from expansion.utils.flowlib import point_vec
class Coach:
def __init__(self, opts):
self.opts = opts
if self.opts.checkpoint_path is None:
self.global_step = 0
else:
self.global_step = int(os.path.splitext(os.path.basename(self.opts.checkpoint_path))[0].split('_')[-1])
self.device = 'cuda:0' # TODO: Allow multiple GPU? currently using CUDA_VISIBLE_DEVICES
self.opts.device = self.device
# Initialize network
self.net = StyleGANControler(self.opts).to(self.device)
# Initialize loss
if self.opts.lpips_lambda > 0:
self.lpips_loss = LPIPS(net_type='alex').to(self.device).eval()
self.mse_loss = nn.MSELoss().to(self.device).eval()
# Initialize optimizer
self.optimizer = self.configure_optimizers()
# Initialize logger
log_dir = os.path.join(opts.exp_dir, 'logs')
os.makedirs(log_dir, exist_ok=True)
self.logger = SummaryWriter(log_dir=log_dir)
# Initialize checkpoint dir
self.checkpoint_dir = os.path.join(opts.exp_dir, 'checkpoints')
os.makedirs(self.checkpoint_dir, exist_ok=True)
self.best_val_loss = None
if self.opts.save_interval is None:
self.opts.save_interval = self.opts.max_steps
# Initialize optical flow estimator
self.ex = Expansion()
# Set flow normalization values
if 'ffhq' in self.opts.stylegan_weights:
self.sigma_f = 4
self.sigma_e = 0.02
elif 'car' in self.opts.stylegan_weights:
self.sigma_f = 5
self.sigma_e = 0.03
elif 'cat' in self.opts.stylegan_weights:
self.sigma_f = 12
self.sigma_e = 0.04
elif 'church' in self.opts.stylegan_weights:
self.sigma_f = 8
self.sigma_e = 0.02
elif 'anime' in self.opts.stylegan_weights:
self.sigma_f = 7
self.sigma_e = 0.025
def train(self, truncation = 0.3, sigma = 0.1, target_layers = [0,1,2,3,4,5]):
x = np.array(range(0,256,16)).astype(np.float32)/127.5-1.
y = np.array(range(0,256,16)).astype(np.float32)/127.5-1.
xx, yy = np.meshgrid(x,y)
grid = np.concatenate([xx[:,:,None],yy[:,:,None]], axis=2)
grid = torch.from_numpy(grid[None,:]).cuda()
grid = grid.repeat(self.opts.batch_size,1,1,1)
while self.global_step < self.opts.max_steps:
with torch.no_grad():
z1 = torch.randn(self.opts.batch_size,512).to("cuda")
z2 = torch.randn(self.opts.batch_size,self.net.style_num, 512).to("cuda")
x1, w1, f1 = self.net.decoder([z1],input_is_latent=False,randomize_noise=False,return_feature_map=True,return_latents=True,truncation=truncation, truncation_latent=self.net.latent_avg[0])
x1 = self.net.face_pool(x1)
x2, w2 = self.net.decoder([z2],input_is_latent=False,randomize_noise=False,return_latents=True, truncation_latent=self.net.latent_avg[0])
x2 = self.net.face_pool(x2)
w_mid = w1.clone()
w_mid[:,target_layers] = w_mid[:,target_layers]+sigma*(w2[:,target_layers]-w_mid[:,target_layers])
x_mid, _ = self.net.decoder([w_mid], input_is_latent=True, randomize_noise=False, return_latents=False)
x_mid = self.net.face_pool(x_mid)
flow, logexp = self.ex.run(x1.detach(),x_mid.detach())
flow_feature = torch.cat([flow/self.sigma_f, logexp/self.sigma_e], dim=1)
f1 = F.interpolate(f1, (flow_feature.shape[2:]))
f1 = F.grid_sample(f1, grid, mode='nearest', align_corners=True)
flow_feature = F.grid_sample(flow_feature, grid, mode='nearest', align_corners=True)
flow_feature = flow_feature.view(flow_feature.shape[0], flow_feature.shape[1], -1).permute(0,2,1)
f1 = f1.view(f1.shape[0], f1.shape[1], -1).permute(0,2,1)
self.net.train()
self.optimizer.zero_grad()
w_hat = self.net.encoder(w1[:,target_layers].detach(), flow_feature.detach(), f1.detach())
loss, loss_dict, id_logs = self.calc_loss(w_hat, w_mid[:,target_layers].detach())
loss.backward()
self.optimizer.step()
w_mid[:,target_layers] = w_hat.detach()
x_hat, _ = self.net.decoder([w_mid], input_is_latent=True, randomize_noise=False)
x_hat = self.net.face_pool(x_hat)
if self.global_step % self.opts.image_interval == 0 or (
self.global_step < 1000 and self.global_step % 100 == 0):
imgL_o = ((x1.detach()+1.)*127.5)[0].permute(1,2,0).cpu().numpy()
flow = torch.cat((flow,torch.ones_like(flow)[:,:1]), dim=1)[0].permute(1,2,0).cpu().numpy()
flowvis = point_vec(imgL_o, flow)
flowvis = torch.from_numpy(flowvis[:,:,::-1].copy()).permute(2,0,1).unsqueeze(0)/127.5-1.
self.parse_and_log_images(None, flowvis, x_mid, x_hat, title='trained_images')
print(loss_dict)
if self.global_step % self.opts.save_interval == 0 or self.global_step == self.opts.max_steps:
self.checkpoint_me(loss_dict, is_best=False)
if self.global_step == self.opts.max_steps:
print('OMG, finished training!')
break
self.global_step += 1
def checkpoint_me(self, loss_dict, is_best):
save_name = 'best_model.pt' if is_best else 'iteration_{}.pt'.format(self.global_step)
save_dict = self.__get_save_dict()
checkpoint_path = os.path.join(self.checkpoint_dir, save_name)
torch.save(save_dict, checkpoint_path)
with open(os.path.join(self.checkpoint_dir, 'timestamp.txt'), 'a') as f:
if is_best:
f.write('**Best**: Step - {}, Loss - {:.3f} \n{}\n'.format(self.global_step, self.best_val_loss, loss_dict))
else:
f.write('Step - {}, \n{}\n'.format(self.global_step, loss_dict))
def configure_optimizers(self):
params = list(self.net.encoder.parameters())
if self.opts.train_decoder:
params += list(self.net.decoder.parameters())
if self.opts.optim_name == 'adam':
optimizer = torch.optim.Adam(params, lr=self.opts.learning_rate)
else:
optimizer = Ranger(params, lr=self.opts.learning_rate)
return optimizer
def calc_loss(self, latent, w, y_hat=None, y=None):
loss_dict = {}
loss = 0.0
id_logs = None
if self.opts.l2_lambda > 0 and (y_hat is not None) and (y is not None):
loss_l2 = F.mse_loss(y_hat, y)
loss_dict['loss_l2'] = float(loss_l2)
loss += loss_l2 * self.opts.l2_lambda
if self.opts.lpips_lambda > 0 and (y_hat is not None) and (y is not None):
loss_lpips = self.lpips_loss(y_hat, y)
loss_dict['loss_lpips'] = float(loss_lpips)
loss += loss_lpips * self.opts.lpips_lambda
if self.opts.l2latent_lambda > 0:
loss_l2 = F.mse_loss(latent, w)
loss_dict['loss_l2latent'] = float(loss_l2)
loss += loss_l2 * self.opts.l2latent_lambda
loss_dict['loss'] = float(loss)
return loss, loss_dict, id_logs
def parse_and_log_images(self, id_logs, x, y, y_hat, title, subscript=None, display_count=1):
im_data = []
for i in range(display_count):
cur_im_data = {
'input_face': common.tensor2im(x[i]),
'target_face': common.tensor2im(y[i]),
'output_face': common.tensor2im(y_hat[i]),
}
if id_logs is not None:
for key in id_logs[i]:
cur_im_data[key] = id_logs[i][key]
im_data.append(cur_im_data)
self.log_images(title, im_data=im_data, subscript=subscript)
def log_images(self, name, im_data, subscript=None, log_latest=False):
fig = common.vis_faces(im_data)
step = self.global_step
if log_latest:
step = 0
if subscript:
path = os.path.join(self.logger.log_dir, name, '{}_{:04d}.jpg'.format(subscript, step))
else:
path = os.path.join(self.logger.log_dir, name, '{:04d}.jpg'.format(step))
os.makedirs(os.path.dirname(path), exist_ok=True)
fig.savefig(path)
plt.close(fig)
def __get_save_dict(self):
save_dict = {
'state_dict': self.net.state_dict(),
'opts': vars(self.opts)
}
save_dict['latent_avg'] = self.net.latent_avg
return save_dict