__author__ = 'Taneem Jan, taneemishere.github.io' import sys import numpy as np START_TOKEN = "" END_TOKEN = "" PLACEHOLDER = " " SEPARATOR = '->' class Vocabulary: def __init__(self): self.binary_vocabulary = {} self.vocabulary = {} self.token_lookup = {} self.size = 0 self.append(START_TOKEN) self.append(END_TOKEN) self.append(PLACEHOLDER) def append(self, token): if token not in self.vocabulary: self.vocabulary[token] = self.size self.token_lookup[self.size] = token self.size += 1 def create_binary_representation(self): if sys.version_info >= (3,): items = self.vocabulary.items() else: items = self.vocabulary.iteritems() for key, value in items: binary = np.zeros(self.size) binary[value] = 1 self.binary_vocabulary[key] = binary def get_serialized_binary_representation(self): if len(self.binary_vocabulary) == 0: self.create_binary_representation() string = "" if sys.version_info >= (3,): items = self.binary_vocabulary.items() else: items = self.binary_vocabulary.iteritems() for key, value in items: array_as_string = np.array2string(value, separator=',', max_line_width=self.size * self.size) string += "{}{}{}\n".format(key, SEPARATOR, array_as_string[1:len(array_as_string) - 1]) return string def save(self, path): output_file_name = "{}/words.vocab".format(path) output_file = open(output_file_name, 'w') output_file.write(self.get_serialized_binary_representation()) output_file.close() def retrieve(self, path): input_file = open("{}/words.vocab".format(path), 'r') buffer = "" for line in input_file: try: separator_position = len(buffer) + line.index(SEPARATOR) buffer += line key = buffer[:separator_position] value = buffer[separator_position + len(SEPARATOR):] value = np.fromstring(value, sep=',') self.binary_vocabulary[key] = value self.vocabulary[key] = np.where(value == 1)[0][0] self.token_lookup[np.where(value == 1)[0][0]] = key buffer = "" except ValueError: buffer += line input_file.close() self.size = len(self.vocabulary)