pop_k / utils.py
patchbanks's picture
Upload 12 files
7deef83
from torch.nn import functional as F
from torch.utils.data import Dataset
import numpy as np
import random
import torch
import re
stoi = {'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9, '\n': 10, '000000000000': 11}
itos = {0: '0', 1: '1', 2: '2', 3: '3', 4: '4', 5: '5', 6: '6', 7: '7', 8: '8', 9: '9', 10: '\n', 11: '000000000000'}
tok_chars = re.compile(r'000000000000|\d{1}|\n')
def encode(text, stoi, tokenizer):
matches = tokenizer.findall(text)
return [stoi[c] for c in matches if c in stoi]
def decode(encoded, itos):
return ''.join([itos[i] for i in encoded])
class Dataset:
def __init__(self, data, ctx_len, epoch_length_fixed, time_aug=True):
self.ctx_len = ctx_len
self.epoch_length_fixed = epoch_length_fixed
self.start_token = '000000000000'
self.tokenizer = tok_chars
self.stoi = stoi
self.itos = itos
self.vocab_size = len(stoi)
print('vocab size:', self.vocab_size)
self.data = encode(data, self.stoi, self.tokenizer)
self.data_size = len(self.data)
print(f'data has {self.data_size} tokens')
def __len__(self):
return self.epoch_length_fixed
def __getitem__(self, idx):
cues = []
idx_randm = random.randint(0, len(self.data) - (self.ctx_len) * 4)
i = idx_randm
while True:
if self.data[i] == self.stoi[self.start_token]:
cues = [i]
break
else:
i = (i + 1) % len(self.data)
if not cues:
return None
start_idx = cues[0]
dix = self.data[start_idx : start_idx + self.ctx_len + 2]
# 96 tick resolution
time_shift = [
[0, 0, 0, 0, 0, 7, 6, 8, 0, 7, 6, 8, 0],
[0, 0, 0, 0, 1, 5, 3, 6, 1, 5, 3, 6, 0],
]
data_aug = random.choice([True, False])
t = dix[2:2 + self.ctx_len] # testing
if data_aug:
ts_rndm = random.choice(time_shift)
ts = ts_rndm * ((self.ctx_len - 1) // len(ts_rndm) + 1)
tsx = torch.tensor(ts[:self.ctx_len])
for j in reversed(range(len(t))):
if j % 13 not in range(2, 12):
continue
aug_int = t[j] + tsx[j]
if aug_int >= 10 and (aug_int not in [10, 11] or j not in [9, 10]):
left_int = aug_int // 10
right_int = aug_int % 10
if j > 0:
t[j - 1] += left_int
t[j] = right_int
else:
t[j] = aug_int
x = t
y = t[1:] + [t[-1]]
else:
x = dix[:-1][:self.ctx_len]
y = dix[1:][:self.ctx_len]
x = torch.tensor(x, dtype=torch.int64)
y = torch.tensor(y, dtype=torch.int64)
return x, y
class TOKENIZER():
def __init__(self):
self.tokenizer = tok_chars
self.stoi = stoi
self.itos = itos
self.vocab_size = len(self.stoi)
def encode(self, text):
matches = self.tokenizer.findall(text)
return [self.stoi[c] for c in matches if c in self.stoi]
def decode(self, encoded):
return ''.join([self.itos[i] for i in encoded])
def sample_logits(self, out, x, ctx_len, temperature=1.0, top_k=50):
probs = F.softmax(torch.tensor(out), dim=-1)
if top_k > 0:
top_k = min(top_k, probs.size(-1))
sorted_probs, sorted_indices = torch.topk(probs, top_k)
probs.fill_(0)
probs.scatter_(dim=-1, index=sorted_indices, src=sorted_probs)
if temperature != 1.0:
probs = probs.pow(1.0 / temperature)
return torch.multinomial(probs, num_samples=1)[0]