HoneyTian's picture
update
f0a00b6
#!/usr/bin/python3
# -*- coding: utf-8 -*-
from collections import defaultdict, OrderedDict
import os
from typing import Any, Callable, Dict, Iterable, List, Set
def namespace_match(pattern: str, namespace: str):
"""
Matches a namespace pattern against a namespace string. For example, ``*tags`` matches
``passage_tags`` and ``question_tags`` and ``tokens`` matches ``tokens`` but not
``stemmed_tokens``.
"""
if pattern[0] == '*' and namespace.endswith(pattern[1:]):
return True
elif pattern == namespace:
return True
return False
class _NamespaceDependentDefaultDict(defaultdict):
def __init__(self,
non_padded_namespaces: Set[str],
padded_function: Callable[[], Any],
non_padded_function: Callable[[], Any]) -> None:
self._non_padded_namespaces = set(non_padded_namespaces)
self._padded_function = padded_function
self._non_padded_function = non_padded_function
super(_NamespaceDependentDefaultDict, self).__init__()
def __missing__(self, key: str):
if any(namespace_match(pattern, key) for pattern in self._non_padded_namespaces):
value = self._non_padded_function()
else:
value = self._padded_function()
dict.__setitem__(self, key, value)
return value
def add_non_padded_namespaces(self, non_padded_namespaces: Set[str]):
# add non_padded_namespaces which weren't already present
self._non_padded_namespaces.update(non_padded_namespaces)
class _TokenToIndexDefaultDict(_NamespaceDependentDefaultDict):
def __init__(self, non_padded_namespaces: Set[str], padding_token: str, oov_token: str) -> None:
super(_TokenToIndexDefaultDict, self).__init__(non_padded_namespaces,
lambda: {padding_token: 0, oov_token: 1},
lambda: {})
class _IndexToTokenDefaultDict(_NamespaceDependentDefaultDict):
def __init__(self, non_padded_namespaces: Set[str], padding_token: str, oov_token: str) -> None:
super(_IndexToTokenDefaultDict, self).__init__(non_padded_namespaces,
lambda: {0: padding_token, 1: oov_token},
lambda: {})
DEFAULT_NON_PADDED_NAMESPACES = ("*tags", "*labels")
DEFAULT_PADDING_TOKEN = '[PAD]'
DEFAULT_OOV_TOKEN = '[UNK]'
NAMESPACE_PADDING_FILE = 'non_padded_namespaces.txt'
class Vocabulary(object):
def __init__(self, non_padded_namespaces: Iterable[str] = DEFAULT_NON_PADDED_NAMESPACES):
self._non_padded_namespaces = set(non_padded_namespaces)
self._padding_token = DEFAULT_PADDING_TOKEN
self._oov_token = DEFAULT_OOV_TOKEN
self._token_to_index = _TokenToIndexDefaultDict(self._non_padded_namespaces,
self._padding_token,
self._oov_token)
self._index_to_token = _IndexToTokenDefaultDict(self._non_padded_namespaces,
self._padding_token,
self._oov_token)
def add_token_to_namespace(self, token: str, namespace: str = 'tokens') -> int:
if token not in self._token_to_index[namespace]:
index = len(self._token_to_index[namespace])
self._token_to_index[namespace][token] = index
self._index_to_token[namespace][index] = token
return index
else:
return self._token_to_index[namespace][token]
def get_index_to_token_vocabulary(self, namespace: str = 'tokens') -> Dict[int, str]:
return self._index_to_token[namespace]
def get_token_to_index_vocabulary(self, namespace: str = 'tokens') -> Dict[str, int]:
return self._token_to_index[namespace]
def get_token_index(self, token: str, namespace: str = 'tokens') -> int:
if token in self._token_to_index[namespace]:
return self._token_to_index[namespace][token]
else:
return self._token_to_index[namespace][self._oov_token]
def get_token_from_index(self, index: int, namespace: str = 'tokens'):
return self._index_to_token[namespace][index]
def get_vocab_size(self, namespace: str = 'tokens') -> int:
return len(self._token_to_index[namespace])
def save_to_files(self, directory: str):
os.makedirs(directory, exist_ok=True)
with open(os.path.join(directory, NAMESPACE_PADDING_FILE), 'w', encoding='utf-8') as f:
for namespace_str in self._non_padded_namespaces:
f.write('{}\n'.format(namespace_str))
for namespace, token_to_index in self._token_to_index.items():
filename = os.path.join(directory, '{}.txt'.format(namespace))
with open(filename, 'w', encoding='utf-8') as f:
for token, _ in token_to_index.items():
f.write('{}\n'.format(token))
@classmethod
def from_files(cls, directory: str) -> 'Vocabulary':
with open(os.path.join(directory, NAMESPACE_PADDING_FILE), 'r', encoding='utf-8') as f:
non_padded_namespaces = [namespace_str.strip() for namespace_str in f]
vocab = cls(non_padded_namespaces=non_padded_namespaces)
for namespace_filename in os.listdir(directory):
if namespace_filename == NAMESPACE_PADDING_FILE:
continue
if namespace_filename.startswith("."):
continue
namespace = namespace_filename.replace('.txt', '')
if any(namespace_match(pattern, namespace) for pattern in non_padded_namespaces):
is_padded = False
else:
is_padded = True
filename = os.path.join(directory, namespace_filename)
vocab.set_from_file(filename, is_padded, namespace=namespace)
return vocab
def set_from_file(self,
filename: str,
is_padded: bool = True,
oov_token: str = DEFAULT_OOV_TOKEN,
namespace: str = "tokens"
):
if is_padded:
self._token_to_index[namespace] = {self._padding_token: 0}
self._index_to_token[namespace] = {0: self._padding_token}
else:
self._token_to_index[namespace] = {}
self._index_to_token[namespace] = {}
with open(filename, 'r', encoding='utf-8') as f:
index = 1 if is_padded else 0
for row in f:
token = str(row).strip()
if token == oov_token:
token = self._oov_token
self._token_to_index[namespace][token] = index
self._index_to_token[namespace][index] = token
index += 1
def convert_tokens_to_ids(self, tokens: List[str], namespace: str = "tokens"):
result = list()
for token in tokens:
idx = self._token_to_index[namespace].get(token)
if idx is None:
idx = self._token_to_index[namespace][self._oov_token]
result.append(idx)
return result
def convert_ids_to_tokens(self, ids: List[int], namespace: str = "tokens"):
result = list()
for idx in ids:
idx = self._index_to_token[namespace][idx]
result.append(idx)
return result
def pad_or_truncate_ids_by_max_length(self, ids: List[int], max_length: int, namespace: str = "tokens"):
pad_idx = self._token_to_index[namespace][self._padding_token]
length = len(ids)
if length > max_length:
result = ids[:max_length]
else:
result = ids + [pad_idx] * (max_length - length)
return result
def demo1():
import jieba
vocabulary = Vocabulary()
vocabulary.add_token_to_namespace('白天', 'tokens')
vocabulary.add_token_to_namespace('晚上', 'tokens')
text = '不是在白天, 就是在晚上'
tokens = jieba.lcut(text)
print(tokens)
ids = vocabulary.convert_tokens_to_ids(tokens)
print(ids)
padded_idx = vocabulary.pad_or_truncate_ids_by_max_length(ids, 10)
print(padded_idx)
tokens = vocabulary.convert_ids_to_tokens(padded_idx)
print(tokens)
return
if __name__ == '__main__':
demo1()