File size: 6,043 Bytes
7dd9869
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
# bert results
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, default_data_collator
import sys, yaml, os
# print( os.path.join(sys.path[0], '../../transformers/examples/pytorch/language-modeling'))
# sys.path.insert(0, 'diffusion_lm/transformers/examples/pytorch/language-modeling')
# sys.path.insert(0, os.path.join(sys.path[0], '../../transformers/examples/pytorch/language-modeling'))
# from custom_trainer import GPT2LMHeadModelCompress, BERTModelCompress, AutoEncoderWithNoise

def load_models(modality, mode, model_name_or_path, emb_dim, file, extra_args=None):

    if mode in ['random', 'random1', 'random_up_proj', 'glove']:
        if modality == 'synth':
            pass# print(file, 'deciding what to load::: ')
            # if 'synth128' in file:
            #     config = 'diffusion_lm/synthetic_data/configs/emnlp2020/experiments/difflm_seed0_m3_k128_trainc20000.yaml'
            # else:
            #     config = 'diffusion_lm/synthetic_data/configs/emnlp2020/experiments/difflm_seed0_m3_k32_trainc20000.yaml'
            # import sys, os
            # sys.path.insert(0, 'diffusion_lm/synthetic_data/rnns-stacks')
            # from dataset import Dataset as SynthDataset
            # args_synth = yaml.load(open(config))
            # dataset = SynthDataset(args_synth)
            # model = torch.nn.Embedding(len(dataset.vocab), emb_dim)
            # print('initializing the random embeddings', model)
            # # print(os.path.split(file.split('.')[0])[-1])
            # # path_save = '{}/random_emb.torch'.format(file)
            # path_save = '{}/random_emb.torch'.format(file)
            # model.load_state_dict(torch.load(path_save))
            # print(dataset.vocab)
            # tokenizer = {v: k for k, v in dataset.vocab.items()}
        else:
            import json
            if modality == 'book' or (extra_args is not None and extra_args.use_bert_tokenizer == 'yes'):
                pass# tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
                # if 'e2e' in file and modality == 'book':
                #     emb_dim = 1
            else:
                path_save_tokenizer = '{}/vocab.json'.format(file)
                path_save_tokenizer = '/data0/gonghaisong/Diffusion-LM/improved-diffusion/diffusion_models/diff_e2e-tgt_block_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart_e2e/vocab.json'
                print(f'loading from {path_save_tokenizer}')
                with open(path_save_tokenizer, 'r') as f:
                    vocab = json.load(f)
                print(len(vocab))
                tokenizer = {v: k for k, v in vocab.items()}
            model = torch.nn.Embedding(len(tokenizer), emb_dim)
            path_save = '{}/random_emb.torch'.format(file)
            path_save = '/data0/gonghaisong/Diffusion-LM/improved-diffusion/diffusion_models/diff_e2e-tgt_block_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart_e2e/random_emb.torch'
            model.load_state_dict(torch.load(path_save))

    return model, tokenizer


def load_tokenizer(modality, mode, model_name_or_path):
    if mode in ['random', 'random_up_proj', 'glove']:
        if modality == 'synth':
            print(model_name_or_path, 'deciding what to load::: ')
            if 'synth128' in model_name_or_path:
                config = 'diffusion_lm/synthetic_data/configs/emnlp2020/experiments/difflm_seed0_m3_k128_trainc20000.yaml'
            else:
                config = 'diffusion_lm/synthetic_data/configs/emnlp2020/experiments/difflm_seed0_m3_k32_trainc20000.yaml'

            import sys, os
            sys.path.insert(0, 'diffusion_lm/synthetic_data/rnns-stacks')
            from dataset import Dataset as SynthDataset
            args_synth = yaml.load(open(config))
            dataset = SynthDataset(args_synth)
            tokenizer = {v: k for k, v in dataset.vocab.items()}
        elif modality =='book':
            tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
        else:
            import json
            path_save_tokenizer = '{}/vocab.json'.format(model_name_or_path)
            with open(path_save_tokenizer, 'r') as f:
                vocab = json.load(f)
            tokenizer = {v: k for k, v in vocab.items()}

    return tokenizer

def rounding_func(mode, text_emb_lst, model, tokenizer, emb_scale_factor=1.0):
    decoded_out_lst = []
    if mode in ['random', 'random_up_proj', 'glove']:
        down_proj_emb = model.weight  # input_embs
        down_proj_emb2 = None


        def get_knn(down_proj_emb, text_emb, dist='cos'):

            if dist == 'cos':
                adjacency = down_proj_emb @ text_emb.transpose(1, 0).to(down_proj_emb.device)
            elif dist == 'l2':
                adjacency = down_proj_emb.unsqueeze(1).expand(-1, text_emb.size(0), -1) - text_emb.unsqueeze(0).expand(
                    down_proj_emb.size(0), -1, -1)
                adjacency = -torch.norm(adjacency, dim=-1)
            topk_out = torch.topk(adjacency, k=6, dim=0)
            return topk_out.values, topk_out.indices

        dist = 'l2'
        # print(npzfile['arr_0'].shape)
        for text_emb in text_emb_lst:
            import torch
            text_emb = torch.tensor(text_emb)
            # print(text_emb.shape)
            if len(text_emb.shape) > 2:
                text_emb = text_emb.view(-1, text_emb.size(-1))
            else:
                text_emb = text_emb
            val, indices = get_knn((down_proj_emb2 if dist == 'cos' else down_proj_emb),
                                   text_emb.to(down_proj_emb.device), dist=dist)
            # generated_lst.append(tuple(indices[0].tolist()))

            # print(indices[0].tolist())
            # for i in range(64):
            #     print([tokenizer[x.item()] for x in indices[:,i]])
            decoded_out = " ".join([tokenizer[i] for i in indices[0].tolist()])
            decoded_out_lst.append(decoded_out)

    return decoded_out_lst