yetrun's picture
ver1: 实现深度学习训练框架,支持 Wiki GPT 与诗歌生成双任务
a5fd608
"""数据集抽象基类模块
定义 DataBundle 抽象基类,统一数据集的接口规范。
每个具体的数据集(如 Wiki、诗歌)都应该继承此类并实现相应方法。
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Callable, Optional
import tensorflow as tf
@dataclass
class TokenizerBundle:
"""分词器信息包装类
将分词器相关的属性打包在一起,简化 DataBundle 接口。
"""
tokenizer: Callable
decode: Callable
end_of_text: int
vocab_size: int
vocab_path: str = ""
@dataclass
class DataBundle(ABC):
"""数据集抽象基类
将数据加载、分词、统计等功能绑定在一起,提供统一的数据集接口。
Usage:
dataset = WikiDataset(data_dir="~/data/wiki")
doc_ds = dataset.doc_ds()
tokens_ds = dataset.tokens_ds(seq_length=256, batch_size=32)
dataset.stat()
"""
data_dir: str
sequence_length: int = 256
@abstractmethod
def doc_ds(self) -> tf.data.Dataset:
"""返回原始文档数据集
Returns:
TensorFlow Dataset,每个元素是一个文档字符串
"""
pass
@abstractmethod
def tokens_ds(self, seq_length: int, batch_size: int) -> tf.data.Dataset:
"""返回 tokenized 数据集
将原始文档转换为 token ID 序列,并分割为训练样本。
Args:
seq_length: 序列长度
batch_size: 批次大小
Returns:
TensorFlow Dataset,每个元素是 (input_ids, target_ids) 对
"""
pass
@abstractmethod
def tokenizer_bundle(self) -> TokenizerBundle:
"""返回分词器信息"""
pass
def stat(self, seq_length: int | None = None) -> None:
"""打印数据集统计信息
Args:
seq_length: 序列长度,用于估算训练样本数
"""
from data.common import collect_stats
info = self.tokenizer_bundle()
stats = collect_stats(
name=self.__class__.__name__, loader=self.doc_ds, tokenizer=info.tokenizer
)
stats.print_report(seq_length=seq_length)