| import torch | |
| from torch.utils.data import Dataset | |
| from datasets import load_dataset | |
| class ChessDataset(Dataset): | |
| def __init__(self, data, tokenizer, block_size): | |
| self.data = data | |
| self.tokenizer = tokenizer | |
| self.block_size = block_size | |
| def __len__(self): return len(self.data) | |
| def __getitem__(self, idx): | |
| text = self.data[idx]["text"] | |
| tokens =self.tokenizer(text, max_length=self.block_size)["input_ids"] | |
| input_ids= torch.tensor(tokens, dtype=torch.long) | |
| attention_mask= torch.ones_like(input_ids) | |
| return {"input_ids": input_ids, "attention_mask": attention_mask} | |
| class ChessDataCollator: | |
| def __init__(self, tokenizer=None, max_length=None): pass | |
| def __call__(self, features): | |
| input_ids=torch.nn.utils.rnn.pad_sequence([f["input_ids"] for f in features], batch_first=True, padding_value=0) | |
| mask =torch.nn.utils.rnn.pad_sequence([f["attention_mask"] for f in features], batch_first=True, padding_value=0) | |
| labels = input_ids.clone() | |
| labels[mask == 0] = -100 | |
| return {"input_ids": input_ids, "attention_mask": mask, "labels": labels} | |
| def create_train_val_datasets(dataset_name, tokenizer, val_samples=1000, **kwargs): | |
| max_train=kwargs.get('train_samples', kwargs.get('max_train_samples', 50000)) | |
| block_size= kwargs.get('n_ctx', kwargs.get('max_length', 256)) | |
| ds =load_dataset(dataset_name, split="train") | |
| if len(ds)>max_train + val_samples: ds = ds.select(range(max_train + val_samples)) | |
| split=ds.train_test_split(test_size=val_samples) | |
| return ChessDataset(split["train"], tokenizer, block_size), ChessDataset(split["test"], tokenizer, block_size) | |