hantech's picture
Upload 38 files
bd22b5e
raw
history blame contribute delete
No virus
944 Bytes
class Vocab():
def __init__(self, chars):
self.pad = 0
self.go = 1
self.eos = 2
self.mask_token = 3
self.chars = chars
self.c2i = {c:i+4 for i, c in enumerate(chars)}
self.i2c = {i+4:c for i, c in enumerate(chars)}
self.i2c[0] = '<pad>'
self.i2c[1] = '<sos>'
self.i2c[2] = '<eos>'
self.i2c[3] = '*'
def encode(self, chars):
return [self.go] + [self.c2i[c] for c in chars] + [self.eos]
def decode(self, ids):
first = 1 if self.go in ids else 0
last = ids.index(self.eos) if self.eos in ids else None
sent = ''.join([self.i2c[i] for i in ids[first:last]])
return sent
def __len__(self):
return len(self.c2i) + 4
def batch_decode(self, arr):
texts = [self.decode(ids) for ids in arr]
return texts
def __str__(self):
return self.chars