Rongjiehuang commited on
Commit
98a2c89
1 Parent(s): 6ceff9a
Files changed (1) hide show
  1. utils/hparams.py +35 -38
utils/hparams.py CHANGED
@@ -1,7 +1,5 @@
1
  import argparse
2
  import os
3
- import subprocess
4
-
5
  import yaml
6
 
7
  global_print_hparams = True
@@ -23,31 +21,30 @@ def override_config(old_config: dict, new_config: dict):
23
 
24
 
25
  def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, global_hparams=True):
26
- if config == '' and exp_name == '':
27
- parser = argparse.ArgumentParser(description='')
28
- parser.add_argument('--config', type=str, default='configs/config_base.yaml',
29
  help='location of the data corpus')
30
  parser.add_argument('--exp_name', type=str, default='', help='exp_name')
31
  parser.add_argument('--hparams', type=str, default='',
32
  help='location of the data corpus')
33
- parser.add_argument('--infer', action='store_true', help='infer')
34
  parser.add_argument('--validate', action='store_true', help='validate')
35
  parser.add_argument('--reset', action='store_true', help='reset hparams')
36
- parser.add_argument('--remove', action='store_true', help='remove old ckpt')
37
  parser.add_argument('--debug', action='store_true', help='debug')
38
  args, unknown = parser.parse_known_args()
39
  else:
40
  args = Args(config=config, exp_name=exp_name, hparams=hparams_str,
41
  infer=False, validate=False, reset=False, debug=False)
42
- global hparams
43
- assert args.config != '' or args.exp_name != ''
 
 
44
 
45
  config_chains = []
46
  loaded_config = set()
47
 
48
  def load_config(config_fn): # deep first
49
- if not os.path.exists(config_fn):
50
- return {}
51
  with open(config_fn) as f:
52
  hparams_ = yaml.safe_load(f)
53
  loaded_config.add(config_fn)
@@ -56,10 +53,10 @@ def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, glob
56
  if not isinstance(hparams_['base_config'], list):
57
  hparams_['base_config'] = [hparams_['base_config']]
58
  for c in hparams_['base_config']:
59
- if c.startswith('.'):
60
- c = f'{os.path.dirname(config_fn)}/{c}'
61
- c = os.path.normpath(c)
62
  if c not in loaded_config:
 
 
 
63
  override_config(ret_hparams, load_config(c))
64
  override_config(ret_hparams, hparams_)
65
  else:
@@ -67,53 +64,48 @@ def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, glob
67
  config_chains.append(config_fn)
68
  return ret_hparams
69
 
 
 
70
  saved_hparams = {}
71
- args_work_dir = ''
72
- if args.exp_name != '':
73
- args_work_dir = f'checkpoints/{args.exp_name}'
74
  ckpt_config_path = f'{args_work_dir}/config.yaml'
75
  if os.path.exists(ckpt_config_path):
76
- with open(ckpt_config_path) as f:
77
- saved_hparams.update(yaml.safe_load(f))
 
 
 
 
 
 
78
  hparams_ = {}
79
- if args.config != '':
80
- hparams_.update(load_config(args.config))
81
  if not args.reset:
82
  hparams_.update(saved_hparams)
83
  hparams_['work_dir'] = args_work_dir
84
 
85
- # --hparams="a=1,b.c=2,d=[1 1 1]"
86
  if args.hparams != "":
87
  for new_hparam in args.hparams.split(","):
88
  k, v = new_hparam.split("=")
89
- v = v.strip("\'\" ")
90
- config_node = hparams_
91
- for k_ in k.split(".")[:-1]:
92
- config_node = config_node[k_]
93
- k = k.split(".")[-1]
94
- if v in ['True', 'False'] or type(config_node[k]) in [bool, list, dict]:
95
- if type(config_node[k]) == list:
96
- v = v.replace(" ", ",")
97
- config_node[k] = eval(v)
98
  else:
99
- config_node[k] = type(config_node[k])(v)
100
- if args_work_dir != '' and args.remove:
101
- answer = input("REMOVE old checkpoint? Y/N [Default: N]: ")
102
- if answer.lower() == "y":
103
- subprocess.check_call(f'rm -rf {args_work_dir}', shell=True)
104
  if args_work_dir != '' and (not os.path.exists(ckpt_config_path) or args.reset) and not args.infer:
105
  os.makedirs(hparams_['work_dir'], exist_ok=True)
106
  with open(ckpt_config_path, 'w') as f:
107
  yaml.safe_dump(hparams_, f)
108
 
109
- hparams_['infer'] = args.infer
110
  hparams_['debug'] = args.debug
111
  hparams_['validate'] = args.validate
112
- hparams_['exp_name'] = args.exp_name
113
  global global_print_hparams
114
  if global_hparams:
115
  hparams.clear()
116
  hparams.update(hparams_)
 
