File size: 5,146 Bytes
e8f4897 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
"""
Alphabet maps objects to integer ids. It provides two way mapping from the index to the objects.
"""
import json
import os
from .logger import get_logger
class Alphabet(object):
def __init__(self, name, defualt_value=False, keep_growing=True, singleton=False):
self.__name = name
self.instance2index = {}
self.instances = []
self.default_value = defualt_value
self.offset = 1 if self.default_value else 0
self.keep_growing = keep_growing
self.singletons = set() if singleton else None
# Index 0 is occupied by default, all else following.
self.default_index = 0 if self.default_value else None
self.next_index = self.offset
self.logger = get_logger("Alphabet")
def add(self, instance):
if instance not in self.instance2index and instance != '<_UNK>':
self.instances.append(instance)
self.instance2index[instance] = self.next_index
self.next_index += 1
def add_singleton(self, id):
if self.singletons is None:
raise RuntimeError("Alphabet %s does not have singleton." % self.__name)
else:
self.singletons.add(id)
def add_singletons(self, ids):
if self.singletons is None:
raise RuntimeError("Alphabet %s does not have singleton." % self.__name)
else:
self.singletons.update(ids)
def is_singleton(self, id):
if self.singletons is None:
raise RuntimeError("Alphabet %s does not have singleton." % self.__name)
else:
return id in self.singletons
def get_index(self, instance):
try:
return self.instance2index[instance]
except KeyError:
if self.keep_growing:
index = self.next_index
self.add(instance)
return index
else:
if self.default_value:
return self.default_index
else:
raise KeyError("instance not found: %s" % instance)
def get_instance(self, index):
if self.default_value and index == self.default_index:
# First index is occupied by the wildcard element.
return "<_UNK>"
else:
try:
return self.instances[index - self.offset]
except IndexError:
raise IndexError("unknown index: %d" % index)
def size(self):
return len(self.instances) + self.offset
def singleton_size(self):
return len(self.singletons)
def items(self):
return self.instance2index.items()
def keys(self):
return self.instance2index.keys()
def values(self):
return self.instance2index.values()
def token_in_alphabet(self, token):
return token in set(self.instance2index.keys())
def enumerate_items(self, start):
if start < self.offset or start >= self.size():
raise IndexError("Enumerate is allowed between [%d : size of the alphabet)" % self.offset)
return zip(range(start, len(self.instances) + self.offset), self.instances[start - self.offset:])
def close(self):
self.keep_growing = False
def open(self):
self.keep_growing = True
def get_content(self):
if self.singletons is None:
return {"instance2index": self.instance2index, "instances": self.instances}
else:
return {"instance2index": self.instance2index, "instances": self.instances,
"singletions": list(self.singletons)}
def __from_json(self, data):
self.instances = data["instances"]
self.instance2index = data["instance2index"]
if "singletions" in data:
self.singletons = set(data["singletions"])
else:
self.singletons = None
def save(self, output_directory, name=None):
"""
Save both alhpabet records to the given directory.
:param output_directory: Directory to save model and weights.
:param name: The alphabet saving name, optional.
:return:
"""
saving_name = name if name else self.__name
try:
if not os.path.exists(output_directory):
os.makedirs(output_directory)
json.dump(self.get_content(),
open(os.path.join(output_directory, saving_name + ".json"), 'w'), indent=4)
except Exception as e:
self.logger.warn("Alphabet is not saved: %s" % repr(e))
def load(self, input_directory, name=None):
"""
Load model architecture and weights from the give directory. This allow we use old saved_models even the structure
changes.
:param input_directory: Directory to save model and weights
:return:
"""
loading_name = name if name else self.__name
filename = os.path.join(input_directory, loading_name + ".json")
f = json.load(open(filename))
self.__from_json(f)
self.next_index = len(self.instances) + self.offset
self.keep_growing = False
|