import os import math, random import numpy as np import matplotlib import matplotlib.pyplot as plt matplotlib.use('Agg') import torch from torch import nn from torch.utils.tensorboard import SummaryWriter import torch.nn.functional as F from utils import common from criteria.lpips.lpips import LPIPS from models.StyleGANControler import StyleGANControler from training.ranger import Ranger from expansion.submission import Expansion from expansion.utils.flowlib import point_vec class Coach: def __init__(self, opts): self.opts = opts if self.opts.checkpoint_path is None: self.global_step = 0 else: self.global_step = int(os.path.splitext(os.path.basename(self.opts.checkpoint_path))[0].split('_')[-1]) self.device = 'cuda:0' # TODO: Allow multiple GPU? currently using CUDA_VISIBLE_DEVICES self.opts.device = self.device # Initialize network self.net = StyleGANControler(self.opts).to(self.device) # Initialize loss if self.opts.lpips_lambda > 0: self.lpips_loss = LPIPS(net_type='alex').to(self.device).eval() self.mse_loss = nn.MSELoss().to(self.device).eval() # Initialize optimizer self.optimizer = self.configure_optimizers() # Initialize logger log_dir = os.path.join(opts.exp_dir, 'logs') os.makedirs(log_dir, exist_ok=True) self.logger = SummaryWriter(log_dir=log_dir) # Initialize checkpoint dir self.checkpoint_dir = os.path.join(opts.exp_dir, 'checkpoints') os.makedirs(self.checkpoint_dir, exist_ok=True) self.best_val_loss = None if self.opts.save_interval is None: self.opts.save_interval = self.opts.max_steps # Initialize optical flow estimator self.ex = Expansion() # Set flow normalization values if 'ffhq' in self.opts.stylegan_weights: self.sigma_f = 4 self.sigma_e = 0.02 elif 'car' in self.opts.stylegan_weights: self.sigma_f = 5 self.sigma_e = 0.03 elif 'cat' in self.opts.stylegan_weights: self.sigma_f = 12 self.sigma_e = 0.04 elif 'church' in self.opts.stylegan_weights: self.sigma_f = 8 self.sigma_e = 0.02 elif 'anime' in self.opts.stylegan_weights: self.sigma_f = 7 self.sigma_e = 0.025 def train(self, truncation = 0.3, sigma = 0.1, target_layers = [0,1,2,3,4,5]): x = np.array(range(0,256,16)).astype(np.float32)/127.5-1. y = np.array(range(0,256,16)).astype(np.float32)/127.5-1. xx, yy = np.meshgrid(x,y) grid = np.concatenate([xx[:,:,None],yy[:,:,None]], axis=2) grid = torch.from_numpy(grid[None,:]).cuda() grid = grid.repeat(self.opts.batch_size,1,1,1) while self.global_step < self.opts.max_steps: with torch.no_grad(): z1 = torch.randn(self.opts.batch_size,512).to("cuda") z2 = torch.randn(self.opts.batch_size,self.net.style_num, 512).to("cuda") x1, w1, f1 = self.net.decoder([z1],input_is_latent=False,randomize_noise=False,return_feature_map=True,return_latents=True,truncation=truncation, truncation_latent=self.net.latent_avg[0]) x1 = self.net.face_pool(x1) x2, w2 = self.net.decoder([z2],input_is_latent=False,randomize_noise=False,return_latents=True, truncation_latent=self.net.latent_avg[0]) x2 = self.net.face_pool(x2) w_mid = w1.clone() w_mid[:,target_layers] = w_mid[:,target_layers]+sigma*(w2[:,target_layers]-w_mid[:,target_layers]) x_mid, _ = self.net.decoder([w_mid], input_is_latent=True, randomize_noise=False, return_latents=False) x_mid = self.net.face_pool(x_mid) flow, logexp = self.ex.run(x1.detach(),x_mid.detach()) flow_feature = torch.cat([flow/self.sigma_f, logexp/self.sigma_e], dim=1) f1 = F.interpolate(f1, (flow_feature.shape[2:])) f1 = F.grid_sample(f1, grid, mode='nearest', align_corners=True) flow_feature = F.grid_sample(flow_feature, grid, mode='nearest', align_corners=True) flow_feature = flow_feature.view(flow_feature.shape[0], flow_feature.shape[1], -1).permute(0,2,1) f1 = f1.view(f1.shape[0], f1.shape[1], -1).permute(0,2,1) self.net.train() self.optimizer.zero_grad() w_hat = self.net.encoder(w1[:,target_layers].detach(), flow_feature.detach(), f1.detach()) loss, loss_dict, id_logs = self.calc_loss(w_hat, w_mid[:,target_layers].detach()) loss.backward() self.optimizer.step() w_mid[:,target_layers] = w_hat.detach() x_hat, _ = self.net.decoder([w_mid], input_is_latent=True, randomize_noise=False) x_hat = self.net.face_pool(x_hat) if self.global_step % self.opts.image_interval == 0 or ( self.global_step < 1000 and self.global_step % 100 == 0): imgL_o = ((x1.detach()+1.)*127.5)[0].permute(1,2,0).cpu().numpy() flow = torch.cat((flow,torch.ones_like(flow)[:,:1]), dim=1)[0].permute(1,2,0).cpu().numpy() flowvis = point_vec(imgL_o, flow) flowvis = torch.from_numpy(flowvis[:,:,::-1].copy()).permute(2,0,1).unsqueeze(0)/127.5-1. self.parse_and_log_images(None, flowvis, x_mid, x_hat, title='trained_images') print(loss_dict) if self.global_step % self.opts.save_interval == 0 or self.global_step == self.opts.max_steps: self.checkpoint_me(loss_dict, is_best=False) if self.global_step == self.opts.max_steps: print('OMG, finished training!') break self.global_step += 1 def checkpoint_me(self, loss_dict, is_best): save_name = 'best_model.pt' if is_best else 'iteration_{}.pt'.format(self.global_step) save_dict = self.__get_save_dict() checkpoint_path = os.path.join(self.checkpoint_dir, save_name) torch.save(save_dict, checkpoint_path) with open(os.path.join(self.checkpoint_dir, 'timestamp.txt'), 'a') as f: if is_best: f.write('**Best**: Step - {}, Loss - {:.3f} \n{}\n'.format(self.global_step, self.best_val_loss, loss_dict)) else: f.write('Step - {}, \n{}\n'.format(self.global_step, loss_dict)) def configure_optimizers(self): params = list(self.net.encoder.parameters()) if self.opts.train_decoder: params += list(self.net.decoder.parameters()) if self.opts.optim_name == 'adam': optimizer = torch.optim.Adam(params, lr=self.opts.learning_rate) else: optimizer = Ranger(params, lr=self.opts.learning_rate) return optimizer def calc_loss(self, latent, w, y_hat=None, y=None): loss_dict = {} loss = 0.0 id_logs = None if self.opts.l2_lambda > 0 and (y_hat is not None) and (y is not None): loss_l2 = F.mse_loss(y_hat, y) loss_dict['loss_l2'] = float(loss_l2) loss += loss_l2 * self.opts.l2_lambda if self.opts.lpips_lambda > 0 and (y_hat is not None) and (y is not None): loss_lpips = self.lpips_loss(y_hat, y) loss_dict['loss_lpips'] = float(loss_lpips) loss += loss_lpips * self.opts.lpips_lambda if self.opts.l2latent_lambda > 0: loss_l2 = F.mse_loss(latent, w) loss_dict['loss_l2latent'] = float(loss_l2) loss += loss_l2 * self.opts.l2latent_lambda loss_dict['loss'] = float(loss) return loss, loss_dict, id_logs def parse_and_log_images(self, id_logs, x, y, y_hat, title, subscript=None, display_count=1): im_data = [] for i in range(display_count): cur_im_data = { 'input_face': common.tensor2im(x[i]), 'target_face': common.tensor2im(y[i]), 'output_face': common.tensor2im(y_hat[i]), } if id_logs is not None: for key in id_logs[i]: cur_im_data[key] = id_logs[i][key] im_data.append(cur_im_data) self.log_images(title, im_data=im_data, subscript=subscript) def log_images(self, name, im_data, subscript=None, log_latest=False): fig = common.vis_faces(im_data) step = self.global_step if log_latest: step = 0 if subscript: path = os.path.join(self.logger.log_dir, name, '{}_{:04d}.jpg'.format(subscript, step)) else: path = os.path.join(self.logger.log_dir, name, '{:04d}.jpg'.format(step)) os.makedirs(os.path.dirname(path), exist_ok=True) fig.savefig(path) plt.close(fig) def __get_save_dict(self): save_dict = { 'state_dict': self.net.state_dict(), 'opts': vars(self.opts) } save_dict['latent_avg'] = self.net.latent_avg return save_dict