File size: 5,146 Bytes
e8f4897
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
"""
Alphabet maps objects to integer ids. It provides two way mapping from the index to the objects.
"""
import json
import os
from .logger import get_logger

class Alphabet(object):
    def __init__(self, name, defualt_value=False, keep_growing=True, singleton=False):
        self.__name = name

        self.instance2index = {}
        self.instances = []
        self.default_value = defualt_value
        self.offset = 1 if self.default_value else 0
        self.keep_growing = keep_growing
        self.singletons = set() if singleton else None

        # Index 0 is occupied by default, all else following.
        self.default_index = 0 if self.default_value else None

        self.next_index = self.offset

        self.logger = get_logger("Alphabet")

    def add(self, instance):
        if instance not in self.instance2index and instance != '<_UNK>':
            self.instances.append(instance)
            self.instance2index[instance] = self.next_index
            self.next_index += 1

    def add_singleton(self, id):
        if self.singletons is None:
            raise RuntimeError("Alphabet %s does not have singleton." % self.__name)
        else:
            self.singletons.add(id)

    def add_singletons(self, ids):
        if self.singletons is None:
            raise RuntimeError("Alphabet %s does not have singleton." % self.__name)
        else:
            self.singletons.update(ids)

    def is_singleton(self, id):
        if self.singletons is None:
            raise RuntimeError("Alphabet %s does not have singleton." % self.__name)
        else:
            return id in self.singletons

    def get_index(self, instance):
        try:
            return self.instance2index[instance]
        except KeyError:
            if self.keep_growing:
                index = self.next_index
                self.add(instance)
                return index
            else:
                if self.default_value:
                    return self.default_index
                else:
                    raise KeyError("instance not found: %s" % instance)

    def get_instance(self, index):
        if self.default_value and index == self.default_index:
            # First index is occupied by the wildcard element.
            return "<_UNK>"
        else:
            try:
                return self.instances[index - self.offset]
            except IndexError:
                raise IndexError("unknown index: %d" % index)

    def size(self):
        return len(self.instances) + self.offset

    def singleton_size(self):
        return len(self.singletons)

    def items(self):
        return self.instance2index.items()

    def keys(self):
        return self.instance2index.keys()

    def values(self):
        return self.instance2index.values()

    def token_in_alphabet(self, token):
        return token in set(self.instance2index.keys())

    def enumerate_items(self, start):
        if start < self.offset or start >= self.size():
            raise IndexError("Enumerate is allowed between [%d : size of the alphabet)" % self.offset)
        return zip(range(start, len(self.instances) + self.offset), self.instances[start - self.offset:])

    def close(self):
        self.keep_growing = False

    def open(self):
        self.keep_growing = True

    def get_content(self):
        if self.singletons is None:
            return {"instance2index": self.instance2index, "instances": self.instances}
        else:
            return {"instance2index": self.instance2index, "instances": self.instances,
                    "singletions": list(self.singletons)}

    def __from_json(self, data):
        self.instances = data["instances"]
        self.instance2index = data["instance2index"]
        if "singletions" in data:
            self.singletons = set(data["singletions"])
        else:
            self.singletons = None

    def save(self, output_directory, name=None):
        """
        Save both alhpabet records to the given directory.
        :param output_directory: Directory to save model and weights.
        :param name: The alphabet saving name, optional.
        :return:
        """
        saving_name = name if name else self.__name
        try:
            if not os.path.exists(output_directory):
                os.makedirs(output_directory)

            json.dump(self.get_content(),
                      open(os.path.join(output_directory, saving_name + ".json"), 'w'), indent=4)
        except Exception as e:
            self.logger.warn("Alphabet is not saved: %s" % repr(e))

    def load(self, input_directory, name=None):
        """
        Load model architecture and weights from the give directory. This allow we use old saved_models even the structure
        changes.
        :param input_directory: Directory to save model and weights
        :return:
        """
        loading_name = name if name else self.__name
        filename = os.path.join(input_directory, loading_name + ".json")
        f = json.load(open(filename))
        self.__from_json(f)
        self.next_index = len(self.instances) + self.offset
        self.keep_growing = False