FRIENDS-GPT / data_utils.py
bala1802's picture
Upload 7 files
dabde41
raw
history blame
1 kB
import torch
with open('data/input.txt', 'r', encoding='utf-8') as f:
text = f.read()
# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string
# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]
'''
'''
def get_train_data():
return train_data
'''
'''
def get_val_data():
return val_data
'''
'''
def get_data():
return data
'''
'''
def get_encoder():
return encode
'''
'''
def get_decoder():
return decode
'''
'''
def get_vocab_size():
return vocab_size