potterGPT-v0 / model /tokenizer.py
nullHawk's picture
add: v0
9fe7c42 verified
raw
history blame contribute delete
495 Bytes
import torch
class CharacterLevelTokenizer:
def __init__(self,data):
self.data = data
self.vocab = sorted(list(set(self.data)))
self.VOCAB_SIZE = len(self.vocab)
self.i_s = {i:s for i,s in enumerate(self.vocab)}
self.s_i = {s:i for i,s in self.i_s.items()}
def encode(self,s):
return torch.tensor([self.s_i[c] for c in s],dtype=torch.long)
def decode(self,s):
return ''.join([self.i_s[i.item()] for i in s])