|
|
|
|
|
from collections import defaultdict, OrderedDict |
|
import os |
|
from typing import Any, Callable, Dict, Iterable, List, Set |
|
|
|
|
|
def namespace_match(pattern: str, namespace: str): |
|
""" |
|
Matches a namespace pattern against a namespace string. For example, ``*tags`` matches |
|
``passage_tags`` and ``question_tags`` and ``tokens`` matches ``tokens`` but not |
|
``stemmed_tokens``. |
|
""" |
|
if pattern[0] == '*' and namespace.endswith(pattern[1:]): |
|
return True |
|
elif pattern == namespace: |
|
return True |
|
return False |
|
|
|
|
|
class _NamespaceDependentDefaultDict(defaultdict): |
|
def __init__(self, |
|
non_padded_namespaces: Set[str], |
|
padded_function: Callable[[], Any], |
|
non_padded_function: Callable[[], Any]) -> None: |
|
self._non_padded_namespaces = set(non_padded_namespaces) |
|
self._padded_function = padded_function |
|
self._non_padded_function = non_padded_function |
|
super(_NamespaceDependentDefaultDict, self).__init__() |
|
|
|
def __missing__(self, key: str): |
|
if any(namespace_match(pattern, key) for pattern in self._non_padded_namespaces): |
|
value = self._non_padded_function() |
|
else: |
|
value = self._padded_function() |
|
dict.__setitem__(self, key, value) |
|
return value |
|
|
|
def add_non_padded_namespaces(self, non_padded_namespaces: Set[str]): |
|
|
|
self._non_padded_namespaces.update(non_padded_namespaces) |
|
|
|
|
|
class _TokenToIndexDefaultDict(_NamespaceDependentDefaultDict): |
|
def __init__(self, non_padded_namespaces: Set[str], padding_token: str, oov_token: str) -> None: |
|
super(_TokenToIndexDefaultDict, self).__init__(non_padded_namespaces, |
|
lambda: {padding_token: 0, oov_token: 1}, |
|
lambda: {}) |
|
|
|
|
|
class _IndexToTokenDefaultDict(_NamespaceDependentDefaultDict): |
|
def __init__(self, non_padded_namespaces: Set[str], padding_token: str, oov_token: str) -> None: |
|
super(_IndexToTokenDefaultDict, self).__init__(non_padded_namespaces, |
|
lambda: {0: padding_token, 1: oov_token}, |
|
lambda: {}) |
|
|
|
|
|
DEFAULT_NON_PADDED_NAMESPACES = ("*tags", "*labels") |
|
DEFAULT_PADDING_TOKEN = '[PAD]' |
|
DEFAULT_OOV_TOKEN = '[UNK]' |
|
NAMESPACE_PADDING_FILE = 'non_padded_namespaces.txt' |
|
|
|
|
|
class Vocabulary(object): |
|
def __init__(self, non_padded_namespaces: Iterable[str] = DEFAULT_NON_PADDED_NAMESPACES): |
|
self._non_padded_namespaces = set(non_padded_namespaces) |
|
self._padding_token = DEFAULT_PADDING_TOKEN |
|
self._oov_token = DEFAULT_OOV_TOKEN |
|
self._token_to_index = _TokenToIndexDefaultDict(self._non_padded_namespaces, |
|
self._padding_token, |
|
self._oov_token) |
|
self._index_to_token = _IndexToTokenDefaultDict(self._non_padded_namespaces, |
|
self._padding_token, |
|
self._oov_token) |
|
|
|
def add_token_to_namespace(self, token: str, namespace: str = 'tokens') -> int: |
|
if token not in self._token_to_index[namespace]: |
|
index = len(self._token_to_index[namespace]) |
|
self._token_to_index[namespace][token] = index |
|
self._index_to_token[namespace][index] = token |
|
return index |
|
else: |
|
return self._token_to_index[namespace][token] |
|
|
|
def get_index_to_token_vocabulary(self, namespace: str = 'tokens') -> Dict[int, str]: |
|
return self._index_to_token[namespace] |
|
|
|
def get_token_to_index_vocabulary(self, namespace: str = 'tokens') -> Dict[str, int]: |
|
return self._token_to_index[namespace] |
|
|
|
def get_token_index(self, token: str, namespace: str = 'tokens') -> int: |
|
if token in self._token_to_index[namespace]: |
|
return self._token_to_index[namespace][token] |
|
else: |
|
return self._token_to_index[namespace][self._oov_token] |
|
|
|
def get_token_from_index(self, index: int, namespace: str = 'tokens'): |
|
return self._index_to_token[namespace][index] |
|
|
|
def get_vocab_size(self, namespace: str = 'tokens') -> int: |
|
return len(self._token_to_index[namespace]) |
|
|
|
def save_to_files(self, directory: str): |
|
os.makedirs(directory, exist_ok=True) |
|
with open(os.path.join(directory, NAMESPACE_PADDING_FILE), 'w', encoding='utf-8') as f: |
|
for namespace_str in self._non_padded_namespaces: |
|
f.write('{}\n'.format(namespace_str)) |
|
|
|
for namespace, token_to_index in self._token_to_index.items(): |
|
filename = os.path.join(directory, '{}.txt'.format(namespace)) |
|
with open(filename, 'w', encoding='utf-8') as f: |
|
for token, _ in token_to_index.items(): |
|
f.write('{}\n'.format(token)) |
|
|
|
@classmethod |
|
def from_files(cls, directory: str) -> 'Vocabulary': |
|
with open(os.path.join(directory, NAMESPACE_PADDING_FILE), 'r', encoding='utf-8') as f: |
|
non_padded_namespaces = [namespace_str.strip() for namespace_str in f] |
|
|
|
vocab = cls(non_padded_namespaces=non_padded_namespaces) |
|
|
|
for namespace_filename in os.listdir(directory): |
|
if namespace_filename == NAMESPACE_PADDING_FILE: |
|
continue |
|
if namespace_filename.startswith("."): |
|
continue |
|
namespace = namespace_filename.replace('.txt', '') |
|
if any(namespace_match(pattern, namespace) for pattern in non_padded_namespaces): |
|
is_padded = False |
|
else: |
|
is_padded = True |
|
filename = os.path.join(directory, namespace_filename) |
|
vocab.set_from_file(filename, is_padded, namespace=namespace) |
|
|
|
return vocab |
|
|
|
def set_from_file(self, |
|
filename: str, |
|
is_padded: bool = True, |
|
oov_token: str = DEFAULT_OOV_TOKEN, |
|
namespace: str = "tokens" |
|
): |
|
if is_padded: |
|
self._token_to_index[namespace] = {self._padding_token: 0} |
|
self._index_to_token[namespace] = {0: self._padding_token} |
|
else: |
|
self._token_to_index[namespace] = {} |
|
self._index_to_token[namespace] = {} |
|
|
|
with open(filename, 'r', encoding='utf-8') as f: |
|
index = 1 if is_padded else 0 |
|
for row in f: |
|
token = str(row).strip() |
|
if token == oov_token: |
|
token = self._oov_token |
|
self._token_to_index[namespace][token] = index |
|
self._index_to_token[namespace][index] = token |
|
index += 1 |
|
|
|
def convert_tokens_to_ids(self, tokens: List[str], namespace: str = "tokens"): |
|
result = list() |
|
for token in tokens: |
|
idx = self._token_to_index[namespace].get(token) |
|
if idx is None: |
|
idx = self._token_to_index[namespace][self._oov_token] |
|
result.append(idx) |
|
return result |
|
|
|
def convert_ids_to_tokens(self, ids: List[int], namespace: str = "tokens"): |
|
result = list() |
|
for idx in ids: |
|
idx = self._index_to_token[namespace][idx] |
|
result.append(idx) |
|
return result |
|
|
|
def pad_or_truncate_ids_by_max_length(self, ids: List[int], max_length: int, namespace: str = "tokens"): |
|
pad_idx = self._token_to_index[namespace][self._padding_token] |
|
|
|
length = len(ids) |
|
if length > max_length: |
|
result = ids[:max_length] |
|
else: |
|
result = ids + [pad_idx] * (max_length - length) |
|
return result |
|
|
|
|
|
def demo1(): |
|
import jieba |
|
|
|
vocabulary = Vocabulary() |
|
vocabulary.add_token_to_namespace('白天', 'tokens') |
|
vocabulary.add_token_to_namespace('晚上', 'tokens') |
|
|
|
text = '不是在白天, 就是在晚上' |
|
tokens = jieba.lcut(text) |
|
|
|
print(tokens) |
|
|
|
ids = vocabulary.convert_tokens_to_ids(tokens) |
|
print(ids) |
|
|
|
padded_idx = vocabulary.pad_or_truncate_ids_by_max_length(ids, 10) |
|
print(padded_idx) |
|
|
|
tokens = vocabulary.convert_ids_to_tokens(padded_idx) |
|
print(tokens) |
|
return |
|
|
|
|
|
if __name__ == '__main__': |
|
demo1() |
|
|