# Copyright (C) 2018 Artsiom Sanakoyeu and Dmytro Kotovenko # # This file is part of Adaptive Style Transfer # # Adaptive Style Transfer is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # Adaptive Style Transfer is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . from __future__ import division from __future__ import print_function import os import time from glob import glob import tensorflow as tf import numpy as np from collections import namedtuple from tqdm import tqdm import multiprocessing from module import * from utils import * import prepare_dataset import img_augm class Artgan(object): def __init__(self, sess, args): self.model_name = args.model_name self.root_dir = './models' self.checkpoint_dir = os.path.join(self.root_dir, self.model_name, 'checkpoint') self.checkpoint_long_dir = os.path.join(self.root_dir, self.model_name, 'checkpoint_long') self.sample_dir = os.path.join(self.root_dir, self.model_name, 'sample') self.inference_dir = os.path.join(self.root_dir, self.model_name, 'inference') self.logs_dir = os.path.join(self.root_dir, self.model_name, 'logs') self.sess = sess self.batch_size = args.batch_size self.image_size = args.image_size self.loss = sce_criterion self.initial_step = 0 OPTIONS = namedtuple('OPTIONS', 'batch_size image_size \ total_steps save_freq lr\ gf_dim df_dim \ is_training \ path_to_content_dataset \ path_to_art_dataset \ discr_loss_weight transformer_loss_weight feature_loss_weight') self.options = OPTIONS._make((args.batch_size, args.image_size, args.total_steps, args.save_freq, args.lr, args.ngf, args.ndf, args.phase == 'train', args.path_to_content_dataset, args.path_to_art_dataset, args.discr_loss_weight, args.transformer_loss_weight, args.feature_loss_weight )) # Create all the folders for saving the model if not os.path.exists(self.root_dir): os.makedirs(self.root_dir) if not os.path.exists(os.path.join(self.root_dir, self.model_name)): os.makedirs(os.path.join(self.root_dir, self.model_name)) if not os.path.exists(self.checkpoint_dir): os.makedirs(self.checkpoint_dir) if not os.path.exists(self.checkpoint_long_dir): os.makedirs(self.checkpoint_long_dir) if not os.path.exists(self.sample_dir): os.makedirs(self.sample_dir) if not os.path.exists(self.inference_dir): os.makedirs(self.inference_dir) self._build_model() #@STCGoal Keep an entire sequence of each 1000 iterations steps #@q Do that bellow set to 405 would keep the whole sequence ?? self.saver = tf.train.Saver(max_to_keep=2) self.saver_long = tf.train.Saver(max_to_keep=None) def _build_model(self): if self.options.is_training: # ==================== Define placeholders. ===================== # with tf.name_scope('placeholder'): self.input_painting = tf.placeholder(dtype=tf.float32, shape=[self.batch_size, None, None, 3], name='painting') self.input_photo = tf.placeholder(dtype=tf.float32, shape=[self.batch_size, None, None, 3], name='photo') self.lr = tf.placeholder(dtype=tf.float32, shape=(), name='learning_rate') # ===================== Wire the graph. ========================= # # Encode input images. self.input_photo_features = encoder(image=self.input_photo, options=self.options, reuse=False) # Decode obtained features self.output_photo = decoder(features=self.input_photo_features, options=self.options, reuse=False) # Get features of output images. Need them to compute feature loss. self.output_photo_features = encoder(image=self.output_photo, options=self.options, reuse=True) # Add discriminators. # Note that each of the predictions contain multiple predictions # at different scale. self.input_painting_discr_predictions = discriminator(image=self.input_painting, options=self.options, reuse=False) self.input_photo_discr_predictions = discriminator(image=self.input_photo, options=self.options, reuse=True) self.output_photo_discr_predictions = discriminator(image=self.output_photo, options=self.options, reuse=True) # ===================== Final losses that we optimize. ===================== # # Discriminator. # Have to predict ones only for original paintings, otherwise predict zero. scale_weight = {"scale_0": 1., "scale_1": 1., "scale_3": 1., "scale_5": 1., "scale_6": 1.} self.input_painting_discr_loss = {key: self.loss(pred, tf.ones_like(pred)) * scale_weight[key] for key, pred in zip(self.input_painting_discr_predictions.keys(), self.input_painting_discr_predictions.values())} self.input_photo_discr_loss = {key: self.loss(pred, tf.zeros_like(pred)) * scale_weight[key] for key, pred in zip(self.input_photo_discr_predictions.keys(), self.input_photo_discr_predictions.values())} self.output_photo_discr_loss = {key: self.loss(pred, tf.zeros_like(pred)) * scale_weight[key] for key, pred in zip(self.output_photo_discr_predictions.keys(), self.output_photo_discr_predictions.values())} self.discr_loss = tf.add_n(list(self.input_painting_discr_loss.values())) + \ tf.add_n(list(self.input_photo_discr_loss.values())) + \ tf.add_n(list(self.output_photo_discr_loss.values())) # Compute discriminator accuracies. self.input_painting_discr_acc = {key: tf.reduce_mean(tf.cast(x=(pred > tf.zeros_like(pred)), dtype=tf.float32)) * scale_weight[key] for key, pred in zip(self.input_painting_discr_predictions.keys(), self.input_painting_discr_predictions.values())} self.input_photo_discr_acc = {key: tf.reduce_mean(tf.cast(x=(pred < tf.zeros_like(pred)), dtype=tf.float32)) * scale_weight[key] for key, pred in zip(self.input_photo_discr_predictions.keys(), self.input_photo_discr_predictions.values())} self.output_photo_discr_acc = {key: tf.reduce_mean(tf.cast(x=(pred < tf.zeros_like(pred)), dtype=tf.float32)) * scale_weight[key] for key, pred in zip(self.output_photo_discr_predictions.keys(), self.output_photo_discr_predictions.values())} self.discr_acc = (tf.add_n(list(self.input_painting_discr_acc.values())) + \ tf.add_n(list(self.input_photo_discr_acc.values())) + \ tf.add_n(list(self.output_photo_discr_acc.values()))) / float(len(scale_weight.keys())*3) # Generator. # Predicts ones for both output images. self.output_photo_gener_loss = {key: self.loss(pred, tf.ones_like(pred)) * scale_weight[key] for key, pred in zip(self.output_photo_discr_predictions.keys(), self.output_photo_discr_predictions.values())} self.gener_loss = tf.add_n(list(self.output_photo_gener_loss.values())) # Compute generator accuracies. self.output_photo_gener_acc = {key: tf.reduce_mean(tf.cast(x=(pred > tf.zeros_like(pred)), dtype=tf.float32)) * scale_weight[key] for key, pred in zip(self.output_photo_discr_predictions.keys(), self.output_photo_discr_predictions.values())} self.gener_acc = tf.add_n(list(self.output_photo_gener_acc.values())) / float(len(scale_weight.keys())) # Image loss. self.img_loss_photo = mse_criterion(transformer_block(self.output_photo), transformer_block(self.input_photo)) self.img_loss = self.img_loss_photo # Features loss. self.feature_loss_photo = abs_criterion(self.output_photo_features, self.input_photo_features) self.feature_loss = self.feature_loss_photo # ================== Define optimization steps. =============== # t_vars = tf.trainable_variables() self.discr_vars = [var for var in t_vars if 'discriminator' in var.name] self.encoder_vars = [var for var in t_vars if 'encoder' in var.name] self.decoder_vars = [var for var in t_vars if 'decoder' in var.name] # Discriminator and generator steps. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): self.d_optim_step = tf.train.AdamOptimizer(self.lr).minimize( loss=self.options.discr_loss_weight * self.discr_loss, var_list=[self.discr_vars]) self.g_optim_step = tf.train.AdamOptimizer(self.lr).minimize( loss=self.options.discr_loss_weight * self.gener_loss + self.options.transformer_loss_weight * self.img_loss + self.options.feature_loss_weight * self.feature_loss, var_list=[self.encoder_vars + self.decoder_vars]) # ============= Write statistics to tensorboard. ================ # # Discriminator loss summary. s_d1 = [tf.summary.scalar("discriminator/input_painting_discr_loss/"+key, val) for key, val in zip(self.input_painting_discr_loss.keys(), self.input_painting_discr_loss.values())] s_d2 = [tf.summary.scalar("discriminator/input_photo_discr_loss/"+key, val) for key, val in zip(self.input_photo_discr_loss.keys(), self.input_photo_discr_loss.values())] s_d3 = [tf.summary.scalar("discriminator/output_photo_discr_loss/" + key, val) for key, val in zip(self.output_photo_discr_loss.keys(), self.output_photo_discr_loss.values())] s_d = tf.summary.scalar("discriminator/discr_loss", self.discr_loss) self.summary_discriminator_loss = tf.summary.merge(s_d1+s_d2+s_d3+[s_d]) # Discriminator acc summary. s_d1_acc = [tf.summary.scalar("discriminator/input_painting_discr_acc/"+key, val) for key, val in zip(self.input_painting_discr_acc.keys(), self.input_painting_discr_acc.values())] s_d2_acc = [tf.summary.scalar("discriminator/input_photo_discr_acc/"+key, val) for key, val in zip(self.input_photo_discr_acc.keys(), self.input_photo_discr_acc.values())] s_d3_acc = [tf.summary.scalar("discriminator/output_photo_discr_acc/" + key, val) for key, val in zip(self.output_photo_discr_acc.keys(), self.output_photo_discr_acc.values())] s_d_acc = tf.summary.scalar("discriminator/discr_acc", self.discr_acc) s_d_acc_g = tf.summary.scalar("discriminator/discr_acc", self.gener_acc) self.summary_discriminator_acc = tf.summary.merge(s_d1_acc+s_d2_acc+s_d3_acc+[s_d_acc]) # Image loss summary. s_i1 = tf.summary.scalar("image_loss/photo", self.img_loss_photo) s_i = tf.summary.scalar("image_loss/loss", self.img_loss) self.summary_image_loss = tf.summary.merge([s_i1 + s_i]) # Feature loss summary. s_f1 = tf.summary.scalar("feature_loss/photo", self.feature_loss_photo) s_f = tf.summary.scalar("feature_loss/loss", self.feature_loss) self.summary_feature_loss = tf.summary.merge([s_f1 + s_f]) self.summary_merged_all = tf.summary.merge_all() self.writer = tf.summary.FileWriter(self.logs_dir, self.sess.graph) else: # ==================== Define placeholders. ===================== # with tf.name_scope('placeholder'): self.input_photo = tf.placeholder(dtype=tf.float32, shape=[self.batch_size, None, None, 3], name='photo') # ===================== Wire the graph. ========================= # # Encode input images. self.input_photo_features = encoder(image=self.input_photo, options=self.options, reuse=False) # Decode obtained features. self.output_photo = decoder(features=self.input_photo_features, options=self.options, reuse=False) def train(self, args, ckpt_nmbr=None): # Initialize augmentor. augmentor = img_augm.Augmentor(crop_size=[self.options.image_size, self.options.image_size], vertical_flip_prb=0., hsv_augm_prb=1.0, hue_augm_shift=0.05, saturation_augm_shift=0.05, saturation_augm_scale=0.05, value_augm_shift=0.05, value_augm_scale=0.05, ) content_dataset_places = prepare_dataset.PlacesDataset(path_to_dataset=self.options.path_to_content_dataset) art_dataset = prepare_dataset.ArtDataset(path_to_art_dataset=self.options.path_to_art_dataset) # Initialize queue workers for both datasets. q_art = multiprocessing.Queue(maxsize=10) q_content = multiprocessing.Queue(maxsize=10) jobs = [] for i in range(5): p = multiprocessing.Process(target=content_dataset_places.initialize_batch_worker, args=(q_content, augmentor, self.batch_size, i)) p.start() jobs.append(p) p = multiprocessing.Process(target=art_dataset.initialize_batch_worker, args=(q_art, augmentor, self.batch_size, i)) p.start() jobs.append(p) print("Processes are started.") time.sleep(3) # Now initialize the graph init_op = tf.global_variables_initializer() self.sess.run(init_op) print("Start training.") if self.load(self.checkpoint_dir, ckpt_nmbr): print(" [*] Load SUCCESS") else: if self.load(self.checkpoint_long_dir, ckpt_nmbr): print(" [*] Load SUCCESS") else: print(" [!] Load failed...") # Initial discriminator success rate. win_rate = args.discr_success_rate discr_success = args.discr_success_rate alpha = 0.05 for step in tqdm(range(self.initial_step, self.options.total_steps+1), initial=self.initial_step, total=self.options.total_steps): # Get batch from the queue with batches q, if the last is non-empty. while q_art.empty() or q_content.empty(): pass batch_art = q_art.get() batch_content = q_content.get() if discr_success >= win_rate: # Train generator _, summary_all, gener_acc_ = self.sess.run( [self.g_optim_step, self.summary_merged_all, self.gener_acc], feed_dict={ self.input_painting: normalize_arr_of_imgs(batch_art['image']), self.input_photo: normalize_arr_of_imgs(batch_content['image']), self.lr: self.options.lr }) discr_success = discr_success * (1. - alpha) + alpha * (1. - gener_acc_) else: # Train discriminator. _, summary_all, discr_acc_ = self.sess.run( [self.d_optim_step, self.summary_merged_all, self.discr_acc], feed_dict={ self.input_painting: normalize_arr_of_imgs(batch_art['image']), self.input_photo: normalize_arr_of_imgs(batch_content['image']), self.lr: self.options.lr }) discr_success = discr_success * (1. - alpha) + alpha * discr_acc_ self.writer.add_summary(summary_all, step * self.batch_size) if step % self.options.save_freq == 0 and step > self.initial_step: self.save(step) # And additionally save all checkpoints each 15000 steps. if step % 15000 == 0 and step > self.initial_step: self.save(step, is_long=True) if step % 500 == 0: output_paintings_, output_photos_= self.sess.run( [self.input_painting, self.output_photo], feed_dict={ self.input_painting: normalize_arr_of_imgs(batch_art['image']), self.input_photo: normalize_arr_of_imgs(batch_content['image']), self.lr: self.options.lr }) save_batch(input_painting_batch=batch_art['image'], input_photo_batch=batch_content['image'], output_painting_batch=denormalize_arr_of_imgs(output_paintings_), output_photo_batch=denormalize_arr_of_imgs(output_photos_), filepath='%s/step_%d.jpg' % (self.sample_dir, step)) print("Training is finished. Terminate jobs.") for p in jobs: p.join() p.terminate() print("Done.") print("Does the sys.exit() made this process to exit ??") sys.exit() # Don't use this function yet. def inference_video(self, args, path_to_folder, to_save_dir=None, resize_to_original=True, use_time_smooth_randomness=True, ckpt_nmbr=None,file_suffix= "_stylized"): """ Run inference on the video frames. Original aspect ratio will be preserved. Args: args: path_to_folder: path to the folder with frames from the video to_save_dir: resize_to_original: use_time_smooth_randomness: change the random vector which is added to the bottleneck features linearly over tim Returns: """ init_op = tf.global_variables_initializer() self.sess.run(init_op) print("Start inference.") if self.load(self.checkpoint_dir, ckpt_nmbr): print(" [*] Load SUCCESS") else: if self.load(self.checkpoint_long_dir, ckpt_nmbr): print(" [*] Load SUCCESS") else: print(" [!] Load failed...") # Create folder to store results. if to_save_dir is None: to_save_dir = os.path.join(self.root_dir, self.model_name, 'inference_ckpt%d_sz%d' % (self.initial_step, self.image_size)) if not os.path.exists(to_save_dir): os.makedirs(to_save_dir) image_paths = sorted(os.listdir(path_to_folder)) num_images = len(image_paths) for img_idx, img_name in enumerate(tqdm(image_paths)): img_path = os.path.join(path_to_folder, img_name) img = scipy.misc.imread(img_path, mode='RGB') img_shape = img.shape[:2] # Prepare image for feeding into network. scale_mult = self.image_size / np.min(img_shape) new_shape = (np.array(img_shape, dtype=float) * scale_mult).astype(int) img = scipy.misc.imresize(img, size=new_shape) img = np.expand_dims(img, axis=0) if use_time_smooth_randomness and img_idx == 0: features_delta = self.sess.run(self.labels_to_concatenate_to_features, feed_dict={ self.input_photo: normalize_arr_of_imgs(img), }) features_delta_start = features_delta + np.random.random(size=features_delta.shape) * 0.5 - 0.25 features_delta_start = features_delta_start.clip(0, 1000) print('features_delta_start.shape=', features_delta_start.shape) features_delta_end = features_delta + np.random.random(size=features_delta.shape) * 0.5 - 0.25 features_delta_end = features_delta_end.clip(0, 1000) step = (features_delta_end - features_delta_start) / (num_images - 1) feed_dict = { self.input_painting: normalize_arr_of_imgs(img), self.input_photo: normalize_arr_of_imgs(img), self.lr: self.options.lr } if use_time_smooth_randomness: pass img = self.sess.run(self.output_photo, feed_dict=feed_dict) img = img[0] img = denormalize_arr_of_imgs(img) if resize_to_original: img = scipy.misc.imresize(img, size=img_shape) else: pass scipy.misc.imsave(os.path.join(to_save_dir, img_name[:-4] + file_suffix +".jpg"), img) print("Inference is finished.") def inference(self, args, path_to_folder, to_save_dir=None, resize_to_original=True, ckpt_nmbr=None,file_suffix= "_stylized"): init_op = tf.global_variables_initializer() self.sess.run(init_op) print("Start inference.") if self.load(self.checkpoint_dir, ckpt_nmbr): print(" [*] Load SUCCESS") else: if self.load(self.checkpoint_long_dir, ckpt_nmbr): print(" [*] Load SUCCESS") else: print(" [!] Load failed...") #Exit if we can not load (fix issue inferencing noizy image) sys.exit() # Create folder to store results. if to_save_dir is None: to_save_dir = os.path.join(self.root_dir, self.model_name, 'inference_ckpt%d_sz%d' % (self.initial_step, self.image_size)) if not os.path.exists(to_save_dir): os.makedirs(to_save_dir) names = [] for d in path_to_folder: names += glob(os.path.join(d, '*')) names = [x for x in names if os.path.basename(x)[0] != '.'] names.sort() for img_idx, img_path in enumerate(tqdm(names)): img = scipy.misc.imread(img_path, mode='RGB') img_shape = img.shape[:2] # Resize the smallest side of the image to the self.image_size alpha = float(self.image_size) / float(min(img_shape)) img = scipy.misc.imresize(img, size=alpha) img = np.expand_dims(img, axis=0) img = self.sess.run( self.output_photo, feed_dict={ self.input_photo: normalize_arr_of_imgs(img), }) img = img[0] img = denormalize_arr_of_imgs(img) if resize_to_original: img = scipy.misc.imresize(img, size=img_shape) else: pass img_name = os.path.basename(img_path) #@STCGoal HERE TO APPEND SUFFIX TO FILE scipy.misc.imsave(os.path.join(to_save_dir, img_name[:-4] + file_suffix +".jpg"), img) print("Inference is finished.") def save(self, step, is_long=False): if not os.path.exists(self.checkpoint_dir): os.makedirs(self.checkpoint_dir) if is_long: self.saver_long.save(self.sess, os.path.join(self.checkpoint_long_dir, self.model_name+'_%d.ckpt' % step), global_step=step) else: self.saver.save(self.sess, os.path.join(self.checkpoint_dir, self.model_name + '_%d.ckpt' % step), global_step=step) def load(self, checkpoint_dir, ckpt_nmbr=None): if ckpt_nmbr: if len([x for x in os.listdir(checkpoint_dir) if ("ckpt-" + str(ckpt_nmbr)) in x]) > 0: print(" [*] Reading checkpoint %d from folder %s." % (ckpt_nmbr, checkpoint_dir)) ckpt_name = [x for x in os.listdir(checkpoint_dir) if ("ckpt-" + str(ckpt_nmbr)) in x][0] ckpt_name = '.'.join(ckpt_name.split('.')[:-1]) self.initial_step = ckpt_nmbr print("Load checkpoint %s. Initial step: %s." % (ckpt_name, self.initial_step)) self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) return True else: return False else: print(" [*] Reading latest checkpoint from folder %s." % (checkpoint_dir)) ckpt = tf.train.get_checkpoint_state(checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: ckpt_name = os.path.basename(ckpt.model_checkpoint_path) self.initial_step = int(ckpt_name.split("_")[-1].split(".")[0]) print("Load checkpoint %s. Initial step: %s." % (ckpt_name, self.initial_step)) self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) return True else: return False