import matplotlib matplotlib.use("Agg") import os, sys import yaml from argparse import ArgumentParser from time import gmtime, strftime from shutil import copy from frames_dataset import FramesDataset from modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator from modules.discriminator import MultiScaleDiscriminator from modules.keypoint_detector import KPDetector, HEEstimator import torch from train import train if __name__ == "__main__": if sys.version_info[0] < 3: raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7") parser = ArgumentParser() parser.add_argument("--config", default="config/vox-256.yaml", help="path to config") parser.add_argument( "--mode", default="train", choices=[ "train", ], ) parser.add_argument("--gen", default="original", choices=["original", "spade"]) parser.add_argument("--log_dir", default="log", help="path to log into") parser.add_argument("--checkpoint", default=None, help="path to checkpoint to restore") parser.add_argument( "--device_ids", default="0, 1, 2, 3, 4, 5, 6, 7", type=lambda x: list(map(int, x.split(","))), help="Names of the devices comma separated.", ) parser.add_argument("--verbose", dest="verbose", action="store_true", help="Print model architecture") parser.set_defaults(verbose=False) opt = parser.parse_args() with open(opt.config) as f: config = yaml.load(f, Loader=yaml.FullLoader) if opt.checkpoint is not None: log_dir = os.path.join(*os.path.split(opt.checkpoint)[:-1]) else: log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split(".")[0]) log_dir += " " + strftime("%d_%m_%y_%H.%M.%S", gmtime()) if opt.gen == "original": generator = OcclusionAwareGenerator(**config["model_params"]["generator_params"], **config["model_params"]["common_params"]) elif opt.gen == "spade": generator = OcclusionAwareSPADEGenerator(**config["model_params"]["generator_params"], **config["model_params"]["common_params"]) if torch.cuda.is_available(): print("cuda is available") generator.to(opt.device_ids[0]) if opt.verbose: print(generator) discriminator = MultiScaleDiscriminator(**config["model_params"]["discriminator_params"], **config["model_params"]["common_params"]) if torch.cuda.is_available(): discriminator.to(opt.device_ids[0]) if opt.verbose: print(discriminator) kp_detector = KPDetector(**config["model_params"]["kp_detector_params"], **config["model_params"]["common_params"]) if torch.cuda.is_available(): kp_detector.to(opt.device_ids[0]) if opt.verbose: print(kp_detector) he_estimator = HEEstimator(**config["model_params"]["he_estimator_params"], **config["model_params"]["common_params"]) if torch.cuda.is_available(): he_estimator.to(opt.device_ids[0]) dataset = FramesDataset(is_train=(opt.mode == "train"), **config["dataset_params"]) if not os.path.exists(log_dir): os.makedirs(log_dir) if not os.path.exists(os.path.join(log_dir, os.path.basename(opt.config))): copy(opt.config, log_dir) if opt.mode == "train": print("Training...") train(config, generator, discriminator, kp_detector, he_estimator, opt.checkpoint, log_dir, dataset, opt.device_ids)