Spaces:
Runtime error
Runtime error
import os | |
import logging | |
import json | |
from pathlib import Path | |
import yaml | |
import torch | |
from policy import Policy | |
logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO")) | |
log = logging.getLogger(__name__) | |
def load_model_args(args): | |
checkpoint = Path(args.checkpoint + '.ckpt') | |
assert checkpoint.is_file(), f"no checkpoint file: {args.checkpoint}" | |
args_path = Path(args.checkpoint + '.json') | |
if args_path.is_file(): | |
with open(args_path) as f: | |
hparams = json.load(f) | |
else: | |
args_path = Path(args.checkpoint + '.yaml') | |
with open(args_path) as f: | |
hparams = yaml.safe_load(f) | |
for key in ['init_model', 'clip_model_type', 'use_caption', 'use_style_reward', 'use_transformer_mapper', | |
'prefix_length', 'clipcap_num_layers', 'use_ptuning_v2']: | |
if key in hparams: | |
setattr(args, key, hparams[key]) | |
args.loaded_init_model = True | |
return args | |
def load_model(args, device, finetune=False): | |
log.info('loading model') | |
policy = Policy(model_name=args.init_model, temperature=1.0, device=device, | |
clipcap_path='None', fix_gpt=True, | |
label_path=args.label_path, | |
prefix_length=args.prefix_length, | |
clipcap_num_layers=args.clipcap_num_layers, | |
use_transformer_mapper=args.use_transformer_mapper, | |
model_weight='None', use_label_prefix=args.use_label_prefix) | |
ckpt = args.checkpoint + '.ckpt' | |
state = torch.load(ckpt) | |
policy_key = 'policy_model' | |
if policy_key in state: | |
policy.model.load_state_dict(state[policy_key]) | |
else: | |
weights = state['state_dict'] | |
key = 'policy.model.' | |
if not any(k for k in weights.keys() if k.startswith(key)): | |
key = 'model.model.' | |
weights = {k[len(key):]: v for k, v in weights.items() if k.startswith(key)} | |
# weights = {k: v for k, v in weights.items() if k.startswith('clip_project.')} | |
policy.model.load_state_dict(weights, strict=False) | |
model = policy | |
model = model.to(device) | |
return model | |