from dataclasses import dataclass from typing import List from ...util.hparams import HyperParams import yaml @dataclass class GraceHyperParams(HyperParams): # Experiments edit_lr: int n_iter: int # Method eps: float dist_fn: str val_init: str val_train: str val_reg: str reg: str replacement: str eps_expand: str num_pert: str dropout: float # Module templates inner_params: List[str] device: int alg_name: str model_name: str # Defaults batch_size: int = 128 max_length: int = 30 model_parallel: bool = False @classmethod def from_hparams(cls, hparams_name_or_path: str): if '.yaml' not in hparams_name_or_path: hparams_name_or_path = hparams_name_or_path + '.yaml' with open(hparams_name_or_path, "r") as stream: config = yaml.safe_load(stream) config = super().construct_float_from_scientific_notation(config) assert (config and config['alg_name'] == 'GRACE') or print( f'GraceHyperParams can not load from {hparams_name_or_path}, ' f'alg_name is {config["alg_name"]} ') return cls(**config)