XzJosh commited on
Commit
a89409b
1 Parent(s): e33d8a8

Upload utils.py

Browse files
Files changed (1) hide show
  1. utils.py +10 -6
utils.py CHANGED
@@ -11,7 +11,8 @@ import torch
11
 
12
  MATPLOTLIB_FLAG = False
13
 
14
- logger = logging.getLogger(__name__)
 
15
 
16
 
17
  def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
@@ -22,7 +23,7 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
22
  if optimizer is not None and not skip_optimizer and checkpoint_dict['optimizer'] is not None:
23
  optimizer.load_state_dict(checkpoint_dict['optimizer'])
24
  elif optimizer is None and not skip_optimizer:
25
- #else: Disable this line if Infer and resume checkpoint,then enable the line upper
26
  new_opt_dict = optimizer.state_dict()
27
  new_opt_dict_params = new_opt_dict['param_groups'][0]['params']
28
  new_opt_dict['param_groups'] = checkpoint_dict['optimizer']['param_groups']
@@ -41,12 +42,13 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
41
  new_state_dict[k] = saved_state_dict[k]
42
  assert saved_state_dict[k].shape == v.shape, (saved_state_dict[k].shape, v.shape)
43
  except:
44
- logger.error("%s is not in the checkpoint" % k)
45
  new_state_dict[k] = v
46
  if hasattr(model, 'module'):
47
  model.module.load_state_dict(new_state_dict, strict=False)
48
  else:
49
  model.load_state_dict(new_state_dict, strict=False)
 
50
  logger.info("Loaded checkpoint '{}' (iteration {})".format(
51
  checkpoint_path, iteration))
52
  return model, optimizer, learning_rate, iteration
@@ -154,8 +156,9 @@ def get_hparams(init=True):
154
  parser = argparse.ArgumentParser()
155
  parser.add_argument('-c', '--config', type=str, default="./configs/base.json",
156
  help='JSON file for configuration')
157
- parser.add_argument('-m', '--model', type=str, required=True,
158
  help='Model name')
 
159
 
160
  args = parser.parse_args()
161
  model_dir = os.path.join("./logs", args.model)
@@ -177,6 +180,7 @@ def get_hparams(init=True):
177
 
178
  hparams = HParams(**config)
179
  hparams.model_dir = model_dir
 
180
  return hparams
181
 
182
 
@@ -204,7 +208,7 @@ def clean_checkpoints(path_to_models='logs/44k/', n_ckpts_to_keep=2, sort_by_tim
204
 
205
  def get_hparams_from_dir(model_dir):
206
  config_save_path = os.path.join(model_dir, "config.json")
207
- with open(config_save_path, "r", encoding='utf-8') as f:
208
  data = f.read()
209
  config = json.loads(data)
210
 
@@ -214,7 +218,7 @@ def get_hparams_from_dir(model_dir):
214
 
215
 
216
  def get_hparams_from_file(config_path):
217
- with open(config_path, "r", encoding='utf-8') as f:
218
  data = f.read()
219
  config = json.loads(data)
220
 
 
11
 
12
  MATPLOTLIB_FLAG = False
13
 
14
+ logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
15
+ logger = logging
16
 
17
 
18
  def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
 
23
  if optimizer is not None and not skip_optimizer and checkpoint_dict['optimizer'] is not None:
24
  optimizer.load_state_dict(checkpoint_dict['optimizer'])
25
  elif optimizer is None and not skip_optimizer:
26
+ #else: #Disable this line if Infer ,and enable the line upper
27
  new_opt_dict = optimizer.state_dict()
28
  new_opt_dict_params = new_opt_dict['param_groups'][0]['params']
29
  new_opt_dict['param_groups'] = checkpoint_dict['optimizer']['param_groups']
 
42
  new_state_dict[k] = saved_state_dict[k]
43
  assert saved_state_dict[k].shape == v.shape, (saved_state_dict[k].shape, v.shape)
44
  except:
45
+ print("error, %s is not in the checkpoint" % k)
46
  new_state_dict[k] = v
47
  if hasattr(model, 'module'):
48
  model.module.load_state_dict(new_state_dict, strict=False)
49
  else:
50
  model.load_state_dict(new_state_dict, strict=False)
51
+ print("load ")
52
  logger.info("Loaded checkpoint '{}' (iteration {})".format(
53
  checkpoint_path, iteration))
54
  return model, optimizer, learning_rate, iteration
 
156
  parser = argparse.ArgumentParser()
157
  parser.add_argument('-c', '--config', type=str, default="./configs/base.json",
158
  help='JSON file for configuration')
159
+ parser.add_argument('-m', '--model', type=str, default="./OUTPUT_MODEL",
160
  help='Model name')
161
+ parser.add_argument('--cont', dest='cont', action="store_true", default=False, help="whether to continue training on the latest checkpoint")
162
 
163
  args = parser.parse_args()
164
  model_dir = os.path.join("./logs", args.model)
 
180
 
181
  hparams = HParams(**config)
182
  hparams.model_dir = model_dir
183
+ hparams.cont = args.cont
184
  return hparams
185
 
186
 
 
208
 
209
  def get_hparams_from_dir(model_dir):
210
  config_save_path = os.path.join(model_dir, "config.json")
211
+ with open(config_save_path, "r") as f:
212
  data = f.read()
213
  config = json.loads(data)
214
 
 
218
 
219
 
220
  def get_hparams_from_file(config_path):
221
+ with open(config_path, "r") as f:
222
  data = f.read()
223
  config = json.loads(data)
224