import json import os import torch from transformers import GPT2Tokenizer, GPT2LMHeadModel from model import get_model device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') models_path = 'models' tokenizer = GPT2Tokenizer.from_pretrained('gpt2') def load_gpt2(): model = GPT2LMHeadModel.from_pretrained('gpt2').to(device) return model def load_unprejudiced(model_name): model_path = os.path.join( models_path, f'{model_name}.pth' ) model_json_path = os.path.join( models_path, f'{model_name}.json' ) with open(model_json_path) as f: config = json.loads(f.read()) combination = config['combination'] unprejudiced_model = get_model( device=device, gpt2_name='gpt2', in_net=combination['in_net'], in_net_init_identity=combination['in_net_init_identity'], out_net=combination['out_net'], out_net_init_identity=combination['out_net_init_identity'], freeze_ln=combination['freeze_ln'], freeze_pos=combination['freeze_pos'], freeze_wte=combination['freeze_wte'], freeze_ff=combination['freeze_ff'], freeze_attn=combination['freeze_attn'], dup_lm_head=combination['dup_lm_head'], dup_lm_head_bias=combination['dup_lm_head_bias'] ) checkpoint = torch.load(model_path, map_location=device) unprejudiced_model.load_state_dict(checkpoint['model_state_dict']) unprejudiced_model = unprejudiced_model.to(device) return unprejudiced_model