Spaces:
Runtime error
Runtime error
import json | |
import os | |
import torch | |
from transformers import GPT2Tokenizer, GPT2LMHeadModel | |
from model import get_model | |
device = 'cuda' | |
models_path = 'models' | |
tokenizer = GPT2Tokenizer.from_pretrained('gpt2') | |
def load_gpt2(): | |
model = GPT2LMHeadModel.from_pretrained('gpt2').to(device) | |
return model | |
def load_unprejudiced(model_name): | |
model_path = os.path.join( | |
models_path, f'{model_name}.pth' | |
) | |
model_json_path = os.path.join( | |
models_path, f'{model_name}.json' | |
) | |
with open(model_json_path) as f: | |
config = json.loads(f.read()) | |
combination = config['combination'] | |
unprejudiced_model = get_model( | |
device=device, | |
gpt2_name='gpt2', | |
in_net=combination['in_net'], | |
in_net_init_identity=combination['in_net_init_identity'], | |
out_net=combination['out_net'], | |
out_net_init_identity=combination['out_net_init_identity'], | |
freeze_ln=combination['freeze_ln'], | |
freeze_pos=combination['freeze_pos'], | |
freeze_wte=combination['freeze_wte'], | |
freeze_ff=combination['freeze_ff'], | |
freeze_attn=combination['freeze_attn'], | |
dup_lm_head=combination['dup_lm_head'], | |
dup_lm_head_bias=combination['dup_lm_head_bias'] | |
) | |
checkpoint = torch.load(model_path, map_location=device) | |
unprejudiced_model.load_state_dict(checkpoint['model_state_dict']) | |
unprejudiced_model = unprejudiced_model.to(device) | |
return unprejudiced_model | |