| | |
| | """ |
| | Preprocess multilingual Wikipedia data for model training. |
| | |
| | This script performs the following steps: |
| | 1. Downloads the Wikimedia Wikipedia dataset for the specified languages |
| | (https://huggingface.co/datasets/wikimedia/wikipedia). |
| | 2. Tokenizes the dataset using a Hugging Face tokenizer corresponding |
| | to the specified model. |
| | 3. Aggregates token IDs up to a target number of tokens per language. |
| | 4. Saves the tokenized data as a PyTorch tensor for later use. |
| | |
| | Usage example: |
| | python load_datas.py \ |
| | --languages en zh vi \ |
| | --model-id meta-llama/Llama-3.1-8B-Instruct \ |
| | --tokenizer meta-llama/Llama-3.1-8B-Instruct \ |
| | --output-dir train-data |
| | """ |
| | |
| | from datasets import load_dataset |
| | from transformers import AutoTokenizer |
| | import torch |
| | import os |
| | from tqdm import tqdm |
| | import multiprocessing |
| | from functools import partial |
| | import argparse |
| |
|
| | NUM_PROC_BASE = max(1, os.cpu_count() // 2 if os.cpu_count() else 1) |
| | TARGET_TOKENS_PER_LANGUAGE = 100_000_000 |
| | DATE_SNAPSHOT = "20231101" |
| |
|
| | def tokenize_function(examples, tokenizer): |
| | output = tokenizer( |
| | examples["text"], |
| | add_special_tokens=False, |
| | truncation=False, |
| | padding=False, |
| | ) |
| | return {"input_ids": output.input_ids} |
| |
|
| | def build_and_save( |
| | lang, |
| | model_id, |
| | tokenizer_name, |
| | output_dir, |
| | num_proc_map=NUM_PROC_BASE |
| | ): |
| | print(f"Starting data processing for language: {lang}") |
| |
|
| | train_filename_base = f"id.{lang}.train.{model_id.replace('/', '_')}" |
| | train_output_path = os.path.join(output_dir, train_filename_base) |
| |
|
| | try: |
| | ds = load_dataset("wikimedia/wikipedia", f"{DATE_SNAPSHOT}.{lang}", split="train", trust_remote_code=True) |
| | if len(ds) == 0: |
| | print(f"Warning: Dataset for {lang} is empty. Skipping.") |
| | return |
| | except Exception as e: |
| | print(f"Error loading dataset for {lang}: {e}") |
| | raise |
| |
|
| | try: |
| | tokenizer = AutoTokenizer.from_pretrained( |
| | tokenizer_name, |
| | use_fast=True, |
| | trust_remote_code=True, |
| | ) |
| | except Exception as e: |
| | print(f"Error loading tokenizer '{tokenizer_name}': {e}") |
| | raise |
| |
|
| | tokenization_func_with_tokenizer = partial(tokenize_function, tokenizer=tokenizer) |
| |
|
| | tokenized_ds = ds.map( |
| | tokenization_func_with_tokenizer, |
| | batched=True, |
| | num_proc=num_proc_map, |
| | remove_columns=ds.column_names, |
| | desc=f"Tokenizing {lang}" |
| | ) |
| |
|
| | all_document_token_lists = [] |
| | for processed_example in tqdm(tokenized_ds, desc=f"Collecting token lists for {lang}"): |
| | token_list_for_one_doc = processed_example['input_ids'] |
| | if isinstance(token_list_for_one_doc, list): |
| | all_document_token_lists.append(token_list_for_one_doc) |
| |
|
| | if not all_document_token_lists: |
| | print(f"Warning: No token sequences found for {lang} after tokenization. Skipping.") |
| | return |
| |
|
| | final_token_ids = [] |
| | collected_tokens_count = 0 |
| | for doc_tokens_list in tqdm(all_document_token_lists, desc=f"Aggregating tokens for {lang}"): |
| | if not doc_tokens_list: |
| | continue |
| |
|
| | current_doc_token_count = len(doc_tokens_list) |
| |
|
| | if collected_tokens_count + current_doc_token_count <= TARGET_TOKENS_PER_LANGUAGE: |
| | final_token_ids.extend(doc_tokens_list) |
| | collected_tokens_count += current_doc_token_count |
| | else: |
| | remaining_needed = TARGET_TOKENS_PER_LANGUAGE - collected_tokens_count |
| | final_token_ids.extend(doc_tokens_list[:remaining_needed]) |
| | collected_tokens_count += remaining_needed |
| | break |
| |
|
| | if collected_tokens_count >= TARGET_TOKENS_PER_LANGUAGE: |
| | break |
| |
|
| | del all_document_token_lists |
| | del tokenized_ds |
| | del ds |
| |
|
| | if collected_tokens_count == 0: |
| | print(f"Warning: Zero tokens collected for {lang}. Skipping save.") |
| | return |
| |
|
| | if collected_tokens_count < TARGET_TOKENS_PER_LANGUAGE: |
| | print(f"Warning: Language {lang} has only {collected_tokens_count:,} tokens, " |
| | f"which is less than the target of {TARGET_TOKENS_PER_LANGUAGE:,}.") |
| |
|
| | full_tensor = torch.tensor(final_token_ids, dtype=torch.long) |
| | del final_token_ids |
| |
|
| | os.makedirs(output_dir, exist_ok=True) |
| | torch.save(full_tensor, train_output_path) |
| | print(f"Saved {full_tensor.numel():,} tokens for {lang}.") |
| | del full_tensor |
| |
|
| | def run_job(args): |
| | lang, model_id, tokenizer_name, output_dir, num_proc_map = args |
| | print(f"Processing language: {lang} (PID: {os.getpid()})") |
| | try: |
| | build_and_save( |
| | lang=lang, |
| | model_id=model_id, |
| | tokenizer_name=tokenizer_name, |
| | output_dir=output_dir, |
| | num_proc_map=num_proc_map |
| | ) |
| | return lang, True, None |
| | except Exception as e: |
| | import traceback |
| | traceback.print_exc() |
| | return lang, False, str(e) |
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser(description="Preprocess Wikipedia data for multiple languages.") |
| | parser.add_argument( |
| | "--languages", type=str, default='en,zh,eu,ga', |
| | help="Comma-separated list of languages to process, e.g., 'en,zh,fr'" |
| | ) |
| | parser.add_argument("--model-id", type=str, required=True, help="Model identifier (used for file naming).") |
| | parser.add_argument("--tokenizer", type=str, required=True, help="Tokenizer name or path.") |
| | parser.add_argument("--output-dir", type=str, default="train-data", help="Where to store tokenized tensors.") |
| | parser.add_argument("--max-concurrent", type=int, default=6, help="Max concurrent processes.") |
| | args = parser.parse_args() |
| |
|
| | args.languages = [lang.strip() for lang in args.languages.split(',') if lang.strip()] |
| |
|
| | MAX_CONCURRENT_LANGUAGES = args.max_concurrent |
| | NUM_MAP_PROC_PER_LANG = max(1, NUM_PROC_BASE // MAX_CONCURRENT_LANGUAGES if MAX_CONCURRENT_LANGUAGES > 0 else NUM_PROC_BASE) |
| |
|
| | print(f"Starting batch processing for {len(args.languages)} languages.") |
| |
|
| | job_args_list = [ |
| | (lang, args.model_id, args.tokenizer, args.output_dir, NUM_MAP_PROC_PER_LANG) |
| | for lang in args.languages |
| | ] |
| |
|
| | successful_langs = [] |
| | failed_langs_with_errors = {} |
| |
|
| | with multiprocessing.Pool(processes=MAX_CONCURRENT_LANGUAGES) as pool: |
| | results_iterable = pool.imap_unordered(run_job, job_args_list) |
| | for result in tqdm(results_iterable, total=len(args.languages), desc="Overall Language Progress"): |
| | lang_processed, success, error_msg = result |
| | if success: |
| | successful_langs.append(lang_processed) |
| | else: |
| | failed_langs_with_errors[lang_processed] = error_msg |
| |
|
| | print("Batch processing finished.") |
| | print(f"Successfully processed: {', '.join(sorted(successful_langs))}") |
| | if failed_langs_with_errors: |
| | print(f"Failed to process: {', '.join(sorted(failed_langs_with_errors.keys()))}") |
| | for lang_failed, err in failed_langs_with_errors.items(): |
| | print(f" - {lang_failed}: {err}") |
| |
|