|
from dataclasses import dataclass |
|
from typing import List |
|
import yaml |
|
|
|
from ...util.hparams import HyperParams |
|
|
|
|
|
@dataclass |
|
class ROMEHyperParams(HyperParams): |
|
|
|
layers: List[int] |
|
fact_token: str |
|
v_num_grad_steps: int |
|
v_lr: float |
|
v_loss_layer: int |
|
v_weight_decay: float |
|
clamp_norm_factor: float |
|
kl_factor: float |
|
mom2_adjustment: bool |
|
context_template_length_params: List[List[int]] |
|
|
|
|
|
rewrite_module_tmp: str |
|
layer_module_tmp: str |
|
mlp_module_tmp: str |
|
attn_module_tmp: str |
|
ln_f_module: str |
|
lm_head_module: str |
|
|
|
|
|
mom2_dataset: str |
|
mom2_n_samples: int |
|
mom2_dtype: str |
|
alg_name: str |
|
device: int |
|
model_name: str |
|
stats_dir: str |
|
|
|
max_length: int = 40 |
|
model_parallel: bool = False |
|
fp16: bool = False |
|
|
|
@classmethod |
|
def from_hparams(cls, hparams_name_or_path: str): |
|
|
|
if '.yaml' not in hparams_name_or_path: |
|
hparams_name_or_path = hparams_name_or_path + '.yaml' |
|
|
|
with open(hparams_name_or_path, "r") as stream: |
|
config = yaml.safe_load(stream) |
|
config = super().construct_float_from_scientific_notation(config) |
|
|
|
assert (config and config['alg_name'] == 'ROME') or print(f'ROMEHyperParams can not load from {hparams_name_or_path}, ' |
|
f'alg_name is {config["alg_name"]} ') |
|
return cls(**config) |
|
|