Spaces:
Runtime error
Runtime error
class Vocab(): | |
def __init__(self, chars): | |
self.pad = 0 | |
self.go = 1 | |
self.eos = 2 | |
self.mask_token = 3 | |
self.chars = chars | |
self.c2i = {c:i+4 for i, c in enumerate(chars)} | |
self.i2c = {i+4:c for i, c in enumerate(chars)} | |
self.i2c[0] = '<pad>' | |
self.i2c[1] = '<sos>' | |
self.i2c[2] = '<eos>' | |
self.i2c[3] = '*' | |
def encode(self, chars): | |
return [self.go] + [self.c2i[c] for c in chars] + [self.eos] | |
def decode(self, ids): | |
first = 1 if self.go in ids else 0 | |
last = ids.index(self.eos) if self.eos in ids else None | |
sent = ''.join([self.i2c[i] for i in ids[first:last]]) | |
return sent | |
def __len__(self): | |
return len(self.c2i) + 4 | |
def batch_decode(self, arr): | |
texts = [self.decode(ids) for ids in arr] | |
return texts | |
def __str__(self): | |
return self.chars | |