| | |
| |
|
| | import sys |
| | from typing import Dict, Any, Optional, List |
| | from datasets import load_dataset, DatasetDict, Dataset |
| | from transformers import PreTrainedTokenizerBase |
| | from .utils import logger |
| |
|
| | |
| | |
| | |
| |
|
| | def load_and_prepare_dataset( |
| | dataset_repo_id: str, |
| | data_dir: Optional[str], |
| | source_column: str, |
| | target_column: str, |
| | tokenizer: PreTrainedTokenizerBase, |
| | block_size: int, |
| | eval_strategy: str |
| | ) -> DatasetDict: |
| | """Loads dataset, renames column, tokenizes, and optionally groups texts.""" |
| | logger.info(f"Loading dataset from Hub: {dataset_repo_id} (data_dir: {data_dir})") |
| | try: |
| | raw_datasets = load_dataset(dataset_repo_id, data_dir=data_dir) |
| | logger.info(f"Dataset loaded: {raw_datasets}") |
| | except Exception as e: |
| | logger.error(f"Failed to load dataset: {e}", exc_info=True) |
| | sys.exit(1) |
| |
|
| | |
| | |
| | logger.info(f"Renaming column '{source_column}' to '{target_column}' and removing others.") |
| | try: |
| | def rename_and_keep_column(example: Dict[str, Any]) -> Dict[str, Any]: |
| | if source_column not in example: |
| | raise KeyError(f"Source column '{source_column}' not found in example: {list(example.keys())}") |
| | return {target_column: example[source_column]} |
| |
|
| | column_names_to_remove = {} |
| | for split in raw_datasets.keys(): |
| | column_names_to_remove[split] = [name for name in raw_datasets[split].column_names if name != source_column] |
| | |
| | if source_column in column_names_to_remove[split]: |
| | column_names_to_remove[split].remove(source_column) |
| |
|
| |
|
| | processed_datasets = DatasetDict() |
| | for split, original_cols in raw_datasets.items(): |
| | cols_to_remove = [col for col in original_cols.column_names if col != source_column] |
| | processed_datasets[split] = raw_datasets[split].map( |
| | rename_and_keep_column, |
| | batched=False, |
| | remove_columns=cols_to_remove |
| | ) |
| | logger.info(f"Dataset after column renaming: {processed_datasets}") |
| |
|
| | except KeyError as e: |
| | logger.error(f"Error during column renaming: {e}. Ensure '{source_column}' exists.", exc_info=True) |
| | sys.exit(1) |
| | except Exception as e: |
| | logger.error(f"An unexpected error occurred during column renaming/cleanup: {e}", exc_info=True) |
| | sys.exit(1) |
| |
|
| | |
| | logger.info("Tokenizing dataset...") |
| | def tokenize_function(examples: Dict[str, List[str]]) -> Dict[str, List[Any]]: |
| | |
| | return tokenizer(examples[target_column], truncation=True, max_length=block_size if block_size else None) |
| |
|
| |
|
| | try: |
| | tokenized_datasets = processed_datasets.map( |
| | tokenize_function, |
| | batched=True, |
| | remove_columns=processed_datasets["train"].column_names, |
| | desc="Running tokenizer on dataset", |
| | ) |
| | logger.info("Tokenization complete.") |
| | except Exception as e: |
| | logger.error(f"Error during tokenization: {e}", exc_info=True) |
| | sys.exit(1) |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | logger.info(f"Processed dataset structure (tokenized only): {tokenized_datasets}") |
| | return tokenized_datasets |