debiasing-lms / load_model.py
Michael Gira
Change device according to environment
74b3160
import json
import os
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from model import get_model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
models_path = 'models'
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
def load_gpt2():
model = GPT2LMHeadModel.from_pretrained('gpt2').to(device)
return model
def load_unprejudiced(model_name):
model_path = os.path.join(
models_path, f'{model_name}.pth'
)
model_json_path = os.path.join(
models_path, f'{model_name}.json'
)
with open(model_json_path) as f:
config = json.loads(f.read())
combination = config['combination']
unprejudiced_model = get_model(
device=device,
gpt2_name='gpt2',
in_net=combination['in_net'],
in_net_init_identity=combination['in_net_init_identity'],
out_net=combination['out_net'],
out_net_init_identity=combination['out_net_init_identity'],
freeze_ln=combination['freeze_ln'],
freeze_pos=combination['freeze_pos'],
freeze_wte=combination['freeze_wte'],
freeze_ff=combination['freeze_ff'],
freeze_attn=combination['freeze_attn'],
dup_lm_head=combination['dup_lm_head'],
dup_lm_head_bias=combination['dup_lm_head_bias']
)
checkpoint = torch.load(model_path, map_location=device)
unprejudiced_model.load_state_dict(checkpoint['model_state_dict'])
unprejudiced_model = unprejudiced_model.to(device)
return unprejudiced_model