File size: 11,732 Bytes
4fb0bd1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 |
from bidict import bidict
import pickle
import logging
logger = logging.getLogger(__name__)
class Vocabulary():
"""This class maps strings to integers, which also allow many namespaces
"""
DEFAULT_PAD_TOKEN = '*@PAD@*'
DEFAULT_UNK_TOKEN = '*@UNK@*'
def __init__(self,
counters=dict(),
min_count=dict(),
pretrained_vocab=dict(),
intersection_namespace=dict(),
no_pad_namespace=list(),
no_unk_namespace=list(),
contain_pad_namespace=dict(),
contain_unk_namespace=dict()):
"""initialize vocabulary
Keyword Arguments:
counters {dict} -- multiple counter (default: {dict()})
min_count {dict} -- min count dict (default: {dict()})
pretrained_vocab {dict} -- pretrained vocabulary (default: {dict()})
intersection_namespace {dict} -- intersection namespace correspond to pretrained vocabulary in case of too large pretrained vocabulary (default: {dict()})
no_pad_namespace {list} -- no paddding namespace (default: {list()})
no_unk_namespace {list} -- no unknown namespace (default: {list()})
contain_pad_namespace {dict} -- contain padding token namespace (default: {dict()})
contain_unk_namespace {dict} -- contain unknown token namespace (default: {dict()})
"""
self.min_count = dict(min_count)
self.intersection_namespace = dict(intersection_namespace)
self.no_pad_namespace = set(no_pad_namespace)
self.no_unk_namespace = set(no_unk_namespace)
self.contain_pad_namespace = dict(contain_pad_namespace)
self.contain_unk_namespace = dict(contain_unk_namespace)
self.vocab = dict()
self.extend_from_counter(counters, self.min_count, self.no_pad_namespace,
self.no_unk_namespace)
self.extend_from_pretrained_vocab(pretrained_vocab, self.intersection_namespace,
self.no_pad_namespace, self.no_unk_namespace)
logger.info("Initialize vocabulary successfully.")
def extend_from_pretrained_vocab(self,
pretrained_vocab,
intersection_namespace=dict(),
no_pad_namespace=list(),
no_unk_namespace=list(),
contain_pad_namespace=dict(),
contain_unk_namespace=dict()):
"""extend vocabulary from pretrained vocab
Arguments:
pretrained_vocab {dict} -- pretrained vocabulary
Keyword Arguments:
intersection_namespace {dict} -- intersection namespace correspond to pretrained vocabulary in case of too large pretrained vocabulary (default: {dict()})
no_pad_namespace {list} -- no paddding namespace (default: {list()})
no_unk_namespace {list} -- no unknown namespace (default: {list()})
contain_pad_namespace {dict} -- contain padding token namespace (default: {dict()})
contain_unk_namespace {dict} -- contain unknown token namespace (default: {dict()})
"""
self.intersection_namespace.update(dict(intersection_namespace))
self.no_pad_namespace.update(set(no_pad_namespace))
self.no_unk_namespace.update(set(no_unk_namespace))
self.contain_pad_namespace.update(dict(contain_pad_namespace))
self.contain_unk_namespace.update(dict(contain_unk_namespace))
for namespace, vocab in pretrained_vocab.items():
self.__namespace_init(namespace)
is_intersection = namespace in self.intersection_namespace
intersection_vocab = self.vocab[
self.intersection_namespace[namespace]] if is_intersection else []
for key, value in vocab.items():
if not is_intersection or key in intersection_vocab:
self.vocab[namespace][key] = value
logger.info(
"Vocabulay {} (size: {}) was constructed successfully from pretrained_vocab.".
format(namespace, len(self.vocab[namespace])))
def extend_from_counter(self,
counters,
min_count=dict(),
no_pad_namespace=list(),
no_unk_namespace=list(),
contain_pad_namespace=dict(),
contain_unk_namespace=dict()):
"""extend vocabulary from counter
Arguments:
counters {dict} -- multiply counter
Keyword Arguments:
min_count {dict} -- min count dict (default: {dict()})
no_pad_namespace {list} -- no paddding namespace (default: {list()})
no_unk_namespace {list} -- no unknown namespace (default: {list()})
contain_pad_namespace {dict} -- contain padding token namespace (default: {dict()})
contain_unk_namespace {dict} -- contain unknown token namespace (default: {dict()})
"""
self.no_pad_namespace.update(set(no_pad_namespace))
self.no_unk_namespace.update(set(no_unk_namespace))
self.contain_pad_namespace.update(dict(contain_pad_namespace))
self.contain_unk_namespace.update(dict(contain_unk_namespace))
self.min_count.update(dict(min_count))
for namespace, counter in counters.items():
self.__namespace_init(namespace)
for key in counter:
minc = min_count[namespace] \
if min_count and namespace in min_count else 1
if counter[key] >= minc:
self.vocab[namespace][key] = len(self.vocab[namespace])
logger.info("Vocabulay {} (size: {}) was constructed successfully from counter.".format(
namespace, len(self.vocab[namespace])))
def add_tokens_to_namespace(self, tokens, namespace):
"""This function adds tokens to one namespace for extending vocabulary
Arguments:
tokens {list} -- token list
namespace {str} -- namespace name
"""
if namespace not in self.vocab:
self.__namespace_init(namespace)
logger.error('Add Namespace {} into vocabulary.'.format(namespace))
for token in tokens:
if token not in self.vocab[namespace]:
self.vocab[namespace][token] = len(self.vocab[namespace])
def get_token_index(self, token, namespace):
"""This function gets token index in one namespace of vocabulary
Arguments:
token {str} -- token
namespace {str} -- namespace name
Raises:
RuntimeError: namespace not exists
Returns:
int -- token index
"""
if token in self.vocab[namespace]:
return self.vocab[namespace][token]
if namespace not in self.no_unk_namespace:
return self.get_unknown_index(namespace)
logger.error("Can not find the index of {} from a no unknown token namespace {}.".format(
token, namespace))
raise RuntimeError(
"Can not find the index of {} from a no unknown token namespace {}.".format(
token, namespace))
def get_token_from_index(self, index, namespace):
"""This function gets token using index in vocabulary
Arguments:
index {int} -- index
namespace {str} -- namespace name
Raises:
RuntimeError: index out of range
Returns:
str -- token
"""
if index < len(self.vocab[namespace]):
return self.vocab[namespace].inv[index]
logger.error("The index {} is out of vocabulary {} range.".format(index, namespace))
raise RuntimeError("The index {} is out of vocabulary {} range.".format(index, namespace))
def get_vocab_size(self, namespace):
"""This function gets the size of one namespace in vocabulary
Arguments:
namespace {str} -- namespace name
Returns:
int -- vocabulary size
"""
return len(self.vocab[namespace])
def get_all_namespaces(self):
"""This function gets all namespaces
Returns:
list -- all namespaces vocabulary contained
"""
return set(self.vocab)
def get_padding_index(self, namespace):
"""This function gets padding token index in one namespace of vocabulary
Arguments:
namespace {str} -- namespace name
Raises:
RuntimeError: no padding
Returns:
int -- padding index
"""
if namespace not in self.vocab:
raise RuntimeError("Namespace {} doesn't exist.".format(namespace))
if namespace not in self.no_pad_namespace:
if namespace not in self.contain_pad_namespace:
return self.vocab[namespace][Vocabulary.DEFAULT_PAD_TOKEN]
return self.vocab[namespace][self.contain_pad_namespace[namespace]]
logger.error("Namespace {} doesn't has paddding token.".format(namespace))
raise RuntimeError("Namespace {} doesn't has paddding token.".format(namespace))
def get_unknown_index(self, namespace):
"""This function gets unknown token index in one namespace of vocabulary
Arguments:
namespace {str} -- namespace name
Raises:
RuntimeError: no unknown
Returns:
int -- unknown index
"""
if namespace not in self.vocab:
raise RuntimeError("Namespace {} doesn't exist.".format(namespace))
if namespace not in self.no_unk_namespace:
if namespace not in self.contain_unk_namespace:
return self.vocab[namespace][Vocabulary.DEFAULT_UNK_TOKEN]
return self.vocab[namespace][self.contain_unk_namespace[namespace]]
logger.error("Namespace {} doesn't has unknown token.".format(namespace))
raise RuntimeError("Namespace {} doesn't has unknown token.".format(namespace))
def get_namespace_tokens(self, namesapce):
"""This function returns all tokens in one namespace
Arguments:
namesapce {str} -- namespce name
Returns:
dict_keys -- all tokens
"""
return self.vocab[namesapce]
def save(self, file_path):
"""This function saves vocabulary into file
Arguments:
file_path {str} -- file path
"""
pickle.dump(self, open(file_path, 'wb'))
@classmethod
def load(cls, file_path):
"""This function loads vocabulary from file
Arguments:
file_path {str} -- file path
Returns:
Vocabulary -- vocabulary
"""
return pickle.load(open(file_path, 'rb'), encoding='utf-8')
def __namespace_init(self, namespace):
"""This function initializes a namespace,
adds pad and unk token to one namespace of vacabulary
Arguments:
namespace {str} -- namespace
"""
self.vocab[namespace] = bidict()
if namespace not in self.no_pad_namespace and namespace not in self.contain_pad_namespace:
self.vocab[namespace][Vocabulary.DEFAULT_PAD_TOKEN] = len(self.vocab[namespace])
if namespace not in self.no_unk_namespace and namespace not in self.contain_unk_namespace:
self.vocab[namespace][Vocabulary.DEFAULT_UNK_TOKEN] = len(self.vocab[namespace])
|