GowthamYarlagadda's picture
Upload 304 files
b36e9ec verified
raw
history blame contribute delete
No virus
3.49 kB
import matplotlib
matplotlib.use("Agg")
import os, sys
import yaml
from argparse import ArgumentParser
from time import gmtime, strftime
from shutil import copy
from frames_dataset import FramesDataset
from modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator
from modules.discriminator import MultiScaleDiscriminator
from modules.keypoint_detector import KPDetector, HEEstimator
import torch
from train import train
if __name__ == "__main__":
if sys.version_info[0] < 3:
raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7")
parser = ArgumentParser()
parser.add_argument("--config", default="config/vox-256.yaml", help="path to config")
parser.add_argument(
"--mode",
default="train",
choices=[
"train",
],
)
parser.add_argument("--gen", default="original", choices=["original", "spade"])
parser.add_argument("--log_dir", default="log", help="path to log into")
parser.add_argument("--checkpoint", default=None, help="path to checkpoint to restore")
parser.add_argument(
"--device_ids",
default="0, 1, 2, 3, 4, 5, 6, 7",
type=lambda x: list(map(int, x.split(","))),
help="Names of the devices comma separated.",
)
parser.add_argument("--verbose", dest="verbose", action="store_true", help="Print model architecture")
parser.set_defaults(verbose=False)
opt = parser.parse_args()
with open(opt.config) as f:
config = yaml.load(f, Loader=yaml.FullLoader)
if opt.checkpoint is not None:
log_dir = os.path.join(*os.path.split(opt.checkpoint)[:-1])
else:
log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split(".")[0])
log_dir += " " + strftime("%d_%m_%y_%H.%M.%S", gmtime())
if opt.gen == "original":
generator = OcclusionAwareGenerator(**config["model_params"]["generator_params"], **config["model_params"]["common_params"])
elif opt.gen == "spade":
generator = OcclusionAwareSPADEGenerator(**config["model_params"]["generator_params"], **config["model_params"]["common_params"])
if torch.cuda.is_available():
print("cuda is available")
generator.to(opt.device_ids[0])
if opt.verbose:
print(generator)
discriminator = MultiScaleDiscriminator(**config["model_params"]["discriminator_params"], **config["model_params"]["common_params"])
if torch.cuda.is_available():
discriminator.to(opt.device_ids[0])
if opt.verbose:
print(discriminator)
kp_detector = KPDetector(**config["model_params"]["kp_detector_params"], **config["model_params"]["common_params"])
if torch.cuda.is_available():
kp_detector.to(opt.device_ids[0])
if opt.verbose:
print(kp_detector)
he_estimator = HEEstimator(**config["model_params"]["he_estimator_params"], **config["model_params"]["common_params"])
if torch.cuda.is_available():
he_estimator.to(opt.device_ids[0])
dataset = FramesDataset(is_train=(opt.mode == "train"), **config["dataset_params"])
if not os.path.exists(log_dir):
os.makedirs(log_dir)
if not os.path.exists(os.path.join(log_dir, os.path.basename(opt.config))):
copy(opt.config, log_dir)
if opt.mode == "train":
print("Training...")
train(config, generator, discriminator, kp_detector, he_estimator, opt.checkpoint, log_dir, dataset, opt.device_ids)