chess-Sunxt25 / data.py
Sunxt25's picture
Upload 4 files
519a223 verified
"""
Data loading utilities for the Chess Challenge using color+piece/from/to tokenizer.
"""
from __future__ import annotations
from typing import Dict, Iterator, List, Optional
import torch
from torch.utils.data import Dataset
class ChessDataset(Dataset):
"""
PyTorch Dataset for chess games with color+piece/from/to tokenizer.
"""
def __init__(
self,
tokenizer,
dataset_name: str = "dlouapre/lichess_2025-01_1M",
split: str = "train",
column: str = "text",
max_length: int = 256,
max_samples: Optional[int] = None,
):
from datasets import load_dataset
self.tokenizer = tokenizer
self.max_length = max_length
self.column = column
# Load dataset
dataset = load_dataset(dataset_name, split=split)
if max_samples is not None:
dataset = dataset.select(range(min(max_samples, len(dataset))))
self.data = dataset
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
game = self.data[idx][self.column]
# Prepend BOS token
game_with_bos = self.tokenizer.bos_token + " " + game
# Tokenize: tokenizer 已经拆成 color+piece/from/to
encoding = self.tokenizer(
game_with_bos,
truncation=True,
max_length=self.max_length,
padding="max_length",
return_tensors="pt",
)
input_ids = encoding["input_ids"].squeeze(0)
attention_mask = encoding["attention_mask"].squeeze(0)
# Labels = input_ids (shift internally)
labels = input_ids.clone()
labels[attention_mask == 0] = -100 # ignore padding in loss
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
class ChessDataCollator:
"""Data collator for chess games."""
def __init__(self, tokenizer, max_length: int = 256):
self.tokenizer = tokenizer
self.max_length = max_length
def __call__(self, features: List[Dict]) -> Dict[str, torch.Tensor]:
input_ids = torch.stack([f["input_ids"] for f in features])
attention_mask = torch.stack([f["attention_mask"] for f in features])
labels = torch.stack([f["labels"] for f in features])
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
def create_train_val_datasets(
tokenizer,
dataset_name: str = "dlouapre/lichess_2025-01_1M",
max_length: int = 256,
train_samples: Optional[int] = None,
val_samples: int = 5000,
val_ratio: float = 0.05,
):
from datasets import load_dataset
full_dataset = load_dataset(dataset_name, split="train")
total = len(full_dataset)
if train_samples is not None:
n_train = min(train_samples, total - val_samples)
else:
n_train = int(total * (1 - val_ratio))
n_val = min(val_samples, total - n_train)
train_data = full_dataset.select(range(n_train))
val_data = full_dataset.select(range(n_train, n_train + n_val))
train_dataset = ChessDataset(tokenizer=tokenizer, dataset_name=dataset_name, max_length=max_length)
train_dataset.data = train_data
val_dataset = ChessDataset(tokenizer=tokenizer, dataset_name=dataset_name, max_length=max_length)
val_dataset.data = val_data
return train_dataset, val_dataset
def stream_games(dataset_name: str = "dlouapre/lichess_2025-01_1M", split: str = "train", column: str = "text") -> Iterator[str]:
"""Stream games for memory-efficient processing."""
from datasets import load_dataset
dataset = load_dataset(dataset_name, split=split, streaming=True)
for example in dataset:
yield example[column]
def analyze_dataset_statistics(dataset_name: str = "dlouapre/lichess_2025-01_1M", max_samples: int = 10000) -> Dict:
"""Analyze chess dataset statistics."""
from collections import Counter
from datasets import load_dataset
dataset = load_dataset(dataset_name, split="train")
dataset = dataset.select(range(min(max_samples, len(dataset))))
game_lengths = []
move_counts = Counter()
opening_moves = Counter()
for example in dataset:
moves = example["text"].strip().split()
game_lengths.append(len(moves))
move_counts.update(moves)
if len(moves) >= 4:
opening = " ".join(moves[:4])
opening_moves[opening] += 1
return {
"total_games": len(dataset),
"avg_game_length": sum(game_lengths) / len(game_lengths),
"min_game_length": min(game_lengths),
"max_game_length": max(game_lengths),
"unique_moves": len(move_counts),
"most_common_moves": move_counts.most_common(20),
"most_common_openings": opening_moves.most_common(10),
}