scan-8192-16M-test / preprocess.py
zaydzuhri's picture
Training in progress, step 5000
a0806ea verified
raw
history blame
3.9 kB
# -*- coding: utf-8 -*-
from __future__ import annotations
import argparse
import logging
from itertools import chain
from typing import Any, Dict, List, Optional
from datasets import load_dataset
from transformers import AutoTokenizer
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def tokenize(
examples: Dict[str, List[Any]],
tokenizer: AutoTokenizer,
context_length: int
) -> Dict[str, List[List[int]]]:
"""
Tokenize the input text and split into chunks of specified context length.
Args:
examples:
Dictionary containing the input text.
tokenizer:
Initialized tokenizer.
context_length:
Length of each context chunk.
Returns:
Dictionary containing tokenized and chunked input ids
"""
text = examples['text']
input_ids = tokenizer(text)['input_ids']
input_ids = list(chain(*input_ids))
total_length = len(input_ids)
total_length = (total_length // context_length) * context_length
# The last chunk smaller than context_length will be discarded
return {'input_ids': [input_ids[i:i+context_length] for i in range(0, total_length, context_length)]}
def preprocess(
dataset: str,
name: Optional[str] = None,
split: str = 'train',
output: str = 'data',
model: str = 'mistralai/Mistral-7B-v0.1',
num_proc: int = 64,
context_length: int = 8192
) -> None:
"""
Load, tokenize, and save the processed dataset.
Args:
dataset:
Path or name of the dataset.
name:
Name of the dataset configuration.
split:
Dataset split to process.
output:
Output directory.
model:
Model name for tokenizer.
num_proc:
Number of processes for parallel processing.
context_length:
Context length for tokenization.
"""
tokenized_path = f'{output}/{dataset}/{name}/{split}' if name is not None else f'{output}/{dataset}/{split}'
logging.info(f'Initializing tokenizer of {model}')
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
logging.info(f'Tokenizer initialized: {tokenizer}')
logging.info(f'Loading dataset: {dataset}')
dataset = load_dataset(dataset, name=name, split=split)
remove_columns = list(next(iter(dataset)).keys())
logging.info('Tokenizing and processing dataset')
dataset = dataset.map(
lambda examples: tokenize(examples, tokenizer, context_length),
batched=True,
remove_columns=remove_columns,
num_proc=num_proc,
desc="Running tokenizer on dataset"
)
logging.info(f'Saving processed dataset to {tokenized_path}')
dataset.save_to_disk(tokenized_path, num_proc=num_proc)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Preprocess and tokenize dataset")
parser.add_argument("--dataset", default="HuggingFaceFW/fineweb-edu", help="Path or name of the dataset")
parser.add_argument("--name", default=None, help="Name of the dataset configuration")
parser.add_argument("--split", default="train", help="Dataset split to process")
parser.add_argument("--output", default="data", help="Output directory")
parser.add_argument("--model", default="mistralai/Mistral-7B-v0.1", help="Model name for tokenizer")
parser.add_argument("--num_proc", type=int, default=64, help="Number of processes for parallel processing")
parser.add_argument("--context_length", type=int, default=8192, help="Context length for tokenization")
args = parser.parse_args()
preprocess(
dataset=args.dataset,
name=args.name,
split=args.split,
output=args.output,
model=args.model,
num_proc=args.num_proc,
context_length=args.context_length
)