ast / main.py
jgwill's picture
add:ast-app
1b677c1
raw
history blame contribute delete
No virus
7.11 kB
# Copyright (C) 2018 Artsiom Sanakoyeu and Dmytro Kotovenko
#
# This file is part of Adaptive Style Transfer
#
# Adaptive Style Transfer is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Adaptive Style Transfer is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import argparse
import tensorflow as tf
tf.set_random_seed(228)
from model import Artgan
def parse_list(str_value):
if ',' in str_value:
str_value = str_value.split(',')
else:
str_value = [str_value]
return str_value
parser = argparse.ArgumentParser(description='')
# ========================== GENERAL PARAMETERS ========================= #
parser.add_argument('--model_name',
dest='model_name',
default='model1',
help='Name of the model')
parser.add_argument('--phase',
dest='phase',
default='train',
help='Specify current phase: train or inference.')
parser.add_argument('--image_size',
dest='image_size',
type=int,
default=256*3,
help='For training phase: will crop out images of this particular size.'
'For inference phase: each input image will have the smallest side of this size. '
'For inference recommended size is 1280.')
# ========================= TRAINING PARAMETERS ========================= #
parser.add_argument('--ptad',
dest='path_to_art_dataset',
type=str,
#default='./data/vincent-van-gogh_paintings/',
default='./data/vincent-van-gogh_road-with-cypresses-1890',
help='Directory with paintings representing style we want to learn.')
parser.add_argument('--ptcd',
dest='path_to_content_dataset',
type=str,
default=None,
help='Path to Places365 training dataset.')
parser.add_argument('--total_steps',
dest='total_steps',
type=int,
default=int(3e5),
help='Total number of steps')
parser.add_argument('--batch_size',
dest='batch_size',
type=int,
default=1,
help='# images in batch')
parser.add_argument('--lr',
dest='lr',
type=float,
default=0.0002,
help='initial learning rate for adam')
parser.add_argument('--save_freq',
dest='save_freq',
type=int,
default=1000,
help='Save model every save_freq steps')
parser.add_argument('--ngf',
dest='ngf',
type=int,
default=32,
help='Number of filters in first conv layer of generator(encoder-decoder).')
parser.add_argument('--ndf',
dest='ndf',
type=int,
default=64,
help='Number of filters in first conv layer of discriminator.')
# Weights of different losses.
parser.add_argument('--dlw',
dest='discr_loss_weight',
type=float,
default=1.,
help='Weight of discriminator loss.')
parser.add_argument('--tlw',
dest='transformer_loss_weight',
type=float,
default=100.,
help='Weight of transformer loss.')
parser.add_argument('--flw',
dest='feature_loss_weight',
type=float,
default=100.,
help='Weight of feature loss.')
parser.add_argument('--dsr',
dest='discr_success_rate',
type=float,
default=0.8,
help='Rate of trials that discriminator will win on average.')
# ========================= INFERENCE PARAMETERS ========================= #
parser.add_argument('--ii_dir',
dest='inference_images_dir',
type=parse_list,
default=['./data/sample_photographs/'],
help='Directory with images we want to process.')
parser.add_argument('--save_dir',
type=str,
default=None,
help='Directory to save inference output images.'
'If not specified will save in the model directory.')
parser.add_argument('--file_suffix',
type=str,
default='_stylized',
help='Suffix to append in between ext format and fn.'
'If not specified will save in the model directory.')
parser.add_argument('--ckpt_nmbr',
dest='ckpt_nmbr',
type=int,
default=None,
help='Checkpoint number we want to use for inference. '
'Might be None(unspecified), then the latest available will be used.')
args = parser.parse_args()
def main(_):
tfconfig = tf.ConfigProto(allow_soft_placement=False)
tfconfig.gpu_options.allow_growth = True
with tf.Session(config=tfconfig) as sess:
model = Artgan(sess, args)
if args.phase == 'train':
model.train(args, ckpt_nmbr=args.ckpt_nmbr)
if args.phase == 'inference' or args.phase == 'test':
print("Inference.")
model.inference(args, args.inference_images_dir, resize_to_original=False,
to_save_dir=args.save_dir,
ckpt_nmbr=args.ckpt_nmbr,
file_suffix=args.file_suffix)
if args.phase == 'inference_on_frames' or args.phase == 'test_on_frames':
print("Inference on frames sequence.")
model.inference_video(args,
path_to_folder=args.inference_images_dir[0],
resize_to_original=False,
to_save_dir=args.save_dir,
ckpt_nmbr = args.ckpt_nmbr,
file_suffix=args.file_suffix)
sess.close()
if __name__ == '__main__':
tf.app.run()