import torch from torch.utils.data import IterableDataset class BilingualDataset(IterableDataset): def __init__(self, ds_stream, tokenizer, seq_len): self.ds_stream = ds_stream self.tokenizer = tokenizer self.seq_len = seq_len self.stride = seq_len // 2 self.sos_token = tokenizer.token_to_id('') self.eos_token = tokenizer.token_to_id('') self.pad_token = tokenizer.token_to_id('') def process_text(self, text): token_ids = self.tokenizer.encode(text).ids + [self.eos_token] for i in range(0, max(1, len(token_ids) - self.seq_len + 1), self.stride): chunk = token_ids[i:i + self.seq_len - 2] # leave space for and pad chunk = [self.sos_token] + chunk if len(chunk) < self.seq_len: chunk += [self.pad_token] * (self.seq_len - len(chunk)) input_tensor = torch.tensor(chunk[:-1], dtype=torch.long) label_tensor = torch.tensor(chunk[1:], dtype=torch.long) yield { "input": input_tensor, "label": label_tensor } def __iter__(self): for item in self.ds_stream: text = item["text"] yield from self.process_text(text) """import torch import torch.nn as nn from torch.utils.data import Dataset import json class BilingualDataset(Dataset): def __init__(self, ds, tokenizer, seq_len): super().__init__() self.tokenizer = tokenizer self.seq_len = seq_len self.ds = ds self.stride = seq_len//2 self.sos_token = torch.tensor([tokenizer.token_to_id('')],dtype=torch.int64) self.eos_token = torch.tensor([tokenizer.token_to_id('')],dtype=torch.int64) self.pad_token = torch.tensor([tokenizer.token_to_id('')],dtype=torch.int64) self.data_tokens = [] for text in self.ds: # text = text['instruction'] +" ### " + text['text'] + " \n" + text['output'] # text = text['user'] +" ### " + text['ai'] text = text['text'] tokens = tokenizer.encode(text).ids self.data_tokens.extend(tokens + [self.eos_token]) def __len__(self): return (len(self.data_tokens) - self.seq_len) // self.stride def __getitem__(self, index): input_tokens = torch.tensor(self.data_tokens[index*self.stride:(index*self.stride)+self.seq_len- 1]).tolist() input_tokens = [self.sos_token] + input_tokens + [self.pad_token] if len(input_tokens) < self.seq_len - 1: input_tokens+=[self.pad_token] * ((self.seq_len - 1 ) - len(input_tokens)) input_tokens = torch.tensor(input_tokens) return { "input": input_tokens[:-1], # "input_mask": (input_tokens[:-1] != self.pad_token).unsqueeze(0).int() & causal_mask(input_tokens[:-1].size(0)), # (1, seq_len) & (1, seq_len, seq_len) "label":input_tokens[1:] # ^ CONFUSION SYNTAX :) } def causal_mask(size): mask = torch.triu(torch.ones(1,size,size), diagonal=1).type(torch.int) return mask == 0"""