Spaces:
Sleeping
Sleeping
File size: 3,842 Bytes
7deef83 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
from torch.nn import functional as F
from torch.utils.data import Dataset
import numpy as np
import random
import torch
import re
stoi = {'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9, '\n': 10, '000000000000': 11}
itos = {0: '0', 1: '1', 2: '2', 3: '3', 4: '4', 5: '5', 6: '6', 7: '7', 8: '8', 9: '9', 10: '\n', 11: '000000000000'}
tok_chars = re.compile(r'000000000000|\d{1}|\n')
def encode(text, stoi, tokenizer):
matches = tokenizer.findall(text)
return [stoi[c] for c in matches if c in stoi]
def decode(encoded, itos):
return ''.join([itos[i] for i in encoded])
class Dataset:
def __init__(self, data, ctx_len, epoch_length_fixed, time_aug=True):
self.ctx_len = ctx_len
self.epoch_length_fixed = epoch_length_fixed
self.start_token = '000000000000'
self.tokenizer = tok_chars
self.stoi = stoi
self.itos = itos
self.vocab_size = len(stoi)
print('vocab size:', self.vocab_size)
self.data = encode(data, self.stoi, self.tokenizer)
self.data_size = len(self.data)
print(f'data has {self.data_size} tokens')
def __len__(self):
return self.epoch_length_fixed
def __getitem__(self, idx):
cues = []
idx_randm = random.randint(0, len(self.data) - (self.ctx_len) * 4)
i = idx_randm
while True:
if self.data[i] == self.stoi[self.start_token]:
cues = [i]
break
else:
i = (i + 1) % len(self.data)
if not cues:
return None
start_idx = cues[0]
dix = self.data[start_idx : start_idx + self.ctx_len + 2]
# 96 tick resolution
time_shift = [
[0, 0, 0, 0, 0, 7, 6, 8, 0, 7, 6, 8, 0],
[0, 0, 0, 0, 1, 5, 3, 6, 1, 5, 3, 6, 0],
]
data_aug = random.choice([True, False])
t = dix[2:2 + self.ctx_len] # testing
if data_aug:
ts_rndm = random.choice(time_shift)
ts = ts_rndm * ((self.ctx_len - 1) // len(ts_rndm) + 1)
tsx = torch.tensor(ts[:self.ctx_len])
for j in reversed(range(len(t))):
if j % 13 not in range(2, 12):
continue
aug_int = t[j] + tsx[j]
if aug_int >= 10 and (aug_int not in [10, 11] or j not in [9, 10]):
left_int = aug_int // 10
right_int = aug_int % 10
if j > 0:
t[j - 1] += left_int
t[j] = right_int
else:
t[j] = aug_int
x = t
y = t[1:] + [t[-1]]
else:
x = dix[:-1][:self.ctx_len]
y = dix[1:][:self.ctx_len]
x = torch.tensor(x, dtype=torch.int64)
y = torch.tensor(y, dtype=torch.int64)
return x, y
class TOKENIZER():
def __init__(self):
self.tokenizer = tok_chars
self.stoi = stoi
self.itos = itos
self.vocab_size = len(self.stoi)
def encode(self, text):
matches = self.tokenizer.findall(text)
return [self.stoi[c] for c in matches if c in self.stoi]
def decode(self, encoded):
return ''.join([self.itos[i] for i in encoded])
def sample_logits(self, out, x, ctx_len, temperature=1.0, top_k=50):
probs = F.softmax(torch.tensor(out), dim=-1)
if top_k > 0:
top_k = min(top_k, probs.size(-1))
sorted_probs, sorted_indices = torch.topk(probs, top_k)
probs.fill_(0)
probs.scatter_(dim=-1, index=sorted_indices, src=sorted_probs)
if temperature != 1.0:
probs = probs.pow(1.0 / temperature)
return torch.multinomial(probs, num_samples=1)[0]
|