File size: 11,732 Bytes
4fb0bd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
from bidict import bidict
import pickle
import logging

logger = logging.getLogger(__name__)


class Vocabulary():
    """This class maps strings to integers, which also allow many namespaces
    """

    DEFAULT_PAD_TOKEN = '*@PAD@*'
    DEFAULT_UNK_TOKEN = '*@UNK@*'

    def __init__(self,
                 counters=dict(),
                 min_count=dict(),
                 pretrained_vocab=dict(),
                 intersection_namespace=dict(),
                 no_pad_namespace=list(),
                 no_unk_namespace=list(),
                 contain_pad_namespace=dict(),
                 contain_unk_namespace=dict()):
        """initialize vocabulary

        Keyword Arguments:
            counters {dict} -- multiple counter (default: {dict()})
            min_count {dict} -- min count dict (default: {dict()})
            pretrained_vocab {dict} -- pretrained vocabulary (default: {dict()})
            intersection_namespace {dict} -- intersection namespace correspond to pretrained vocabulary in case of too large pretrained vocabulary (default: {dict()})
            no_pad_namespace {list} -- no paddding namespace (default: {list()})
            no_unk_namespace {list} -- no unknown namespace (default: {list()})
            contain_pad_namespace {dict} -- contain padding token namespace (default: {dict()})
            contain_unk_namespace {dict} -- contain unknown token namespace (default: {dict()})
        """

        self.min_count = dict(min_count)
        self.intersection_namespace = dict(intersection_namespace)
        self.no_pad_namespace = set(no_pad_namespace)
        self.no_unk_namespace = set(no_unk_namespace)
        self.contain_pad_namespace = dict(contain_pad_namespace)
        self.contain_unk_namespace = dict(contain_unk_namespace)
        self.vocab = dict()

        self.extend_from_counter(counters, self.min_count, self.no_pad_namespace,
                                 self.no_unk_namespace)

        self.extend_from_pretrained_vocab(pretrained_vocab, self.intersection_namespace,
                                          self.no_pad_namespace, self.no_unk_namespace)

        logger.info("Initialize vocabulary successfully.")

    def extend_from_pretrained_vocab(self,
                                     pretrained_vocab,
                                     intersection_namespace=dict(),
                                     no_pad_namespace=list(),
                                     no_unk_namespace=list(),
                                     contain_pad_namespace=dict(),
                                     contain_unk_namespace=dict()):
        """extend vocabulary from pretrained vocab

        Arguments:
            pretrained_vocab {dict} -- pretrained vocabulary

        Keyword Arguments:
            intersection_namespace {dict} -- intersection namespace correspond to pretrained vocabulary in case of too large pretrained vocabulary (default: {dict()})
            no_pad_namespace {list} -- no paddding namespace (default: {list()})
            no_unk_namespace {list} -- no unknown namespace (default: {list()})
            contain_pad_namespace {dict} -- contain padding token namespace (default: {dict()})
            contain_unk_namespace {dict} -- contain unknown token namespace (default: {dict()})
        """

        self.intersection_namespace.update(dict(intersection_namespace))
        self.no_pad_namespace.update(set(no_pad_namespace))
        self.no_unk_namespace.update(set(no_unk_namespace))
        self.contain_pad_namespace.update(dict(contain_pad_namespace))
        self.contain_unk_namespace.update(dict(contain_unk_namespace))

        for namespace, vocab in pretrained_vocab.items():
            self.__namespace_init(namespace)
            is_intersection = namespace in self.intersection_namespace
            intersection_vocab = self.vocab[
                self.intersection_namespace[namespace]] if is_intersection else []
            for key, value in vocab.items():
                if not is_intersection or key in intersection_vocab:
                    self.vocab[namespace][key] = value

            logger.info(
                "Vocabulay {} (size: {}) was constructed successfully from pretrained_vocab.".
                format(namespace, len(self.vocab[namespace])))

    def extend_from_counter(self,
                            counters,
                            min_count=dict(),
                            no_pad_namespace=list(),
                            no_unk_namespace=list(),
                            contain_pad_namespace=dict(),
                            contain_unk_namespace=dict()):
        """extend vocabulary from counter

        Arguments:
            counters {dict} -- multiply counter

        Keyword Arguments:
            min_count {dict} -- min count dict (default: {dict()})
            no_pad_namespace {list} -- no paddding namespace (default: {list()})
            no_unk_namespace {list} -- no unknown namespace (default: {list()})
            contain_pad_namespace {dict} -- contain padding token namespace (default: {dict()})
            contain_unk_namespace {dict} -- contain unknown token namespace (default: {dict()})
        """

        self.no_pad_namespace.update(set(no_pad_namespace))
        self.no_unk_namespace.update(set(no_unk_namespace))
        self.contain_pad_namespace.update(dict(contain_pad_namespace))
        self.contain_unk_namespace.update(dict(contain_unk_namespace))
        self.min_count.update(dict(min_count))

        for namespace, counter in counters.items():
            self.__namespace_init(namespace)
            for key in counter:
                minc = min_count[namespace] \
                    if min_count and namespace in min_count else 1
                if counter[key] >= minc:
                    self.vocab[namespace][key] = len(self.vocab[namespace])

            logger.info("Vocabulay {} (size: {}) was constructed successfully from counter.".format(
                namespace, len(self.vocab[namespace])))

    def add_tokens_to_namespace(self, tokens, namespace):
        """This function adds tokens to one namespace for extending vocabulary

        Arguments:
            tokens {list} -- token list
            namespace {str} -- namespace name
        """

        if namespace not in self.vocab:
            self.__namespace_init(namespace)
            logger.error('Add Namespace {} into vocabulary.'.format(namespace))

        for token in tokens:
            if token not in self.vocab[namespace]:
                self.vocab[namespace][token] = len(self.vocab[namespace])

    def get_token_index(self, token, namespace):
        """This function gets token index in one namespace of vocabulary

        Arguments:
            token {str} -- token
            namespace {str} -- namespace name

        Raises:
            RuntimeError: namespace not exists

        Returns:
            int -- token index
        """

        if token in self.vocab[namespace]:
            return self.vocab[namespace][token]

        if namespace not in self.no_unk_namespace:
            return self.get_unknown_index(namespace)

        logger.error("Can not find the index of {} from a no unknown token namespace {}.".format(
            token, namespace))
        raise RuntimeError(
            "Can not find the index of {} from a no unknown token namespace {}.".format(
                token, namespace))

    def get_token_from_index(self, index, namespace):
        """This function gets token using index in vocabulary

        Arguments:
            index {int} -- index
            namespace {str} -- namespace name

        Raises:
            RuntimeError: index out of range

        Returns:
            str -- token
        """

        if index < len(self.vocab[namespace]):
            return self.vocab[namespace].inv[index]

        logger.error("The index {} is out of vocabulary {} range.".format(index, namespace))
        raise RuntimeError("The index {} is out of vocabulary {} range.".format(index, namespace))

    def get_vocab_size(self, namespace):
        """This function gets the size of one namespace in vocabulary

        Arguments:
            namespace {str} -- namespace name

        Returns:
            int -- vocabulary size
        """

        return len(self.vocab[namespace])

    def get_all_namespaces(self):
        """This function gets all namespaces

        Returns:
            list -- all namespaces vocabulary contained
        """

        return set(self.vocab)

    def get_padding_index(self, namespace):
        """This function gets padding token index in one namespace of vocabulary

        Arguments:
            namespace {str} -- namespace name

        Raises:
            RuntimeError: no padding

        Returns:
            int -- padding index
        """

        if namespace not in self.vocab:
            raise RuntimeError("Namespace {} doesn't exist.".format(namespace))

        if namespace not in self.no_pad_namespace:
            if namespace not in self.contain_pad_namespace:
                return self.vocab[namespace][Vocabulary.DEFAULT_PAD_TOKEN]
            return self.vocab[namespace][self.contain_pad_namespace[namespace]]

        logger.error("Namespace {} doesn't has paddding token.".format(namespace))
        raise RuntimeError("Namespace {} doesn't has paddding token.".format(namespace))

    def get_unknown_index(self, namespace):
        """This function gets unknown token index in one namespace of vocabulary

        Arguments:
            namespace {str} -- namespace name

        Raises:
            RuntimeError: no unknown

        Returns:
            int -- unknown index
        """

        if namespace not in self.vocab:
            raise RuntimeError("Namespace {} doesn't exist.".format(namespace))

        if namespace not in self.no_unk_namespace:
            if namespace not in self.contain_unk_namespace:
                return self.vocab[namespace][Vocabulary.DEFAULT_UNK_TOKEN]
            return self.vocab[namespace][self.contain_unk_namespace[namespace]]

        logger.error("Namespace {} doesn't has unknown token.".format(namespace))
        raise RuntimeError("Namespace {} doesn't has unknown token.".format(namespace))

    def get_namespace_tokens(self, namesapce):
        """This function returns all tokens in one namespace

        Arguments:
            namesapce {str} -- namespce name

        Returns:
            dict_keys -- all tokens
        """

        return self.vocab[namesapce]

    def save(self, file_path):
        """This function saves vocabulary into file

        Arguments:
            file_path {str} -- file path
        """

        pickle.dump(self, open(file_path, 'wb'))

    @classmethod
    def load(cls, file_path):
        """This function loads vocabulary from file

        Arguments:
            file_path {str} -- file path

        Returns:
            Vocabulary -- vocabulary
        """

        return pickle.load(open(file_path, 'rb'), encoding='utf-8')

    def __namespace_init(self, namespace):
        """This function initializes a namespace,
        adds pad and unk token to one namespace of vacabulary

        Arguments:
            namespace {str} -- namespace
        """

        self.vocab[namespace] = bidict()

        if namespace not in self.no_pad_namespace and namespace not in self.contain_pad_namespace:
            self.vocab[namespace][Vocabulary.DEFAULT_PAD_TOKEN] = len(self.vocab[namespace])

        if namespace not in self.no_unk_namespace and namespace not in self.contain_unk_namespace:
            self.vocab[namespace][Vocabulary.DEFAULT_UNK_TOKEN] = len(self.vocab[namespace])