JerryLiJinyi's picture
Upload 127 files
10b912d verified
raw
history blame
1.74 kB
import json
from dataclasses import dataclass, make_dataclass, asdict, field
from typing import List
@dataclass
class Config:
# paths
config: str = "config/default.json"
loader: str = "loaders/newsroom.py"
dataset: str = ""
indices: str = ""
model_dir: str = "default_model_dir"
validation_datasets: List = field(default_factory=lambda: [])
# training settings/hyperparams
batch_size: int = 4
learning_rate: float = 0.00001
k_samples: int = 1
sample_aggregation: str = "max"
max_val_steps: int = None
max_train_steps: int = None
max_train_seconds: int = None
print_every: int = 10
save_every: int = 100
eval_every: int = 100
verbose: bool = True
# pretrained models
encoder_model_id: str = "distilroberta-base"
# reward settings
rewards: tuple = (
"FluencyReward",
"BiEncoderSimilarity",
"GaussianLength",
)
def validate_config(args):
assert (args.sample_aggregation in ("max", "mean"))
def load_config(args):
"""
Loads settings into a dataclass object, from the following sources:
- defaults defined above by DefaultConfig
- args.config (path to a JSON config file)
- args (from using argparse in a script)
Overlapping fields are overwritten in that order.
Example usage:
(...)
args = load_config(parser.parse_args())
args.batch_size
"""
config = asdict(Config())
if args.config:
with open(args.config) as f:
config.update(json.load(f))
config.update(args.__dict__)
Config_ = make_dataclass("Config", fields=config.items())
config_object = Config_(**config)
validate_config(config_object)
return config_object