|
import os |
|
import importlib |
|
from typing import Type, TypeVar |
|
from argparse import ArgumentParser |
|
|
|
from omegaconf import OmegaConf, DictConfig |
|
|
|
|
|
def get_module_config(cfg_model: DictConfig, path: str = "modules") -> DictConfig: |
|
files = os.listdir(f'./configs/{path}/') |
|
for file in files: |
|
if file.endswith('.yaml'): |
|
with open(f'./configs/{path}/' + file, 'r') as f: |
|
cfg_model.merge_with(OmegaConf.load(f)) |
|
return cfg_model |
|
|
|
|
|
def get_obj_from_str(string: str, reload: bool = False) -> Type: |
|
module, cls = string.rsplit(".", 1) |
|
if reload: |
|
module_imp = importlib.import_module(module) |
|
importlib.reload(module_imp) |
|
return getattr(importlib.import_module(module, package=None), cls) |
|
|
|
|
|
def instantiate_from_config(config: DictConfig) -> TypeVar: |
|
return get_obj_from_str(config["target"])(**config.get("params", dict())) |
|
|
|
|
|
def parse_args() -> DictConfig: |
|
parser = ArgumentParser() |
|
parser.add_argument("--cfg", type=str, required=True, help="config file") |
|
|
|
|
|
parser.add_argument('--example', type=str, required=False, help="input text and lengths with txt format") |
|
parser.add_argument('--no-plot', action="store_true", required=False, help="whether plot the skeleton-based motion") |
|
parser.add_argument('--replication', type=int, default=1, help="the number of replication of sampling") |
|
args = parser.parse_args() |
|
|
|
cfg = OmegaConf.load(args.cfg) |
|
cfg_model = get_module_config(cfg.model, cfg.model.target) |
|
cfg = OmegaConf.merge(cfg, cfg_model) |
|
|
|
cfg.example = args.example |
|
cfg.no_plot = args.no_plot |
|
cfg.replication = args.replication |
|
return cfg |
|
|