|
""" |
|
Alphabet maps objects to integer ids. It provides two way mapping from the index to the objects. |
|
""" |
|
import json |
|
import os |
|
from .logger import get_logger |
|
|
|
class Alphabet(object): |
|
def __init__(self, name, defualt_value=False, keep_growing=True, singleton=False): |
|
self.__name = name |
|
|
|
self.instance2index = {} |
|
self.instances = [] |
|
self.default_value = defualt_value |
|
self.offset = 1 if self.default_value else 0 |
|
self.keep_growing = keep_growing |
|
self.singletons = set() if singleton else None |
|
|
|
|
|
self.default_index = 0 if self.default_value else None |
|
|
|
self.next_index = self.offset |
|
|
|
self.logger = get_logger("Alphabet") |
|
|
|
def add(self, instance): |
|
if instance not in self.instance2index and instance != '<_UNK>': |
|
self.instances.append(instance) |
|
self.instance2index[instance] = self.next_index |
|
self.next_index += 1 |
|
|
|
def add_singleton(self, id): |
|
if self.singletons is None: |
|
raise RuntimeError("Alphabet %s does not have singleton." % self.__name) |
|
else: |
|
self.singletons.add(id) |
|
|
|
def add_singletons(self, ids): |
|
if self.singletons is None: |
|
raise RuntimeError("Alphabet %s does not have singleton." % self.__name) |
|
else: |
|
self.singletons.update(ids) |
|
|
|
def is_singleton(self, id): |
|
if self.singletons is None: |
|
raise RuntimeError("Alphabet %s does not have singleton." % self.__name) |
|
else: |
|
return id in self.singletons |
|
|
|
def get_index(self, instance): |
|
try: |
|
return self.instance2index[instance] |
|
except KeyError: |
|
if self.keep_growing: |
|
index = self.next_index |
|
self.add(instance) |
|
return index |
|
else: |
|
if self.default_value: |
|
return self.default_index |
|
else: |
|
raise KeyError("instance not found: %s" % instance) |
|
|
|
def get_instance(self, index): |
|
if self.default_value and index == self.default_index: |
|
|
|
return "<_UNK>" |
|
else: |
|
try: |
|
return self.instances[index - self.offset] |
|
except IndexError: |
|
raise IndexError("unknown index: %d" % index) |
|
|
|
def size(self): |
|
return len(self.instances) + self.offset |
|
|
|
def singleton_size(self): |
|
return len(self.singletons) |
|
|
|
def items(self): |
|
return self.instance2index.items() |
|
|
|
def keys(self): |
|
return self.instance2index.keys() |
|
|
|
def values(self): |
|
return self.instance2index.values() |
|
|
|
def token_in_alphabet(self, token): |
|
return token in set(self.instance2index.keys()) |
|
|
|
def enumerate_items(self, start): |
|
if start < self.offset or start >= self.size(): |
|
raise IndexError("Enumerate is allowed between [%d : size of the alphabet)" % self.offset) |
|
return zip(range(start, len(self.instances) + self.offset), self.instances[start - self.offset:]) |
|
|
|
def close(self): |
|
self.keep_growing = False |
|
|
|
def open(self): |
|
self.keep_growing = True |
|
|
|
def get_content(self): |
|
if self.singletons is None: |
|
return {"instance2index": self.instance2index, "instances": self.instances} |
|
else: |
|
return {"instance2index": self.instance2index, "instances": self.instances, |
|
"singletions": list(self.singletons)} |
|
|
|
def __from_json(self, data): |
|
self.instances = data["instances"] |
|
self.instance2index = data["instance2index"] |
|
if "singletions" in data: |
|
self.singletons = set(data["singletions"]) |
|
else: |
|
self.singletons = None |
|
|
|
def save(self, output_directory, name=None): |
|
""" |
|
Save both alhpabet records to the given directory. |
|
:param output_directory: Directory to save model and weights. |
|
:param name: The alphabet saving name, optional. |
|
:return: |
|
""" |
|
saving_name = name if name else self.__name |
|
try: |
|
if not os.path.exists(output_directory): |
|
os.makedirs(output_directory) |
|
|
|
json.dump(self.get_content(), |
|
open(os.path.join(output_directory, saving_name + ".json"), 'w'), indent=4) |
|
except Exception as e: |
|
self.logger.warn("Alphabet is not saved: %s" % repr(e)) |
|
|
|
def load(self, input_directory, name=None): |
|
""" |
|
Load model architecture and weights from the give directory. This allow we use old saved_models even the structure |
|
changes. |
|
:param input_directory: Directory to save model and weights |
|
:return: |
|
""" |
|
loading_name = name if name else self.__name |
|
filename = os.path.join(input_directory, loading_name + ".json") |
|
f = json.load(open(filename)) |
|
self.__from_json(f) |
|
self.next_index = len(self.instances) + self.offset |
|
self.keep_growing = False |
|
|