esper / load.py
jiwan-chung's picture
demo init
0bf81ba
raw
history blame
2.15 kB
import os
import logging
import json
from pathlib import Path
import yaml
import torch
from policy import Policy
logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
log = logging.getLogger(__name__)
def load_model_args(args):
checkpoint = Path(args.checkpoint + '.ckpt')
assert checkpoint.is_file(), f"no checkpoint file: {args.checkpoint}"
args_path = Path(args.checkpoint + '.json')
if args_path.is_file():
with open(args_path) as f:
hparams = json.load(f)
else:
args_path = Path(args.checkpoint + '.yaml')
with open(args_path) as f:
hparams = yaml.safe_load(f)
for key in ['init_model', 'clip_model_type', 'use_caption', 'use_style_reward', 'use_transformer_mapper',
'prefix_length', 'clipcap_num_layers', 'use_ptuning_v2']:
if key in hparams:
setattr(args, key, hparams[key])
args.loaded_init_model = True
return args
def load_model(args, device, finetune=False):
log.info('loading model')
policy = Policy(model_name=args.init_model, temperature=1.0, device=device,
clipcap_path='None', fix_gpt=True,
label_path=args.label_path,
prefix_length=args.prefix_length,
clipcap_num_layers=args.clipcap_num_layers,
use_transformer_mapper=args.use_transformer_mapper,
model_weight='None', use_label_prefix=args.use_label_prefix)
ckpt = args.checkpoint + '.ckpt'
state = torch.load(ckpt)
policy_key = 'policy_model'
if policy_key in state:
policy.model.load_state_dict(state[policy_key])
else:
weights = state['state_dict']
key = 'policy.model.'
if not any(k for k in weights.keys() if k.startswith(key)):
key = 'model.model.'
weights = {k[len(key):]: v for k, v in weights.items() if k.startswith(key)}
# weights = {k: v for k, v in weights.items() if k.startswith('clip_project.')}
policy.model.load_state_dict(weights, strict=False)
model = policy
model = model.to(device)
return model