|
|
import os |
|
|
import sys |
|
|
from transformers import AutoTokenizer |
|
|
from .config import Config |
|
|
from .dataset import DataProcessor |
|
|
|
|
|
def main(): |
|
|
print("⏳ 开始下载并处理数据...") |
|
|
|
|
|
|
|
|
if not os.path.exists(Config.DATA_DIR): |
|
|
os.makedirs(Config.DATA_DIR) |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(Config.BASE_MODEL) |
|
|
processor = DataProcessor(tokenizer) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset = processor.get_processed_dataset() |
|
|
|
|
|
save_path = os.path.join(Config.DATA_DIR, "processed_dataset") |
|
|
print(f"💾 正在保存处理后的数据集到: {save_path}") |
|
|
dataset.save_to_disk(save_path) |
|
|
|
|
|
print("✅ 数据保存完成!") |
|
|
print(f" Train set size: {len(dataset['train'])}") |
|
|
print(f" Test set size: {len(dataset['test'])}") |
|
|
print(" 下次加载可直接使用: from datasets import load_from_disk") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|