pookiefoof's picture
Update hf
ff0340e
raw
history blame contribute delete
No virus
4.66 kB
import os
from dataclasses import dataclass, field
from datetime import datetime
from omegaconf import OmegaConf
import lrm
from .typing import *
# ============ Register OmegaConf Resolvers ============= #
OmegaConf.register_new_resolver(
"calc_exp_lr_decay_rate", lambda factor, n: factor ** (1.0 / n)
)
OmegaConf.register_new_resolver("add", lambda a, b: a + b)
OmegaConf.register_new_resolver("sub", lambda a, b: a - b)
OmegaConf.register_new_resolver("mul", lambda a, b: a * b)
OmegaConf.register_new_resolver("div", lambda a, b: a / b)
OmegaConf.register_new_resolver("idiv", lambda a, b: a // b)
OmegaConf.register_new_resolver("basename", lambda p: os.path.basename(p))
OmegaConf.register_new_resolver("rmspace", lambda s, sub: s.replace(" ", sub))
OmegaConf.register_new_resolver("tuple2", lambda s: [float(s), float(s)])
OmegaConf.register_new_resolver("gt0", lambda s: s > 0)
OmegaConf.register_new_resolver("not", lambda s: not s)
def calc_num_train_steps(num_data, batch_size, max_epochs, num_nodes, num_cards=8):
return int(num_data / (num_nodes * num_cards * batch_size)) * max_epochs
OmegaConf.register_new_resolver("calc_num_train_steps", calc_num_train_steps)
# ======================================================= #
# ============== Automatic Name Resolvers =============== #
def get_naming_convention(cfg):
# TODO
name = f"lrm_{cfg.system.backbone.num_layers}"
return name
# ======================================================= #
@dataclass
class ExperimentConfig:
name: str = "default"
description: str = ""
tag: str = ""
seed: int = 0
use_timestamp: bool = True
timestamp: Optional[str] = None
exp_root_dir: str = "outputs"
### these shouldn't be set manually
exp_dir: str = "outputs/default"
trial_name: str = "exp"
trial_dir: str = "outputs/default/exp"
n_gpus: int = 1
###
resume: Optional[str] = None
data_cls: str = ""
data: dict = field(default_factory=dict)
system_cls: str = ""
system: dict = field(default_factory=dict)
# accept pytorch-lightning trainer parameters
# see https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api
trainer: dict = field(default_factory=dict)
# accept pytorch-lightning checkpoint callback parameters
# see https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html#modelcheckpoint
checkpoint: dict = field(default_factory=dict)
def load_config(
*yamls: str, cli_args: list = [], from_string=False, makedirs=True, **kwargs
) -> Any:
if from_string:
parse_func = OmegaConf.create
else:
parse_func = OmegaConf.load
yaml_confs = []
for y in yamls:
conf = parse_func(y)
extends = conf.pop("extends", None)
if extends:
assert os.path.exists(extends), f"File {extends} does not exist."
yaml_confs.append(OmegaConf.load(extends))
yaml_confs.append(conf)
cli_conf = OmegaConf.from_cli(cli_args)
cfg = OmegaConf.merge(*yaml_confs, cli_conf, kwargs)
OmegaConf.resolve(cfg)
assert isinstance(cfg, DictConfig)
scfg: ExperimentConfig = parse_structured(ExperimentConfig, cfg)
# post processing
# auto naming
if scfg.name == "auto":
scfg.name = get_naming_convention(scfg)
# add timestamp
if not scfg.tag and not scfg.use_timestamp:
raise ValueError("Either tag is specified or use_timestamp is True.")
scfg.trial_name = scfg.tag
# if resume from an existing config, scfg.timestamp should not be None
if scfg.timestamp is None:
scfg.timestamp = ""
if scfg.use_timestamp:
if scfg.n_gpus > 1:
lrm.warn(
"Timestamp is disabled when using multiple GPUs, please make sure you have a unique tag."
)
else:
scfg.timestamp = datetime.now().strftime("@%Y%m%d-%H%M%S")
# make directories
scfg.trial_name += scfg.timestamp
scfg.exp_dir = os.path.join(scfg.exp_root_dir, scfg.name)
scfg.trial_dir = os.path.join(scfg.exp_dir, scfg.trial_name)
if makedirs:
os.makedirs(scfg.trial_dir, exist_ok=True)
return scfg
def config_to_primitive(config, resolve: bool = True) -> Any:
return OmegaConf.to_container(config, resolve=resolve)
def dump_config(path: str, config) -> None:
with open(path, "w") as fp:
OmegaConf.save(config=config, f=fp)
def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any:
scfg = OmegaConf.merge(OmegaConf.structured(fields), cfg)
return scfg