import torch import torch.nn as nn from torch.utils.data import Dataset import json class BilingualDataset(Dataset): def __init__(self, ds, tokenizer, seq_len): super().__init__() self.tokenizer = tokenizer self.seq_len = seq_len self.ds = ds self.stride = seq_len//2 self.sos_token = torch.tensor([tokenizer.token_to_id('')],dtype=torch.int64) self.eos_token = torch.tensor([tokenizer.token_to_id('')],dtype=torch.int64) self.pad_token = torch.tensor([tokenizer.token_to_id('')],dtype=torch.int64) self.user_token = torch.tensor([tokenizer.token_to_id('')],dtype=torch.int64) self.ai_token = torch.tensor([tokenizer.token_to_id('')],dtype=torch.int64) self.data_tokens = [] for text in self.ds: # text = text['instruction'] +" ### " + text['text'] + " \n" + text['output'] # text = text['user'] +" ### " + text['ai'] user_tokens = tokenizer.encode(text['instruction'] + " " + text['input']).ids ai_tokens = tokenizer.encode(text['output']).ids self.data_tokens.extend([self.user_token] + user_tokens + [self.ai_token] + ai_tokens+ [self.eos_token] ) def __len__(self): return (len(self.data_tokens) - self.seq_len) // self.stride def __getitem__(self, index): input_tokens = torch.tensor(self.data_tokens[index*self.stride:(index*self.stride)+self.seq_len- 1]).tolist() input_tokens = [self.sos_token] + input_tokens + [self.pad_token] if len(input_tokens) < self.seq_len - 1: input_tokens+=[self.pad_token] * ((self.seq_len - 1 ) - len(input_tokens)) input_tokens = torch.tensor(input_tokens) return { "input": input_tokens[:-1], # "input_mask": (input_tokens[:-1] != self.pad_token).unsqueeze(0).int() & causal_mask(input_tokens[:-1].size(0)), # (1, seq_len) & (1, seq_len, seq_len) "label":input_tokens[1:] # ^ CONFUSION SYNTAX :) } def causal_mask(size): mask = torch.triu(torch.ones(1,size,size), diagonal=1).type(torch.int) return mask == 0