Spaces:
Sleeping
Sleeping
from torch.nn import functional as F | |
from torch.utils.data import Dataset | |
import numpy as np | |
import random | |
import torch | |
import re | |
stoi = {'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9, '\n': 10, '000000000000': 11} | |
itos = {0: '0', 1: '1', 2: '2', 3: '3', 4: '4', 5: '5', 6: '6', 7: '7', 8: '8', 9: '9', 10: '\n', 11: '000000000000'} | |
tok_chars = re.compile(r'000000000000|\d{1}|\n') | |
def encode(text, stoi, tokenizer): | |
matches = tokenizer.findall(text) | |
return [stoi[c] for c in matches if c in stoi] | |
def decode(encoded, itos): | |
return ''.join([itos[i] for i in encoded]) | |
class Dataset: | |
def __init__(self, data, ctx_len, epoch_length_fixed, time_aug=True): | |
self.ctx_len = ctx_len | |
self.epoch_length_fixed = epoch_length_fixed | |
self.start_token = '000000000000' | |
self.tokenizer = tok_chars | |
self.stoi = stoi | |
self.itos = itos | |
self.vocab_size = len(stoi) | |
print('vocab size:', self.vocab_size) | |
self.data = encode(data, self.stoi, self.tokenizer) | |
self.data_size = len(self.data) | |
print(f'data has {self.data_size} tokens') | |
def __len__(self): | |
return self.epoch_length_fixed | |
def __getitem__(self, idx): | |
cues = [] | |
idx_randm = random.randint(0, len(self.data) - (self.ctx_len) * 4) | |
i = idx_randm | |
while True: | |
if self.data[i] == self.stoi[self.start_token]: | |
cues = [i] | |
break | |
else: | |
i = (i + 1) % len(self.data) | |
if not cues: | |
return None | |
start_idx = cues[0] | |
dix = self.data[start_idx : start_idx + self.ctx_len + 2] | |
# 96 tick resolution | |
time_shift = [ | |
[0, 0, 0, 0, 0, 7, 6, 8, 0, 7, 6, 8, 0], | |
[0, 0, 0, 0, 1, 5, 3, 6, 1, 5, 3, 6, 0], | |
] | |
data_aug = random.choice([True, False]) | |
t = dix[2:2 + self.ctx_len] # testing | |
if data_aug: | |
ts_rndm = random.choice(time_shift) | |
ts = ts_rndm * ((self.ctx_len - 1) // len(ts_rndm) + 1) | |
tsx = torch.tensor(ts[:self.ctx_len]) | |
for j in reversed(range(len(t))): | |
if j % 13 not in range(2, 12): | |
continue | |
aug_int = t[j] + tsx[j] | |
if aug_int >= 10 and (aug_int not in [10, 11] or j not in [9, 10]): | |
left_int = aug_int // 10 | |
right_int = aug_int % 10 | |
if j > 0: | |
t[j - 1] += left_int | |
t[j] = right_int | |
else: | |
t[j] = aug_int | |
x = t | |
y = t[1:] + [t[-1]] | |
else: | |
x = dix[:-1][:self.ctx_len] | |
y = dix[1:][:self.ctx_len] | |
x = torch.tensor(x, dtype=torch.int64) | |
y = torch.tensor(y, dtype=torch.int64) | |
return x, y | |
class TOKENIZER(): | |
def __init__(self): | |
self.tokenizer = tok_chars | |
self.stoi = stoi | |
self.itos = itos | |
self.vocab_size = len(self.stoi) | |
def encode(self, text): | |
matches = self.tokenizer.findall(text) | |
return [self.stoi[c] for c in matches if c in self.stoi] | |
def decode(self, encoded): | |
return ''.join([self.itos[i] for i in encoded]) | |
def sample_logits(self, out, x, ctx_len, temperature=1.0, top_k=50): | |
probs = F.softmax(torch.tensor(out), dim=-1) | |
if top_k > 0: | |
top_k = min(top_k, probs.size(-1)) | |
sorted_probs, sorted_indices = torch.topk(probs, top_k) | |
probs.fill_(0) | |
probs.scatter_(dim=-1, index=sorted_indices, src=sorted_probs) | |
if temperature != 1.0: | |
probs = probs.pow(1.0 / temperature) | |
return torch.multinomial(probs, num_samples=1)[0] | |