117
  if print_hparams and global_print_hparams and global_hparams:
118
  print('| Hparams chains: ', config_chains)
119
  print('| Hparams: ')
@@ -121,4 +113,9 @@ def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, glob
121
  print(f"\033[;33;m{k}\033[0m: {v}, ", end="\n" if i % 5 == 4 else "")
122
  print("")
123
  global_print_hparams = False
 
 
 
 
 
124
  return hparams_
 
1
  import argparse
2
  import os
 
 
3
  import yaml
4
 
5
  global_print_hparams = True
 
21
 
22
 
23
  def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, global_hparams=True):
24
+ if config == '':
25
+ parser = argparse.ArgumentParser(description='neural music')
26
+ parser.add_argument('--config', type=str, default='',
27
  help='location of the data corpus')
28
  parser.add_argument('--exp_name', type=str, default='', help='exp_name')
29
  parser.add_argument('--hparams', type=str, default='',
30
  help='location of the data corpus')
31
+ parser.add_argument('--inference', action='store_true', help='inference')
32
  parser.add_argument('--validate', action='store_true', help='validate')
33
  parser.add_argument('--reset', action='store_true', help='reset hparams')
 
34
  parser.add_argument('--debug', action='store_true', help='debug')
35
  args, unknown = parser.parse_known_args()
36
  else:
37
  args = Args(config=config, exp_name=exp_name, hparams=hparams_str,
38
  infer=False, validate=False, reset=False, debug=False)
39
+ args_work_dir = ''
40
+ if args.exp_name != '':
41
+ args.work_dir = args.exp_name
42
+ args_work_dir = f'checkpoints/{args.work_dir}'
43
 
44
  config_chains = []
45
  loaded_config = set()
46
 
47
  def load_config(config_fn): # deep first
 
 
48
  with open(config_fn) as f:
49
  hparams_ = yaml.safe_load(f)
50
  loaded_config.add(config_fn)
 
53
  if not isinstance(hparams_['base_config'], list):
54
  hparams_['base_config'] = [hparams_['base_config']]
55
  for c in hparams_['base_config']:
 
 
 
56
  if c not in loaded_config:
57
+ if c.startswith('.'):
58
+ c = f'{os.path.dirname(config_fn)}/{c}'
59
+ c = os.path.normpath(c)
60
  override_config(ret_hparams, load_config(c))
61
  override_config(ret_hparams, hparams_)
62
  else:
 
64
  config_chains.append(config_fn)
65
  return ret_hparams
66
 
67
+ global hparams
68
+ assert args.config != '' or args_work_dir != ''
69
  saved_hparams = {}
70
+ if args_work_dir != 'checkpoints/':
 
 
71
  ckpt_config_path = f'{args_work_dir}/config.yaml'
72
  if os.path.exists(ckpt_config_path):
73
+ try:
74
+ with open(ckpt_config_path) as f:
75
+ saved_hparams.update(yaml.safe_load(f))
76
+ except:
77
+ pass
78
+ if args.config == '':
79
+ args.config = ckpt_config_path
80
+
81
  hparams_ = {}
82
+ hparams_.update(load_config(args.config))
83
+
84
  if not args.reset:
85
  hparams_.update(saved_hparams)
86
  hparams_['work_dir'] = args_work_dir
87
 
 
88
  if args.hparams != "":
89
  for new_hparam in args.hparams.split(","):
90
  k, v = new_hparam.split("=")
91
+ if v in ['True', 'False'] or type(hparams_[k]) == bool:
92
+ hparams_[k] = eval(v)
 
 
 
 
 
 
 
93
  else:
94
+ hparams_[k] = type(hparams_[k])(v)
95
+
 
 
 
96
  if args_work_dir != '' and (not os.path.exists(ckpt_config_path) or args.reset) and not args.infer:
97
  os.makedirs(hparams_['work_dir'], exist_ok=True)
98
  with open(ckpt_config_path, 'w') as f:
99
  yaml.safe_dump(hparams_, f)
100
 
101
+ hparams_['inference'] = args.infer
102
  hparams_['debug'] = args.debug
103
  hparams_['validate'] = args.validate
 
104
  global global_print_hparams
105
  if global_hparams:
106
  hparams.clear()
107
  hparams.update(hparams_)
108
+
109
  if print_hparams and global_print_hparams and global_hparams:
110
  print('| Hparams chains: ', config_chains)
111
  print('| Hparams: ')
 
113
  print(f"\033[;33;m{k}\033[0m: {v}, ", end="\n" if i % 5 == 4 else "")
114
  print("")
115
  global_print_hparams = False
116
+ # print(hparams_.keys())
117
+ if hparams.get('exp_name') is None:
118
+ hparams['exp_name'] = args.exp_name
119
+ if hparams_.get('exp_name') is None:
120
+ hparams_['exp_name'] = args.exp_name
121
  return hparams_