Spaces:
Running
Running
Upload utils.py
Browse files
utils.py
CHANGED
@@ -11,7 +11,8 @@ import torch
|
|
11 |
|
12 |
MATPLOTLIB_FLAG = False
|
13 |
|
14 |
-
|
|
|
15 |
|
16 |
|
17 |
def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
|
@@ -22,7 +23,7 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
|
|
22 |
if optimizer is not None and not skip_optimizer and checkpoint_dict['optimizer'] is not None:
|
23 |
optimizer.load_state_dict(checkpoint_dict['optimizer'])
|
24 |
elif optimizer is None and not skip_optimizer:
|
25 |
-
#else:
|
26 |
new_opt_dict = optimizer.state_dict()
|
27 |
new_opt_dict_params = new_opt_dict['param_groups'][0]['params']
|
28 |
new_opt_dict['param_groups'] = checkpoint_dict['optimizer']['param_groups']
|
@@ -41,12 +42,13 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
|
|
41 |
new_state_dict[k] = saved_state_dict[k]
|
42 |
assert saved_state_dict[k].shape == v.shape, (saved_state_dict[k].shape, v.shape)
|
43 |
except:
|
44 |
-
|
45 |
new_state_dict[k] = v
|
46 |
if hasattr(model, 'module'):
|
47 |
model.module.load_state_dict(new_state_dict, strict=False)
|
48 |
else:
|
49 |
model.load_state_dict(new_state_dict, strict=False)
|
|
|
50 |
logger.info("Loaded checkpoint '{}' (iteration {})".format(
|
51 |
checkpoint_path, iteration))
|
52 |
return model, optimizer, learning_rate, iteration
|
@@ -154,8 +156,9 @@ def get_hparams(init=True):
|
|
154 |
parser = argparse.ArgumentParser()
|
155 |
parser.add_argument('-c', '--config', type=str, default="./configs/base.json",
|
156 |
help='JSON file for configuration')
|
157 |
-
parser.add_argument('-m', '--model', type=str,
|
158 |
help='Model name')
|
|
|
159 |
|
160 |
args = parser.parse_args()
|
161 |
model_dir = os.path.join("./logs", args.model)
|
@@ -177,6 +180,7 @@ def get_hparams(init=True):
|
|
177 |
|
178 |
hparams = HParams(**config)
|
179 |
hparams.model_dir = model_dir
|
|
|
180 |
return hparams
|
181 |
|
182 |
|
@@ -204,7 +208,7 @@ def clean_checkpoints(path_to_models='logs/44k/', n_ckpts_to_keep=2, sort_by_tim
|
|
204 |
|
205 |
def get_hparams_from_dir(model_dir):
|
206 |
config_save_path = os.path.join(model_dir, "config.json")
|
207 |
-
with open(config_save_path, "r"
|
208 |
data = f.read()
|
209 |
config = json.loads(data)
|
210 |
|
@@ -214,7 +218,7 @@ def get_hparams_from_dir(model_dir):
|
|
214 |
|
215 |
|
216 |
def get_hparams_from_file(config_path):
|
217 |
-
with open(config_path, "r"
|
218 |
data = f.read()
|
219 |
config = json.loads(data)
|
220 |
|
|
|
11 |
|
12 |
MATPLOTLIB_FLAG = False
|
13 |
|
14 |
+
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
|
15 |
+
logger = logging
|
16 |
|
17 |
|
18 |
def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
|
|
|
23 |
if optimizer is not None and not skip_optimizer and checkpoint_dict['optimizer'] is not None:
|
24 |
optimizer.load_state_dict(checkpoint_dict['optimizer'])
|
25 |
elif optimizer is None and not skip_optimizer:
|
26 |
+
#else: #Disable this line if Infer ,and enable the line upper
|
27 |
new_opt_dict = optimizer.state_dict()
|
28 |
new_opt_dict_params = new_opt_dict['param_groups'][0]['params']
|
29 |
new_opt_dict['param_groups'] = checkpoint_dict['optimizer']['param_groups']
|
|
|
42 |
new_state_dict[k] = saved_state_dict[k]
|
43 |
assert saved_state_dict[k].shape == v.shape, (saved_state_dict[k].shape, v.shape)
|
44 |
except:
|
45 |
+
print("error, %s is not in the checkpoint" % k)
|
46 |
new_state_dict[k] = v
|
47 |
if hasattr(model, 'module'):
|
48 |
model.module.load_state_dict(new_state_dict, strict=False)
|
49 |
else:
|
50 |
model.load_state_dict(new_state_dict, strict=False)
|
51 |
+
print("load ")
|
52 |
logger.info("Loaded checkpoint '{}' (iteration {})".format(
|
53 |
checkpoint_path, iteration))
|
54 |
return model, optimizer, learning_rate, iteration
|
|
|
156 |
parser = argparse.ArgumentParser()
|
157 |
parser.add_argument('-c', '--config', type=str, default="./configs/base.json",
|
158 |
help='JSON file for configuration')
|
159 |
+
parser.add_argument('-m', '--model', type=str, default="./OUTPUT_MODEL",
|
160 |
help='Model name')
|
161 |
+
parser.add_argument('--cont', dest='cont', action="store_true", default=False, help="whether to continue training on the latest checkpoint")
|
162 |
|
163 |
args = parser.parse_args()
|
164 |
model_dir = os.path.join("./logs", args.model)
|
|
|
180 |
|
181 |
hparams = HParams(**config)
|
182 |
hparams.model_dir = model_dir
|
183 |
+
hparams.cont = args.cont
|
184 |
return hparams
|
185 |
|
186 |
|
|
|
208 |
|
209 |
def get_hparams_from_dir(model_dir):
|
210 |
config_save_path = os.path.join(model_dir, "config.json")
|
211 |
+
with open(config_save_path, "r") as f:
|
212 |
data = f.read()
|
213 |
config = json.loads(data)
|
214 |
|
|
|
218 |
|
219 |
|
220 |
def get_hparams_from_file(config_path):
|
221 |
+
with open(config_path, "r") as f:
|
222 |
data = f.read()
|
223 |
config = json.loads(data)
|
224 |
|