Spaces:
Running
Running
import os | |
import sys | |
import logging | |
import torch | |
from contants import config | |
MATPLOTLIB_FLAG = False | |
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) | |
logger = logging | |
def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False, version=None): | |
assert os.path.isfile(checkpoint_path) | |
checkpoint_dict = torch.load(checkpoint_path, map_location=config.system.device) | |
iteration = checkpoint_dict['iteration'] | |
learning_rate = checkpoint_dict['learning_rate'] | |
if optimizer is not None and not skip_optimizer and checkpoint_dict['optimizer'] is not None: | |
optimizer.load_state_dict(checkpoint_dict['optimizer']) | |
elif optimizer is None and not skip_optimizer: | |
# else: #Disable this line if Infer ,and enable the line upper | |
new_opt_dict = optimizer.state_dict() | |
new_opt_dict_params = new_opt_dict['param_groups'][0]['params'] | |
new_opt_dict['param_groups'] = checkpoint_dict['optimizer']['param_groups'] | |
new_opt_dict['param_groups'][0]['params'] = new_opt_dict_params | |
optimizer.load_state_dict(new_opt_dict) | |
saved_state_dict = checkpoint_dict['model'] | |
if hasattr(model, 'module'): | |
state_dict = model.module.state_dict() | |
else: | |
state_dict = model.state_dict() | |
new_state_dict = {} | |
for k, v in state_dict.items(): | |
try: | |
# assert "emb_g" not in k | |
# print("load", k) | |
new_state_dict[k] = saved_state_dict[k] | |
assert saved_state_dict[k].shape == v.shape, (saved_state_dict[k].shape, v.shape) | |
except: | |
# Handle legacy model versions and provide appropriate warnings | |
if "ja_bert_proj" in k: | |
v = torch.zeros_like(v) | |
if version is None: | |
logger.error(f"{k} is not in the checkpoint") | |
logger.warning( | |
f"If you're using an older version of the model, consider adding the \"version\" parameter to the model's config.json. For instance: \"version\": \"1.0.1\"") | |
elif "flow.flows.0.enc.attn_layers.3" in k: | |
logger.error(f"{k} is not in the checkpoint") | |
logger.warning( | |
f"If you're using a transitional version, please add the \"version\": \"1.1.0-transition\" parameter to the model's config.json. For instance: \"version\": \"1.1.0-transition\"") | |
elif "en_bert_proj" in k: | |
v = torch.zeros_like(v) | |
if version is None: | |
logger.error(f"{k} is not in the checkpoint") | |
logger.warning( | |
f"If you're using an older version of the model, consider adding the \"version\" parameter to the model's config.json. For instance: \"version\": \"1.1.1\"") | |
else: | |
logger.error(f"{k} is not in the checkpoint") | |
new_state_dict[k] = v | |
if hasattr(model, 'module'): | |
model.module.load_state_dict(new_state_dict, strict=False) | |
else: | |
model.load_state_dict(new_state_dict, strict=False) | |
# print("load ") | |
logger.info("Loaded checkpoint '{}' (iteration {})".format( | |
checkpoint_path, iteration)) | |
return model, optimizer, learning_rate, iteration | |
def process_legacy_versions(hps): | |
version = getattr(hps, "version", getattr(hps.data, "version", None)) | |
if version: | |
prefix = version[0].lower() | |
if prefix == "v": | |
version = version[1:] | |
return version | |