pytorchAnimeGAN / train.py
ptran1203's picture
first
f2fa83b
import torch
import argparse
import os
from models.anime_gan import GeneratorV1
from models.anime_gan_v2 import GeneratorV2
from models.anime_gan_v3 import GeneratorV3
from models.anime_gan import Discriminator
from datasets import AnimeDataSet
from utils.common import load_checkpoint
from trainer import Trainer
from utils.logger import get_logger
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--real_image_dir', type=str, default='dataset/train_photo')
parser.add_argument('--anime_image_dir', type=str, default='dataset/Hayao')
parser.add_argument('--test_image_dir', type=str, default='dataset/test/HR_photo')
parser.add_argument('--model', type=str, default='v1', help="AnimeGAN version, can be {'v1', 'v2', 'v3'}")
parser.add_argument('--epochs', type=int, default=70)
parser.add_argument('--init_epochs', type=int, default=10)
parser.add_argument('--batch_size', type=int, default=8)
parser.add_argument('--exp_dir', type=str, default='runs', help="Experiment directory")
parser.add_argument('--gan_loss', type=str, default='lsgan', help='lsgan / hinge / bce')
parser.add_argument('--resume', action='store_true', help="Continue from current dir")
parser.add_argument('--resume_G_init', type=str, default='False')
parser.add_argument('--resume_G', type=str, default='False')
parser.add_argument('--resume_D', type=str, default='False')
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--use_sn', action='store_true')
parser.add_argument('--cache', action='store_true', help="Turn on disk cache")
parser.add_argument('--amp', action='store_true', help="Turn on Automatic Mixed Precision")
parser.add_argument('--save_interval', type=int, default=1)
parser.add_argument('--debug_samples', type=int, default=0)
parser.add_argument('--num_workers', type=int, default=2)
parser.add_argument('--imgsz', type=int, nargs="+", default=[256],
help="Image sizes, can provide multiple values, image size will increase after a proportion of epochs")
parser.add_argument('--resize_method', type=str, default="crop",
help="Resize image method if origin photo larger than imgsz")
# Loss stuff
parser.add_argument('--lr_g', type=float, default=2e-5)
parser.add_argument('--lr_d', type=float, default=4e-5)
parser.add_argument('--init_lr', type=float, default=1e-4)
parser.add_argument('--wadvg', type=float, default=300.0, help='Adversarial loss weight for G')
parser.add_argument('--wadvd', type=float, default=300.0, help='Adversarial loss weight for D')
parser.add_argument(
'--gray_adv', action='store_true',
help="If given, train adversarial with gray scale image instead of RGB image to reduce color effect of anime style")
# Loss weight VGG19
parser.add_argument('--wcon', type=float, default=1.5, help='Content loss weight') # 1.5 for Hayao, 2.0 for Paprika, 1.2 for Shinkai
parser.add_argument('--wgra', type=float, default=5.0, help='Gram loss weight') # 2.5 for Hayao, 0.6 for Paprika, 2.0 for Shinkai
parser.add_argument('--wcol', type=float, default=30.0, help='Color loss weight') # 15. for Hayao, 50. for Paprika, 10. for Shinkai
parser.add_argument('--wtvar', type=float, default=1.0, help='Total variation loss') # 1. for Hayao, 0.1 for Paprika, 1. for Shinkai
parser.add_argument('--d_layers', type=int, default=2, help='Discriminator conv layers')
parser.add_argument('--d_noise', action='store_true')
# DDP
parser.add_argument('--ddp', action='store_true')
parser.add_argument("--local-rank", default=0, type=int)
parser.add_argument("--world-size", default=2, type=int)
return parser.parse_args()
def check_params(args):
# dataset/Hayao + dataset/train_photo -> train_photo_Hayao
args.dataset = f"{os.path.basename(args.real_image_dir)}_{os.path.basename(args.anime_image_dir)}"
assert args.gan_loss in {'lsgan', 'hinge', 'bce'}, f'{args.gan_loss} is not supported'
def main(args, logger):
check_params(args)
if not torch.cuda.is_available():
logger.info("CUDA not found, use CPU")
# Just for debugging purpose, set to minimum config
# to avoid πŸ”₯ the computer...
args.device = 'cpu'
args.debug_samples = 10
args.batch_size = 2
else:
logger.info(f"Use GPU: {torch.cuda.get_device_name(0)}")
norm_type = "instance"
if args.model == 'v1':
G = GeneratorV1(args.dataset)
elif args.model == 'v2':
G = GeneratorV2(args.dataset)
norm_type = "layer"
elif args.model == 'v3':
G = GeneratorV3(args.dataset)
D = Discriminator(
args.dataset,
num_layers=args.d_layers,
use_sn=args.use_sn,
norm_type=norm_type,
)
start_e = 0
start_e_init = 0
trainer = Trainer(
generator=G,
discriminator=D,
config=args,
logger=logger,
)
if args.resume_G_init.lower() != 'false':
start_e_init = load_checkpoint(G, args.resume_G_init) + 1
if args.local_rank == 0:
logger.info(f"G content weight loaded from {args.resume_G_init}")
elif args.resume_G.lower() != 'false' and args.resume_D.lower() != 'false':
# You should provide both
try:
start_e = load_checkpoint(G, args.resume_G)
if args.local_rank == 0:
logger.info(f"G weight loaded from {args.resume_G}")
load_checkpoint(D, args.resume_D)
if args.local_rank == 0:
logger.info(f"D weight loaded from {args.resume_D}")
# If loaded both weight, turn off init G phrase
args.init_epochs = 0
except Exception as e:
print('Could not load checkpoint, train from scratch', e)
elif args.resume:
# Try to load from working dir
logger.info(f"Loading weight from {trainer.checkpoint_path_G}")
start_e = load_checkpoint(G, trainer.checkpoint_path_G)
logger.info(f"Loading weight from {trainer.checkpoint_path_D}")
load_checkpoint(D, trainer.checkpoint_path_D)
args.init_epochs = 0
dataset = AnimeDataSet(
args.anime_image_dir,
args.real_image_dir,
args.debug_samples,
args.cache,
imgsz=args.imgsz,
resize_method=args.resize_method,
)
if args.local_rank == 0:
logger.info(f"Start from epoch {start_e}, {start_e_init}")
trainer.train(dataset, start_e, start_e_init)
if __name__ == '__main__':
args = parse_args()
real_name = os.path.basename(args.real_image_dir)
anime_name = os.path.basename(args.anime_image_dir)
args.exp_dir = f"{args.exp_dir}_{real_name}_{anime_name}"
os.makedirs(args.exp_dir, exist_ok=True)
logger = get_logger(os.path.join(args.exp_dir, "train.log"))
if args.local_rank == 0:
logger.info("# ==== Train Config ==== #")
for arg in vars(args):
logger.info(f"{arg} {getattr(args, arg)}")
logger.info("==========================")
main(args, logger)