fffiloni commited on
Commit
b8ea8bb
1 Parent(s): 80627a9

Create argument.py

Browse files
Files changed (1) hide show
  1. utils/argument.py +98 -0
utils/argument.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ import json
3
+ import argparse
4
+ import logging
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+
9
+ def load_config_dict_to_opt(opt, config_dict):
10
+ """
11
+ Load the key, value pairs from config_dict to opt, overriding existing values in opt
12
+ if there is any.
13
+ """
14
+ if not isinstance(config_dict, dict):
15
+ raise TypeError("Config must be a Python dictionary")
16
+ for k, v in config_dict.items():
17
+ k_parts = k.split('.')
18
+ pointer = opt
19
+ for k_part in k_parts[:-1]:
20
+ if k_part not in pointer:
21
+ pointer[k_part] = {}
22
+ pointer = pointer[k_part]
23
+ assert isinstance(pointer, dict), "Overriding key needs to be inside a Python dict."
24
+ ori_value = pointer.get(k_parts[-1])
25
+ pointer[k_parts[-1]] = v
26
+ if ori_value:
27
+ logger.warning(f"Overrided {k} from {ori_value} to {pointer[k_parts[-1]]}")
28
+
29
+
30
+ def load_opt_from_config_files(conf_file):
31
+ """
32
+ Load opt from the config files, settings in later files can override those in previous files.
33
+
34
+ Args:
35
+ conf_files: config file path
36
+
37
+ Returns:
38
+ dict: a dictionary of opt settings
39
+ """
40
+ opt = {}
41
+ with open(conf_file, encoding='utf-8') as f:
42
+ config_dict = yaml.safe_load(f)
43
+
44
+ load_config_dict_to_opt(opt, config_dict)
45
+
46
+ return opt
47
+
48
+
49
+ def load_opt_command(args):
50
+ parser = argparse.ArgumentParser(description='MainzTrain: Pretrain or fine-tune models for NLP tasks.')
51
+ parser.add_argument('command', help='Command: train/evaluate/train-and-evaluate')
52
+ parser.add_argument('--conf_files', required=True, help='Path(s) to the MainzTrain config file(s).')
53
+ parser.add_argument('--config_overrides', nargs='*', help='Override parameters on config with a json style string, e.g. {"<PARAM_NAME_1>": <PARAM_VALUE_1>, "<PARAM_GROUP_2>.<PARAM_SUBGROUP_2>.<PARAM_2>": <PARAM_VALUE_2>}. A key with "." updates the object in the corresponding nested dict. Remember to escape " in command line.')
54
+ parser.add_argument('--overrides', help='arguments that used to overide the config file in cmdline', nargs=argparse.REMAINDER)
55
+
56
+ cmdline_args = parser.parse_args() if not args else parser.parse_args(args)
57
+
58
+ opt = load_opt_from_config_files(cmdline_args.conf_files)
59
+
60
+ if cmdline_args.config_overrides:
61
+ config_overrides_string = ' '.join(cmdline_args.config_overrides)
62
+ logger.warning(f"Command line config overrides: {config_overrides_string}")
63
+ config_dict = json.loads(config_overrides_string)
64
+ load_config_dict_to_opt(opt, config_dict)
65
+
66
+ if cmdline_args.overrides:
67
+ assert len(cmdline_args.overrides) % 2 == 0, "overides arguments is not paired, required: key value"
68
+ keys = [cmdline_args.overrides[idx*2] for idx in range(len(cmdline_args.overrides)//2)]
69
+ vals = [cmdline_args.overrides[idx*2+1] for idx in range(len(cmdline_args.overrides)//2)]
70
+ vals = [val.replace('false', '').replace('False','') if len(val.replace(' ', '')) == 5 else val for val in vals]
71
+
72
+ types = []
73
+ for key in keys:
74
+ key = key.split('.')
75
+ ele = opt.copy()
76
+ while len(key) > 0:
77
+ ele = ele[key.pop(0)]
78
+ types.append(type(ele))
79
+
80
+ config_dict = {x:z(y) for x,y,z in zip(keys, vals, types)}
81
+ load_config_dict_to_opt(opt, config_dict)
82
+
83
+ # combine cmdline_args into opt dictionary
84
+ for key, val in cmdline_args.__dict__.items():
85
+ if val is not None:
86
+ opt[key] = val
87
+
88
+ return opt, cmdline_args
89
+
90
+
91
+ def save_opt_to_json(opt, conf_file):
92
+ with open(conf_file, 'w', encoding='utf-8') as f:
93
+ json.dump(opt, f, indent=4)
94
+
95
+
96
+ def save_opt_to_yaml(opt, conf_file):
97
+ with open(conf_file, 'w', encoding='utf-8') as f:
98
+ yaml.dump(opt, f)