File size: 11,732 Bytes
4fb0bd1 |
|
from bidict import bidict
import pickle
import logging
logger = logging.getLogger(__name__)
class Vocabulary():
"""This class maps strings to integers, which also allow many namespaces
"""
DEFAULT_PAD_TOKEN = '*@PAD@*'
DEFAULT_UNK_TOKEN = '*@UNK@*'
def __init__(self,
counters=dict(),
min_count=dict(),
pretrained_vocab=dict(),
intersection_namespace=dict(),
no_pad_namespace=list(),
no_unk_namespace=list(),
contain_pad_namespace=dict(),
contain_unk_namespace=dict()):
"""initialize vocabulary
Keyword Arguments:
counters {dict} -- multiple counter (default: {dict()})
min_count {dict} -- min count dict (default: {dict()})
pretrained_vocab {dict} -- pretrained vocabulary (default: {dict()})
intersection_namespace {dict} -- intersection namespace correspond to pretrained vocabulary in case of too large pretrained vocabulary (default: {dict()})
no_pad_namespace {list} -- no paddding namespace (default: {list()})
no_unk_namespace {list} -- no unknown namespace (default: {list()})
contain_pad_namespace {dict} -- contain padding token namespace (default: {dict()})
contain_unk_namespace {dict} -- contain unknown token namespace (default: {dict()})
"""
self.min_count = dict(min_count)
self.intersection_namespace = dict(intersection_namespace)
self.no_pad_namespace = set(no_pad_namespace)
self.no_unk_namespace = set(no_unk_namespace)
self.contain_pad_namespace = dict(contain_pad_namespace)
self.contain_unk_namespace = dict(contain_unk_namespace)
self.vocab = dict()
self.extend_from_counter(counters, self.min_count, self.no_pad_namespace,
self.no_unk_namespace)
self.extend_from_pretrained_vocab(pretrained_vocab, self.intersection_namespace,
self.no_pad_namespace, self.no_unk_namespace)
logger.info("Initialize vocabulary successfully.")
def extend_from_pretrained_vocab(self,
pretrained_vocab,
intersection_namespace=dict(),
no_pad_namespace=list(),
no_unk_namespace=list(),
contain_pad_namespace=dict(),
contain_unk_namespace=dict()):
"""extend vocabulary from pretrained vocab
Arguments:
pretrained_vocab {dict} -- pretrained vocabulary
Keyword Arguments:
intersection_namespace {dict} -- intersection namespace correspond to pretrained vocabulary in case of too large pretrained vocabulary (default: {dict()})
no_pad_namespace {list} -- no paddding namespace (default: {list()})
no_unk_namespace {list} -- no unknown namespace (default: {list()})
contain_pad_namespace {dict} -- contain padding token namespace (default: {dict()})
contain_unk_namespace {dict} -- contain unknown token namespace (default: {dict()})
"""
self.intersection_namespace.update(dict(intersection_namespace))
self.no_pad_namespace.update(set(no_pad_namespace))
self.no_unk_namespace.update(set(no_unk_namespace))
self.contain_pad_namespace.update(dict(contain_pad_namespace))
self.contain_unk_namespace.update(dict(contain_unk_namespace))
for namespace, vocab in pretrained_vocab.items():
self.__namespace_init(namespace)
is_intersection = namespace in self.intersection_namespace
intersection_vocab = self.vocab[
self.intersection_namespace[namespace]] if is_intersection else []
for key, value in vocab.items():
if not is_intersection or key in intersection_vocab:
self.vocab[namespace][key] = value
logger.info(
"Vocabulay {} (size: {}) was constructed successfully from pretrained_vocab.".
format(namespace, len(self.vocab[namespace])))
def extend_from_counter(self,
counters,
min_count=dict(),
no_pad_namespace=list(),
no_unk_namespace=list(),
contain_pad_namespace=dict(),
contain_unk_namespace=dict()):
"""extend vocabulary from counter
Arguments:
counters {dict} -- multiply counter
Keyword Arguments:
min_count {dict} -- min count dict (default: {dict()})
no_pad_namespace {list} -- no paddding namespace (default: {list()})
no_unk_namespace {list} -- no unknown namespace (default: {list()})
contain_pad_namespace {dict} -- contain padding token namespace (default: {dict()})
contain_unk_namespace {dict} -- contain unknown token namespace (default: {dict()})
"""
self.no_pad_namespace.update(set(no_pad_namespace))
self.no_unk_namespace.update(set(no_unk_namespace))
self.contain_pad_namespace.update(dict(contain_pad_namespace))
self.contain_unk_namespace.update(dict(contain_unk_namespace))
self.min_count.update(dict(min_count))
for namespace, counter in counters.items():
self.__namespace_init(namespace)
for key in counter:
minc = min_count[namespace] \
if min_count and namespace in min_count else 1
if counter[key] >= minc:
self.vocab[namespace][key] = len(self.vocab[namespace])
logger.info("Vocabulay {} (size: {}) was constructed successfully from counter.".format(
namespace, len(self.vocab[namespace])))
def add_tokens_to_namespace(self, tokens, namespace):
"""This function adds tokens to one namespace for extending vocabulary
Arguments:
tokens {list} -- token list
namespace {str} -- namespace name
"""
if namespace not in self.vocab:
self.__namespace_init(namespace)
logger.error('Add Namespace {} into vocabulary.'.format(namespace))
for token in tokens:
if token not in self.vocab[namespace]:
self.vocab[namespace][token] = len(self.vocab[namespace])
def get_token_index(self, token, namespace):
"""This function gets token index in one namespace of vocabulary
Arguments:
token {str} -- token
namespace {str} -- namespace name
Raises:
RuntimeError: namespace not exists
Returns:
int -- token index
"""
if token in self.vocab[namespace]:
return self.vocab[namespace][token]
if namespace not in self.no_unk_namespace:
return self.get_unknown_index(namespace)
logger.error("Can not find the index of {} from a no unknown token namespace {}.".format(
token, namespace))
raise RuntimeError(
"Can not find the index of {} from a no unknown token namespace {}.".format(
token, namespace))
def get_token_from_index(self, index, namespace):
"""This function gets token using index in vocabulary
Arguments:
index {int} -- index
namespace {str} -- namespace name
Raises:
RuntimeError: index out of range
Returns:
str -- token
"""
if index < len(self.vocab[namespace]):
return self.vocab[namespace].inv[index]
logger.error("The index {} is out of vocabulary {} range.".format(index, namespace))
raise RuntimeError("The index {} is out of vocabulary {} range.".format(index, namespace))
def get_vocab_size(self, namespace):
"""This function gets the size of one namespace in vocabulary
Arguments:
namespace {str} -- namespace name
Returns:
int -- vocabulary size
"""
return len(self.vocab[namespace])
def get_all_namespaces(self):
"""This function gets all namespaces
Returns:
list -- all namespaces vocabulary contained
"""
return set(self.vocab)
def get_padding_index(self, namespace):
"""This function gets padding token index in one namespace of vocabulary
Arguments:
namespace {str} -- namespace name
Raises:
RuntimeError: no padding
Returns:
int -- padding index
"""
if namespace not in self.vocab:
raise RuntimeError("Namespace {} doesn't exist.".format(namespace))
if namespace not in self.no_pad_namespace:
if namespace not in self.contain_pad_namespace:
return self.vocab[namespace][Vocabulary.DEFAULT_PAD_TOKEN]
return self.vocab[namespace][self.contain_pad_namespace[namespace]]
logger.error("Namespace {} doesn't has paddding token.".format(namespace))
raise RuntimeError("Namespace {} doesn't has paddding token.".format(namespace))
def get_unknown_index(self, namespace):
"""This function gets unknown token index in one namespace of vocabulary
Arguments:
namespace {str} -- namespace name
Raises:
RuntimeError: no unknown
Returns:
int -- unknown index
"""
if namespace not in self.vocab:
raise RuntimeError("Namespace {} doesn't exist.".format(namespace))
if namespace not in self.no_unk_namespace:
if namespace not in self.contain_unk_namespace:
return self.vocab[namespace][Vocabulary.DEFAULT_UNK_TOKEN]
return self.vocab[namespace][self.contain_unk_namespace[namespace]]
logger.error("Namespace {} doesn't has unknown token.".format(namespace))
raise RuntimeError("Namespace {} doesn't has unknown token.".format(namespace))
def get_namespace_tokens(self, namesapce):
"""This function returns all tokens in one namespace
Arguments:
namesapce {str} -- namespce name
Returns:
dict_keys -- all tokens
"""
return self.vocab[namesapce]
def save(self, file_path):
"""This function saves vocabulary into file
Arguments:
file_path {str} -- file path
"""
pickle.dump(self, open(file_path, 'wb'))
@classmethod
def load(cls, file_path):
"""This function loads vocabulary from file
Arguments:
file_path {str} -- file path
Returns:
Vocabulary -- vocabulary
"""
return pickle.load(open(file_path, 'rb'), encoding='utf-8')
def __namespace_init(self, namespace):
"""This function initializes a namespace,
adds pad and unk token to one namespace of vacabulary
Arguments:
namespace {str} -- namespace
"""
self.vocab[namespace] = bidict()
if namespace not in self.no_pad_namespace and namespace not in self.contain_pad_namespace:
self.vocab[namespace][Vocabulary.DEFAULT_PAD_TOKEN] = len(self.vocab[namespace])
if namespace not in self.no_unk_namespace and namespace not in self.contain_unk_namespace:
self.vocab[namespace][Vocabulary.DEFAULT_UNK_TOKEN] = len(self.vocab[namespace])
|