import codecs from collections import defaultdict import logging import os import re from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Union, TYPE_CHECKING from filelock import FileLock logger = logging.getLogger(__name__) DEFAULT_NON_PADDED_NAMESPACES = ("*tags", "*labels") DEFAULT_PADDING_TOKEN = "@@PADDING@@" DEFAULT_OOV_TOKEN = "@@UNKNOWN@@" NAMESPACE_PADDING_FILE = "non_padded_namespaces.txt" _NEW_LINE_REGEX = re.compile(r"\n|\r\n") def namespace_match(pattern: str, namespace: str): """ Matches a namespace pattern against a namespace string. For example, `*tags` matches `passage_tags` and `question_tags` and `tokens` matches `tokens` but not `stemmed_tokens`. """ if pattern[0] == "*" and namespace.endswith(pattern[1:]): return True elif pattern == namespace: return True return False class _NamespaceDependentDefaultDict(defaultdict): """ This is a [defaultdict] (https://docs.python.org/2/library/collections.html#collections.defaultdict) where the default value is dependent on the key that is passed. We use "namespaces" in the :class:`Vocabulary` object to keep track of several different mappings from strings to integers, so that we have a consistent API for mapping words, tags, labels, characters, or whatever else you want, into integers. The issue is that some of those namespaces (words and characters) should have integers reserved for padding and out-of-vocabulary tokens, while others (labels and tags) shouldn't. This class allows you to specify filters on the namespace (the key used in the `defaultdict`), and use different default values depending on whether the namespace passes the filter. To do filtering, we take a set of `non_padded_namespaces`. This is a set of strings that are either matched exactly against the keys, or treated as suffixes, if the string starts with `*`. In other words, if `*tags` is in `non_padded_namespaces` then `passage_tags`, `question_tags`, etc. (anything that ends with `tags`) will have the `non_padded` default value. # Parameters non_padded_namespaces : `Iterable[str]` A set / list / tuple of strings describing which namespaces are not padded. If a namespace (key) is missing from this dictionary, we will use :func:`namespace_match` to see whether the namespace should be padded. If the given namespace matches any of the strings in this list, we will use `non_padded_function` to initialize the value for that namespace, and we will use `padded_function` otherwise. padded_function : `Callable[[], Any]` A zero-argument function to call to initialize a value for a namespace that `should` be padded. non_padded_function : `Callable[[], Any]` A zero-argument function to call to initialize a value for a namespace that should `not` be padded. """ def __init__( self, non_padded_namespaces: Iterable[str], padded_function: Callable[[], Any], non_padded_function: Callable[[], Any], ) -> None: self._non_padded_namespaces = set(non_padded_namespaces) self._padded_function = padded_function self._non_padded_function = non_padded_function super().__init__() def add_non_padded_namespaces(self, non_padded_namespaces: Set[str]): # add non_padded_namespaces which weren't already present self._non_padded_namespaces.update(non_padded_namespaces) class _TokenToIndexDefaultDict(_NamespaceDependentDefaultDict): def __init__(self, non_padded_namespaces: Set[str], padding_token: str, oov_token: str) -> None: super().__init__(non_padded_namespaces, lambda: {padding_token: 0, oov_token: 1}, lambda: {}) class _IndexToTokenDefaultDict(_NamespaceDependentDefaultDict): def __init__(self, non_padded_namespaces: Set[str], padding_token: str, oov_token: str) -> None: super().__init__(non_padded_namespaces, lambda: {0: padding_token, 1: oov_token}, lambda: {}) class Vocabulary: def __init__( self, counter: Dict[str, Dict[str, int]] = None, min_count: Dict[str, int] = None, max_vocab_size: Union[int, Dict[str, int]] = None, non_padded_namespaces: Iterable[str] = DEFAULT_NON_PADDED_NAMESPACES, pretrained_files: Optional[Dict[str, str]] = None, only_include_pretrained_words: bool = False, tokens_to_add: Dict[str, List[str]] = None, min_pretrained_embeddings: Dict[str, int] = None, padding_token: Optional[str] = DEFAULT_PADDING_TOKEN, oov_token: Optional[str] = DEFAULT_OOV_TOKEN, ) -> None: self._padding_token = padding_token if padding_token is not None else DEFAULT_PADDING_TOKEN self._oov_token = oov_token if oov_token is not None else DEFAULT_OOV_TOKEN self._non_padded_namespaces = set(non_padded_namespaces) self._token_to_index = _TokenToIndexDefaultDict( self._non_padded_namespaces, self._padding_token, self._oov_token ) self._index_to_token = _IndexToTokenDefaultDict( self._non_padded_namespaces, self._padding_token, self._oov_token ) @classmethod def from_files( cls, directory: Union[str, os.PathLike], padding_token: Optional[str] = DEFAULT_PADDING_TOKEN, oov_token: Optional[str] = DEFAULT_OOV_TOKEN, ) -> "Vocabulary": """ Loads a `Vocabulary` that was serialized either using `save_to_files` or inside a model archive file. # Parameters directory : `str` The directory or archive file containing the serialized vocabulary. """ logger.info("Loading token dictionary from %s.", directory) padding_token = padding_token if padding_token is not None else DEFAULT_PADDING_TOKEN oov_token = oov_token if oov_token is not None else DEFAULT_OOV_TOKEN if not os.path.isdir(directory): raise ValueError(f"{directory} not exist") # We use a lock file to avoid race conditions where multiple processes # might be reading/writing from/to the same vocab files at once. with FileLock(os.path.join(directory, ".lock")): with codecs.open(os.path.join(directory, NAMESPACE_PADDING_FILE), "r", "utf-8") as namespace_file: non_padded_namespaces = [namespace_str.strip() for namespace_str in namespace_file] vocab = cls( non_padded_namespaces=non_padded_namespaces, padding_token=padding_token, oov_token=oov_token, ) # Check every file in the directory. for namespace_filename in os.listdir(directory): if namespace_filename == NAMESPACE_PADDING_FILE: continue if namespace_filename.startswith("."): continue namespace = namespace_filename.replace(".txt", "") if any(namespace_match(pattern, namespace) for pattern in non_padded_namespaces): is_padded = False else: is_padded = True filename = os.path.join(directory, namespace_filename) vocab.set_from_file(filename, is_padded, namespace=namespace, oov_token=oov_token) return vocab @classmethod def empty(cls) -> "Vocabulary": """ This method returns a bare vocabulary instantiated with `cls()` (so, `Vocabulary()` if you haven't made a subclass of this object). The only reason to call `Vocabulary.empty()` instead of `Vocabulary()` is if you are instantiating this object from a config file. We register this constructor with the key "empty", so if you know that you don't need to compute a vocabulary (either because you're loading a pre-trained model from an archive file, you're using a pre-trained transformer that has its own vocabulary, or something else), you can use this to avoid having the default vocabulary construction code iterate through the data. """ return cls() def set_from_file( self, filename: str, is_padded: bool = True, oov_token: str = DEFAULT_OOV_TOKEN, namespace: str = "tokens", ): """ If you already have a vocabulary file for a trained model somewhere, and you really want to use that vocabulary file instead of just setting the vocabulary from a dataset, for whatever reason, you can do that with this method. You must specify the namespace to use, and we assume that you want to use padding and OOV tokens for this. # Parameters filename : `str` The file containing the vocabulary to load. It should be formatted as one token per line, with nothing else in the line. The index we assign to the token is the line number in the file (1-indexed if `is_padded`, 0-indexed otherwise). Note that this file should contain the OOV token string! is_padded : `bool`, optional (default=`True`) Is this vocabulary padded? For token / word / character vocabularies, this should be `True`; while for tag or label vocabularies, this should typically be `False`. If `True`, we add a padding token with index 0, and we enforce that the `oov_token` is present in the file. oov_token : `str`, optional (default=`DEFAULT_OOV_TOKEN`) What token does this vocabulary use to represent out-of-vocabulary characters? This must show up as a line in the vocabulary file. When we find it, we replace `oov_token` with `self._oov_token`, because we only use one OOV token across namespaces. namespace : `str`, optional (default=`"tokens"`) What namespace should we overwrite with this vocab file? """ if is_padded: self._token_to_index[namespace] = {self._padding_token: 0} self._index_to_token[namespace] = {0: self._padding_token} else: self._token_to_index[namespace] = {} self._index_to_token[namespace] = {} with codecs.open(filename, "r", "utf-8") as input_file: lines = _NEW_LINE_REGEX.split(input_file.read()) # Be flexible about having final newline or not if lines and lines[-1] == "": lines = lines[:-1] for i, line in enumerate(lines): index = i + 1 if is_padded else i token = line.replace("@@NEWLINE@@", "\n") if token == oov_token: token = self._oov_token self._token_to_index[namespace][token] = index self._index_to_token[namespace][index] = token if is_padded: assert self._oov_token in self._token_to_index[namespace], "OOV token not found!" def add_token_to_namespace(self, token: str, namespace: str = "tokens") -> int: """ Adds `token` to the index, if it is not already present. Either way, we return the index of the token. """ if not isinstance(token, str): raise ValueError( "Vocabulary tokens must be strings, or saving and loading will break." " Got %s (with type %s)" % (repr(token), type(token)) ) if token not in self._token_to_index[namespace]: index = len(self._token_to_index[namespace]) self._token_to_index[namespace][token] = index self._index_to_token[namespace][index] = token return index else: return self._token_to_index[namespace][token] def add_tokens_to_namespace(self, tokens: List[str], namespace: str = "tokens") -> List[int]: """ Adds `tokens` to the index, if they are not already present. Either way, we return the indices of the tokens in the order that they were given. """ return [self.add_token_to_namespace(token, namespace) for token in tokens] def get_token_index(self, token: str, namespace: str = "tokens") -> int: try: return self._token_to_index[namespace][token] except KeyError: try: return self._token_to_index[namespace][self._oov_token] except KeyError: logger.error("Namespace: %s", namespace) logger.error("Token: %s", token) raise KeyError( f"'{token}' not found in vocab namespace '{namespace}', and namespace " f"does not contain the default OOV token ('{self._oov_token}')" ) def get_token_from_index(self, index: int, namespace: str = "tokens") -> str: return self._index_to_token[namespace][index] def get_vocab_size(self, namespace: str = "tokens") -> int: return len(self._token_to_index[namespace]) def get_namespaces(self) -> Set[str]: return set(self._index_to_token.keys())