class Vocab(): def __init__(self, chars): self.pad = 0 self.go = 1 self.eos = 2 self.mask_token = 3 self.chars = chars self.c2i = {c:i+4 for i, c in enumerate(chars)} self.i2c = {i+4:c for i, c in enumerate(chars)} self.i2c[0] = '' self.i2c[1] = '' self.i2c[2] = '' self.i2c[3] = '*' def encode(self, chars): return [self.go] + [self.c2i[c] for c in chars] + [self.eos] def decode(self, ids): first = 1 if self.go in ids else 0 last = ids.index(self.eos) if self.eos in ids else None sent = ''.join([self.i2c[i] for i in ids[first:last]]) return sent def __len__(self): return len(self.c2i) + 4 def batch_decode(self, arr): texts = [self.decode(ids) for ids in arr] return texts def __str__(self): return self.chars