|
|
""" |
|
|
Data loading utilities for the Chess Challenge using color+piece/from/to tokenizer. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
from typing import Dict, Iterator, List, Optional |
|
|
import torch |
|
|
from torch.utils.data import Dataset |
|
|
|
|
|
class ChessDataset(Dataset): |
|
|
""" |
|
|
PyTorch Dataset for chess games with color+piece/from/to tokenizer. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
tokenizer, |
|
|
dataset_name: str = "dlouapre/lichess_2025-01_1M", |
|
|
split: str = "train", |
|
|
column: str = "text", |
|
|
max_length: int = 256, |
|
|
max_samples: Optional[int] = None, |
|
|
): |
|
|
from datasets import load_dataset |
|
|
|
|
|
self.tokenizer = tokenizer |
|
|
self.max_length = max_length |
|
|
self.column = column |
|
|
|
|
|
|
|
|
dataset = load_dataset(dataset_name, split=split) |
|
|
if max_samples is not None: |
|
|
dataset = dataset.select(range(min(max_samples, len(dataset)))) |
|
|
self.data = dataset |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self.data) |
|
|
|
|
|
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
|
|
game = self.data[idx][self.column] |
|
|
|
|
|
|
|
|
game_with_bos = self.tokenizer.bos_token + " " + game |
|
|
|
|
|
|
|
|
encoding = self.tokenizer( |
|
|
game_with_bos, |
|
|
truncation=True, |
|
|
max_length=self.max_length, |
|
|
padding="max_length", |
|
|
return_tensors="pt", |
|
|
) |
|
|
|
|
|
input_ids = encoding["input_ids"].squeeze(0) |
|
|
attention_mask = encoding["attention_mask"].squeeze(0) |
|
|
|
|
|
|
|
|
labels = input_ids.clone() |
|
|
labels[attention_mask == 0] = -100 |
|
|
|
|
|
return { |
|
|
"input_ids": input_ids, |
|
|
"attention_mask": attention_mask, |
|
|
"labels": labels, |
|
|
} |
|
|
|
|
|
|
|
|
class ChessDataCollator: |
|
|
"""Data collator for chess games.""" |
|
|
|
|
|
def __init__(self, tokenizer, max_length: int = 256): |
|
|
self.tokenizer = tokenizer |
|
|
self.max_length = max_length |
|
|
|
|
|
def __call__(self, features: List[Dict]) -> Dict[str, torch.Tensor]: |
|
|
input_ids = torch.stack([f["input_ids"] for f in features]) |
|
|
attention_mask = torch.stack([f["attention_mask"] for f in features]) |
|
|
labels = torch.stack([f["labels"] for f in features]) |
|
|
return { |
|
|
"input_ids": input_ids, |
|
|
"attention_mask": attention_mask, |
|
|
"labels": labels, |
|
|
} |
|
|
|
|
|
|
|
|
def create_train_val_datasets( |
|
|
tokenizer, |
|
|
dataset_name: str = "dlouapre/lichess_2025-01_1M", |
|
|
max_length: int = 256, |
|
|
train_samples: Optional[int] = None, |
|
|
val_samples: int = 5000, |
|
|
val_ratio: float = 0.05, |
|
|
): |
|
|
from datasets import load_dataset |
|
|
|
|
|
full_dataset = load_dataset(dataset_name, split="train") |
|
|
total = len(full_dataset) |
|
|
|
|
|
if train_samples is not None: |
|
|
n_train = min(train_samples, total - val_samples) |
|
|
else: |
|
|
n_train = int(total * (1 - val_ratio)) |
|
|
|
|
|
n_val = min(val_samples, total - n_train) |
|
|
|
|
|
train_data = full_dataset.select(range(n_train)) |
|
|
val_data = full_dataset.select(range(n_train, n_train + n_val)) |
|
|
|
|
|
train_dataset = ChessDataset(tokenizer=tokenizer, dataset_name=dataset_name, max_length=max_length) |
|
|
train_dataset.data = train_data |
|
|
|
|
|
val_dataset = ChessDataset(tokenizer=tokenizer, dataset_name=dataset_name, max_length=max_length) |
|
|
val_dataset.data = val_data |
|
|
|
|
|
return train_dataset, val_dataset |
|
|
|
|
|
|
|
|
def stream_games(dataset_name: str = "dlouapre/lichess_2025-01_1M", split: str = "train", column: str = "text") -> Iterator[str]: |
|
|
"""Stream games for memory-efficient processing.""" |
|
|
from datasets import load_dataset |
|
|
|
|
|
dataset = load_dataset(dataset_name, split=split, streaming=True) |
|
|
for example in dataset: |
|
|
yield example[column] |
|
|
|
|
|
|
|
|
def analyze_dataset_statistics(dataset_name: str = "dlouapre/lichess_2025-01_1M", max_samples: int = 10000) -> Dict: |
|
|
"""Analyze chess dataset statistics.""" |
|
|
from collections import Counter |
|
|
from datasets import load_dataset |
|
|
|
|
|
dataset = load_dataset(dataset_name, split="train") |
|
|
dataset = dataset.select(range(min(max_samples, len(dataset)))) |
|
|
|
|
|
game_lengths = [] |
|
|
move_counts = Counter() |
|
|
opening_moves = Counter() |
|
|
|
|
|
for example in dataset: |
|
|
moves = example["text"].strip().split() |
|
|
game_lengths.append(len(moves)) |
|
|
move_counts.update(moves) |
|
|
if len(moves) >= 4: |
|
|
opening = " ".join(moves[:4]) |
|
|
opening_moves[opening] += 1 |
|
|
|
|
|
return { |
|
|
"total_games": len(dataset), |
|
|
"avg_game_length": sum(game_lengths) / len(game_lengths), |
|
|
"min_game_length": min(game_lengths), |
|
|
"max_game_length": max(game_lengths), |
|
|
"unique_moves": len(move_counts), |
|
|
"most_common_moves": move_counts.most_common(20), |
|
|
"most_common_openings": opening_moves.most_common(10), |
|
|
} |
|
|
|