import importlib from argparse import ArgumentParser from omegaconf import OmegaConf from os.path import join as pjoin import os import glob def get_module_config(cfg, filepath="./configs"): """ Load yaml config files from subfolders """ yamls = glob.glob(pjoin(filepath, '*', '*.yaml')) yamls = [y.replace(filepath, '') for y in yamls] for yaml in yamls: nodes = yaml.replace('.yaml', '').replace('/', '.') nodes = nodes[1:] if nodes[0] == '.' else nodes OmegaConf.update(cfg, nodes, OmegaConf.load('./configs' + yaml)) return cfg def get_obj_from_str(string, reload=False): """ Get object from string """ module, cls = string.rsplit(".", 1) if reload: module_imp = importlib.import_module(module) importlib.reload(module_imp) return getattr(importlib.import_module(module, package=None), cls) def instantiate_from_config(config): """ Instantiate object from config """ if not "target" in config: raise KeyError("Expected key `target` to instantiate.") return get_obj_from_str(config["target"])(**config.get("params", dict())) def resume_config(cfg: OmegaConf): """ Resume model and wandb """ if cfg.TRAIN.RESUME: resume = cfg.TRAIN.RESUME if os.path.exists(resume): # Checkpoints cfg.TRAIN.PRETRAINED = pjoin(resume, "checkpoints", "last.ckpt") # Wandb wandb_files = os.listdir(pjoin(resume, "wandb", "latest-run")) wandb_run = [item for item in wandb_files if "run-" in item][0] cfg.LOGGER.WANDB.params.id = wandb_run.replace("run-","").replace(".wandb", "") else: raise ValueError("Resume path is not right.") return cfg def parse_args(phase="train"): """ Parse arguments and load config files """ parser = ArgumentParser() group = parser.add_argument_group("Training options") # Assets group.add_argument( "--cfg_assets", type=str, required=False, default="./configs/assets.yaml", help="config file for asset paths", ) # Default config if phase in ["train", "test"]: cfg_defualt = "./configs/default.yaml" elif phase == "render": cfg_defualt = "./configs/render.yaml" elif phase == "webui": cfg_defualt = "./configs/webui.yaml" group.add_argument( "--cfg", type=str, required=False, default=cfg_defualt, help="config file", ) # Parse for each phase if phase in ["train", "test"]: group.add_argument("--batch_size", type=int, required=False, help="training batch size") group.add_argument("--num_nodes", type=int, required=False, help="number of nodes") group.add_argument("--device", type=int, nargs="+", required=False, help="training device") group.add_argument("--task", type=str, required=False, help="evaluation task type") group.add_argument("--nodebug", action="store_true", required=False, help="debug or not") if phase == "demo": group.add_argument( "--example", type=str, required=False, help="input text and lengths with txt format", ) group.add_argument( "--out_dir", type=str, required=False, help="output dir", ) group.add_argument("--task", type=str, required=False, help="evaluation task type") if phase == "render": group.add_argument("--npy", type=str, required=False, default=None, help="npy motion files") group.add_argument("--dir", type=str, required=False, default=None, help="npy motion folder") group.add_argument("--fps", type=int, required=False, default=30, help="render fps") group.add_argument( "--mode", type=str, required=False, default="sequence", help="render target: video, sequence, frame", ) params = parser.parse_args() # Load yaml config files OmegaConf.register_new_resolver("eval", eval) cfg_assets = OmegaConf.load(params.cfg_assets) cfg_base = OmegaConf.load(pjoin(cfg_assets.CONFIG_FOLDER, 'default.yaml')) cfg_exp = OmegaConf.merge(cfg_base, OmegaConf.load(params.cfg)) if not cfg_exp.FULL_CONFIG: cfg_exp = get_module_config(cfg_exp, cfg_assets.CONFIG_FOLDER) cfg = OmegaConf.merge(cfg_exp, cfg_assets) # Update config with arguments if phase in ["train", "test"]: cfg.TRAIN.BATCH_SIZE = params.batch_size if params.batch_size else cfg.TRAIN.BATCH_SIZE cfg.DEVICE = params.device if params.device else cfg.DEVICE cfg.NUM_NODES = params.num_nodes if params.num_nodes else cfg.NUM_NODES cfg.model.params.task = params.task if params.task else cfg.model.params.task cfg.DEBUG = not params.nodebug if params.nodebug is not None else cfg.DEBUG # Force no debug in test if phase == "test": cfg.DEBUG = False cfg.DEVICE = [0] print("Force no debugging and one gpu when testing") if phase == "demo": cfg.DEMO.RENDER = params.render cfg.DEMO.FRAME_RATE = params.frame_rate cfg.DEMO.EXAMPLE = params.example cfg.DEMO.TASK = params.task cfg.TEST.FOLDER = params.out_dir if params.out_dir else cfg.TEST.FOLDER os.makedirs(cfg.TEST.FOLDER, exist_ok=True) if phase == "render": if params.npy: cfg.RENDER.NPY = params.npy cfg.RENDER.INPUT_MODE = "npy" if params.dir: cfg.RENDER.DIR = params.dir cfg.RENDER.INPUT_MODE = "dir" if params.fps: cfg.RENDER.FPS = float(params.fps) cfg.RENDER.MODE = params.mode # Debug mode if cfg.DEBUG: cfg.NAME = "debug--" + cfg.NAME cfg.LOGGER.WANDB.params.offline = True cfg.LOGGER.VAL_EVERY_STEPS = 1 # Resume config cfg = resume_config(cfg) return cfg