File size: 2,295 Bytes
82f9e44 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import json
class BilingualDataset(Dataset):
def __init__(self, ds, tokenizer, seq_len):
super().__init__()
self.tokenizer = tokenizer
self.seq_len = seq_len
self.ds = ds
self.stride = seq_len//2
self.sos_token = torch.tensor([tokenizer.token_to_id('<s>')],dtype=torch.int64)
self.eos_token = torch.tensor([tokenizer.token_to_id('</s>')],dtype=torch.int64)
self.pad_token = torch.tensor([tokenizer.token_to_id('<pad>')],dtype=torch.int64)
self.user_token = torch.tensor([tokenizer.token_to_id('<user>')],dtype=torch.int64)
self.ai_token = torch.tensor([tokenizer.token_to_id('<ai>')],dtype=torch.int64)
self.data_tokens = []
for text in self.ds:
# text = text['instruction'] +" ### " + text['text'] + " \n" + text['output']
# text = text['user'] +" ### " + text['ai']
user_tokens = tokenizer.encode(text['instruction'] + " " + text['input']).ids
ai_tokens = tokenizer.encode(text['output']).ids
self.data_tokens.extend([self.user_token] + user_tokens + [self.ai_token] + ai_tokens+ [self.eos_token] )
def __len__(self):
return (len(self.data_tokens) - self.seq_len) // self.stride
def __getitem__(self, index):
input_tokens = torch.tensor(self.data_tokens[index*self.stride:(index*self.stride)+self.seq_len- 1]).tolist()
input_tokens = [self.sos_token] + input_tokens + [self.pad_token]
if len(input_tokens) < self.seq_len - 1:
input_tokens+=[self.pad_token] * ((self.seq_len - 1 ) - len(input_tokens))
input_tokens = torch.tensor(input_tokens)
return {
"input": input_tokens[:-1],
# "input_mask": (input_tokens[:-1] != self.pad_token).unsqueeze(0).int() & causal_mask(input_tokens[:-1].size(0)), # (1, seq_len) & (1, seq_len, seq_len)
"label":input_tokens[1:] # ^ CONFUSION SYNTAX :)
}
def causal_mask(size):
mask = torch.triu(torch.ones(1,size,size), diagonal=1).type(torch.int)
return mask == 0 |