File size: 6,120 Bytes
158b61b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import division
import six
import argparse
import torch
from onmt.utils.logging import init_logger, logger


def get_vocabs(dict_path):
    fields = torch.load(dict_path)

    vocs = []
    for side in ['src', 'tgt']:
        try:
            vocab = fields[side].base_field.vocab
        except AttributeError:
            vocab = fields[side].vocab
        vocs.append(vocab)
    enc_vocab, dec_vocab = vocs

    logger.info("From: %s" % dict_path)
    logger.info("\t* source vocab: %d words" % len(enc_vocab))
    logger.info("\t* target vocab: %d words" % len(dec_vocab))

    return enc_vocab, dec_vocab


def read_embeddings(file_enc, skip_lines=0, filter_set=None):
    embs = dict()
    total_vectors_in_file = 0
    with open(file_enc, 'rb') as f:
        for i, line in enumerate(f):
            if i < skip_lines:
                continue
            if not line:
                break
            if len(line) == 0:
                # is this reachable?
                continue

            l_split = line.decode('utf8').strip().split(' ')
            if len(l_split) == 2:
                continue
            total_vectors_in_file += 1
            if filter_set is not None and l_split[0] not in filter_set:
                continue
            embs[l_split[0]] = [float(em) for em in l_split[1:]]
    return embs, total_vectors_in_file


def convert_to_torch_tensor(word_to_float_list_dict, vocab):
    dim = len(six.next(six.itervalues(word_to_float_list_dict)))
    tensor = torch.zeros((len(vocab), dim))
    for word, values in word_to_float_list_dict.items():
        tensor[vocab.stoi[word]] = torch.Tensor(values)
    return tensor


def calc_vocab_load_stats(vocab, loaded_embed_dict):
    matching_count = len(
        set(vocab.stoi.keys()) & set(loaded_embed_dict.keys()))
    missing_count = len(vocab) - matching_count
    percent_matching = matching_count / len(vocab) * 100
    return matching_count, missing_count, percent_matching


def main():
    parser = argparse.ArgumentParser(description='embeddings_to_torch.py')
    parser.add_argument('-emb_file_both', required=False,
                        help="loads Embeddings for both source and target "
                             "from this file.")
    parser.add_argument('-emb_file_enc', required=False,
                        help="source Embeddings from this file")
    parser.add_argument('-emb_file_dec', required=False,
                        help="target Embeddings from this file")
    parser.add_argument('-output_file', required=True,
                        help="Output file for the prepared data")
    parser.add_argument('-dict_file', required=True,
                        help="Dictionary file")
    parser.add_argument('-verbose', action="store_true", default=False)
    parser.add_argument('-skip_lines', type=int, default=0,
                        help="Skip first lines of the embedding file")
    parser.add_argument('-type', choices=["GloVe", "word2vec"],
                        default="GloVe")
    opt = parser.parse_args()

    enc_vocab, dec_vocab = get_vocabs(opt.dict_file)

    # Read in embeddings
    skip_lines = 1 if opt.type == "word2vec" else opt.skip_lines
    if opt.emb_file_both is not None:
        if opt.emb_file_enc is not None:
            raise ValueError("If --emb_file_both is passed in, you should not"
                             "set --emb_file_enc.")
        if opt.emb_file_dec is not None:
            raise ValueError("If --emb_file_both is passed in, you should not"
                             "set --emb_file_dec.")
        set_of_src_and_tgt_vocab = \
            set(enc_vocab.stoi.keys()) | set(dec_vocab.stoi.keys())
        logger.info("Reading encoder and decoder embeddings from {}".format(
            opt.emb_file_both))
        src_vectors, total_vec_count = \
            read_embeddings(opt.emb_file_both, skip_lines,
                            set_of_src_and_tgt_vocab)
        tgt_vectors = src_vectors
        logger.info("\tFound {} total vectors in file".format(total_vec_count))
    else:
        if opt.emb_file_enc is None:
            raise ValueError("If --emb_file_enc not provided. Please specify "
                             "the file with encoder embeddings, or pass in "
                             "--emb_file_both")
        if opt.emb_file_dec is None:
            raise ValueError("If --emb_file_dec not provided. Please specify "
                             "the file with encoder embeddings, or pass in "
                             "--emb_file_both")
        logger.info("Reading encoder embeddings from {}".format(
            opt.emb_file_enc))
        src_vectors, total_vec_count = read_embeddings(
            opt.emb_file_enc, skip_lines,
            filter_set=enc_vocab.stoi
        )
        logger.info("\tFound {} total vectors in file.".format(
            total_vec_count))
        logger.info("Reading decoder embeddings from {}".format(
            opt.emb_file_dec))
        tgt_vectors, total_vec_count = read_embeddings(
            opt.emb_file_dec, skip_lines,
            filter_set=dec_vocab.stoi
        )
        logger.info("\tFound {} total vectors in file".format(total_vec_count))
    logger.info("After filtering to vectors in vocab:")
    logger.info("\t* enc: %d match, %d missing, (%.2f%%)"
                % calc_vocab_load_stats(enc_vocab, src_vectors))
    logger.info("\t* dec: %d match, %d missing, (%.2f%%)"
                % calc_vocab_load_stats(dec_vocab, tgt_vectors))

    # Write to file
    enc_output_file = opt.output_file + ".enc.pt"
    dec_output_file = opt.output_file + ".dec.pt"
    logger.info("\nSaving embedding as:\n\t* enc: %s\n\t* dec: %s"
                % (enc_output_file, dec_output_file))
    torch.save(
        convert_to_torch_tensor(src_vectors, enc_vocab),
        enc_output_file
    )
    torch.save(
        convert_to_torch_tensor(tgt_vectors, dec_vocab),
        dec_output_file
    )
    logger.info("\nDone.")


if __name__ == "__main__":
    init_logger('embeddings_to_torch.log')
    main()