File size: 1,613 Bytes
6f82d3b
 
 
 
 
 
74b3160
6f82d3b
 
 
 
659007c
6f82d3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import json
import os
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from model import get_model

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
models_path = 'models'

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')


def load_gpt2():

    model = GPT2LMHeadModel.from_pretrained('gpt2').to(device)
    return model


def load_unprejudiced(model_name):
    model_path = os.path.join(
        models_path, f'{model_name}.pth'
    )
    model_json_path = os.path.join(
        models_path, f'{model_name}.json'
    )

    with open(model_json_path) as f:
        config = json.loads(f.read())
        combination = config['combination']

        unprejudiced_model = get_model(
            device=device,
            gpt2_name='gpt2',
            in_net=combination['in_net'],
            in_net_init_identity=combination['in_net_init_identity'],
            out_net=combination['out_net'],
            out_net_init_identity=combination['out_net_init_identity'],
            freeze_ln=combination['freeze_ln'],
            freeze_pos=combination['freeze_pos'],
            freeze_wte=combination['freeze_wte'],
            freeze_ff=combination['freeze_ff'],
            freeze_attn=combination['freeze_attn'],
            dup_lm_head=combination['dup_lm_head'],
            dup_lm_head_bias=combination['dup_lm_head_bias']
        )
        checkpoint = torch.load(model_path, map_location=device)
        unprejudiced_model.load_state_dict(checkpoint['model_state_dict'])
        unprejudiced_model = unprejudiced_model.to(device)
        return unprejudiced_model