esper / load.py
jiwan-chung's picture
running on cpu
5a61cb9
raw history blame
No virus
2.18 kB
import os
import logging
import json
from pathlib import Path
import yaml
import torch
from policy import Policy
logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
log = logging.getLogger(__name__)
def load_model_args(args):
checkpoint = Path(args.checkpoint + '.ckpt')
assert checkpoint.is_file(), f"no checkpoint file: {checkpoint}"
args_path = Path(args.checkpoint + '.json')
if args_path.is_file():
with open(args_path) as f:
hparams = json.load(f)
else:
args_path = Path(args.checkpoint + '.yaml')
with open(args_path) as f:
hparams = yaml.safe_load(f)
for key in ['init_model', 'clip_model_type', 'use_caption', 'use_style_reward', 'use_transformer_mapper',
'prefix_length', 'clipcap_num_layers', 'use_ptuning_v2']:
if key in hparams:
setattr(args, key, hparams[key])
args.loaded_init_model = True
return args
def load_model(args, device, finetune=False):
log.info('loading model')
policy = Policy(model_name=args.init_model, temperature=1.0, device=device,
clipcap_path='None', fix_gpt=True,
label_path=args.label_path,
prefix_length=args.prefix_length,
clipcap_num_layers=args.clipcap_num_layers,
use_transformer_mapper=args.use_transformer_mapper,
model_weight='None', use_label_prefix=args.use_label_prefix)
ckpt = args.checkpoint + '.ckpt'
state = torch.load(ckpt, map_location=torch.device('cpu'))
policy_key = 'policy_model'
if policy_key in state:
policy.model.load_state_dict(state[policy_key])
else:
weights = state['state_dict']
key = 'policy.model.'
if not any(k for k in weights.keys() if k.startswith(key)):
key = 'model.model.'
weights = {k[len(key):]: v for k, v in weights.items() if k.startswith(key)}
# weights = {k: v for k, v in weights.items() if k.startswith('clip_project.')}
policy.model.load_state_dict(weights, strict=False)
model = policy
model = model.to(device)
return model