Spaces:
Paused
Paused
Create argument.py
Browse files- utils/argument.py +98 -0
utils/argument.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import yaml
|
2 |
+
import json
|
3 |
+
import argparse
|
4 |
+
import logging
|
5 |
+
|
6 |
+
logger = logging.getLogger(__name__)
|
7 |
+
|
8 |
+
|
9 |
+
def load_config_dict_to_opt(opt, config_dict):
|
10 |
+
"""
|
11 |
+
Load the key, value pairs from config_dict to opt, overriding existing values in opt
|
12 |
+
if there is any.
|
13 |
+
"""
|
14 |
+
if not isinstance(config_dict, dict):
|
15 |
+
raise TypeError("Config must be a Python dictionary")
|
16 |
+
for k, v in config_dict.items():
|
17 |
+
k_parts = k.split('.')
|
18 |
+
pointer = opt
|
19 |
+
for k_part in k_parts[:-1]:
|
20 |
+
if k_part not in pointer:
|
21 |
+
pointer[k_part] = {}
|
22 |
+
pointer = pointer[k_part]
|
23 |
+
assert isinstance(pointer, dict), "Overriding key needs to be inside a Python dict."
|
24 |
+
ori_value = pointer.get(k_parts[-1])
|
25 |
+
pointer[k_parts[-1]] = v
|
26 |
+
if ori_value:
|
27 |
+
logger.warning(f"Overrided {k} from {ori_value} to {pointer[k_parts[-1]]}")
|
28 |
+
|
29 |
+
|
30 |
+
def load_opt_from_config_files(conf_file):
|
31 |
+
"""
|
32 |
+
Load opt from the config files, settings in later files can override those in previous files.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
conf_files: config file path
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
dict: a dictionary of opt settings
|
39 |
+
"""
|
40 |
+
opt = {}
|
41 |
+
with open(conf_file, encoding='utf-8') as f:
|
42 |
+
config_dict = yaml.safe_load(f)
|
43 |
+
|
44 |
+
load_config_dict_to_opt(opt, config_dict)
|
45 |
+
|
46 |
+
return opt
|
47 |
+
|
48 |
+
|
49 |
+
def load_opt_command(args):
|
50 |
+
parser = argparse.ArgumentParser(description='MainzTrain: Pretrain or fine-tune models for NLP tasks.')
|
51 |
+
parser.add_argument('command', help='Command: train/evaluate/train-and-evaluate')
|
52 |
+
parser.add_argument('--conf_files', required=True, help='Path(s) to the MainzTrain config file(s).')
|
53 |
+
parser.add_argument('--config_overrides', nargs='*', help='Override parameters on config with a json style string, e.g. {"<PARAM_NAME_1>": <PARAM_VALUE_1>, "<PARAM_GROUP_2>.<PARAM_SUBGROUP_2>.<PARAM_2>": <PARAM_VALUE_2>}. A key with "." updates the object in the corresponding nested dict. Remember to escape " in command line.')
|
54 |
+
parser.add_argument('--overrides', help='arguments that used to overide the config file in cmdline', nargs=argparse.REMAINDER)
|
55 |
+
|
56 |
+
cmdline_args = parser.parse_args() if not args else parser.parse_args(args)
|
57 |
+
|
58 |
+
opt = load_opt_from_config_files(cmdline_args.conf_files)
|
59 |
+
|
60 |
+
if cmdline_args.config_overrides:
|
61 |
+
config_overrides_string = ' '.join(cmdline_args.config_overrides)
|
62 |
+
logger.warning(f"Command line config overrides: {config_overrides_string}")
|
63 |
+
config_dict = json.loads(config_overrides_string)
|
64 |
+
load_config_dict_to_opt(opt, config_dict)
|
65 |
+
|
66 |
+
if cmdline_args.overrides:
|
67 |
+
assert len(cmdline_args.overrides) % 2 == 0, "overides arguments is not paired, required: key value"
|
68 |
+
keys = [cmdline_args.overrides[idx*2] for idx in range(len(cmdline_args.overrides)//2)]
|
69 |
+
vals = [cmdline_args.overrides[idx*2+1] for idx in range(len(cmdline_args.overrides)//2)]
|
70 |
+
vals = [val.replace('false', '').replace('False','') if len(val.replace(' ', '')) == 5 else val for val in vals]
|
71 |
+
|
72 |
+
types = []
|
73 |
+
for key in keys:
|
74 |
+
key = key.split('.')
|
75 |
+
ele = opt.copy()
|
76 |
+
while len(key) > 0:
|
77 |
+
ele = ele[key.pop(0)]
|
78 |
+
types.append(type(ele))
|
79 |
+
|
80 |
+
config_dict = {x:z(y) for x,y,z in zip(keys, vals, types)}
|
81 |
+
load_config_dict_to_opt(opt, config_dict)
|
82 |
+
|
83 |
+
# combine cmdline_args into opt dictionary
|
84 |
+
for key, val in cmdline_args.__dict__.items():
|
85 |
+
if val is not None:
|
86 |
+
opt[key] = val
|
87 |
+
|
88 |
+
return opt, cmdline_args
|
89 |
+
|
90 |
+
|
91 |
+
def save_opt_to_json(opt, conf_file):
|
92 |
+
with open(conf_file, 'w', encoding='utf-8') as f:
|
93 |
+
json.dump(opt, f, indent=4)
|
94 |
+
|
95 |
+
|
96 |
+
def save_opt_to_yaml(opt, conf_file):
|
97 |
+
with open(conf_file, 'w', encoding='utf-8') as f:
|
98 |
+
yaml.dump(opt, f)
|