|
from dataclasses import dataclass |
|
from typing import List |
|
from ...util.hparams import HyperParams |
|
import yaml |
|
|
|
|
|
@dataclass |
|
class GraceHyperParams(HyperParams): |
|
|
|
|
|
edit_lr: int |
|
n_iter: int |
|
|
|
eps: float |
|
dist_fn: str |
|
val_init: str |
|
val_train: str |
|
val_reg: str |
|
reg: str |
|
replacement: str |
|
eps_expand: str |
|
num_pert: str |
|
dropout: float |
|
|
|
|
|
inner_params: List[str] |
|
device: int |
|
alg_name: str |
|
model_name: str |
|
|
|
|
|
batch_size: int = 128 |
|
max_length: int = 30 |
|
model_parallel: bool = False |
|
|
|
@classmethod |
|
def from_hparams(cls, hparams_name_or_path: str): |
|
if '.yaml' not in hparams_name_or_path: |
|
hparams_name_or_path = hparams_name_or_path + '.yaml' |
|
|
|
with open(hparams_name_or_path, "r") as stream: |
|
config = yaml.safe_load(stream) |
|
config = super().construct_float_from_scientific_notation(config) |
|
|
|
assert (config and config['alg_name'] == 'GRACE') or print( |
|
f'GraceHyperParams can not load from {hparams_name_or_path}, ' |
|
f'alg_name is {config["alg_name"]} ') |
|
return cls(**config) |
|
|