Spaces:
Running
Running
| """数据集抽象基类模块 | |
| 定义 DataBundle 抽象基类,统一数据集的接口规范。 | |
| 每个具体的数据集(如 Wiki、诗歌)都应该继承此类并实现相应方法。 | |
| """ | |
| from abc import ABC, abstractmethod | |
| from dataclasses import dataclass, field | |
| from typing import Callable, Optional | |
| import tensorflow as tf | |
| class TokenizerBundle: | |
| """分词器信息包装类 | |
| 将分词器相关的属性打包在一起,简化 DataBundle 接口。 | |
| """ | |
| tokenizer: Callable | |
| decode: Callable | |
| end_of_text: int | |
| vocab_size: int | |
| vocab_path: str = "" | |
| class DataBundle(ABC): | |
| """数据集抽象基类 | |
| 将数据加载、分词、统计等功能绑定在一起,提供统一的数据集接口。 | |
| Usage: | |
| dataset = WikiDataset(data_dir="~/data/wiki") | |
| doc_ds = dataset.doc_ds() | |
| tokens_ds = dataset.tokens_ds(seq_length=256, batch_size=32) | |
| dataset.stat() | |
| """ | |
| data_dir: str | |
| sequence_length: int = 256 | |
| def doc_ds(self) -> tf.data.Dataset: | |
| """返回原始文档数据集 | |
| Returns: | |
| TensorFlow Dataset,每个元素是一个文档字符串 | |
| """ | |
| pass | |
| def tokens_ds(self, seq_length: int, batch_size: int) -> tf.data.Dataset: | |
| """返回 tokenized 数据集 | |
| 将原始文档转换为 token ID 序列,并分割为训练样本。 | |
| Args: | |
| seq_length: 序列长度 | |
| batch_size: 批次大小 | |
| Returns: | |
| TensorFlow Dataset,每个元素是 (input_ids, target_ids) 对 | |
| """ | |
| pass | |
| def tokenizer_bundle(self) -> TokenizerBundle: | |
| """返回分词器信息""" | |
| pass | |
| def stat(self, seq_length: int | None = None) -> None: | |
| """打印数据集统计信息 | |
| Args: | |
| seq_length: 序列长度,用于估算训练样本数 | |
| """ | |
| from data.common import collect_stats | |
| info = self.tokenizer_bundle() | |
| stats = collect_stats( | |
| name=self.__class__.__name__, loader=self.doc_ds, tokenizer=info.tokenizer | |
| ) | |
| stats.print_report(seq_length=seq_length) | |