"""Module containing data utilities""" import functools import hashlib import logging from collections import defaultdict from pathlib import Path from typing import Dict, List, Tuple, Union import torch from datasets import ( Dataset, DatasetDict, concatenate_datasets, load_dataset, load_from_disk, ) from huggingface_hub import hf_hub_download from torch.utils.data import RandomSampler from transformers import PreTrainedTokenizerBase from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.datasets import TokenizedPromptDataset from axolotl.prompt_strategies import load from axolotl.prompt_tokenizers import ( AlpacaMultipleChoicePromptTokenizingStrategy, AlpacaPromptTokenizingStrategy, AlpacaReflectionPTStrategy, GPTeacherPromptTokenizingStrategy, JeopardyPromptTokenizingStrategy, OpenAssistantPromptTokenizingStrategy, SummarizeTLDRPromptTokenizingStrategy, ) from axolotl.prompters import ( AlpacaPrompter, GPTeacherPrompter, JeopardyPrompter, MultipleChoiceConcisePrompter, MultipleChoiceExplainPrompter, Prompter, ReflectAlpacaPrompter, SummarizeTLDRPrompter, UnsupportedPrompter, ) from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_main_process, zero_first from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.trainer import ( calculate_total_num_steps, process_datasets_for_packing, process_pretraining_datasets_for_packing, ) LOG = logging.getLogger("axolotl") def md5(to_hash: str, encoding: str = "utf-8") -> str: try: return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest() except TypeError: return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec def prepare_dataset(cfg, tokenizer): prompters = [] if not cfg.pretraining_dataset: with zero_first(is_main_process()): train_dataset, eval_dataset, prompters = load_prepare_datasets( tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH ) else: path = cfg.pretraining_dataset name = None if isinstance(cfg.pretraining_dataset, list) and isinstance( cfg.pretraining_dataset[0], dict ): path = cfg.pretraining_dataset[0]["path"] name = cfg.pretraining_dataset[0]["name"] train_dataset = load_pretraining_dataset( path, tokenizer, cfg, name=name, max_tokens=cfg.sequence_len, seed=cfg.seed or 42, ) # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230 train_dataset = train_dataset.with_format("torch") eval_dataset = None return train_dataset, eval_dataset, cfg.max_steps, prompters if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False: total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False) if total_eval_steps == 0: raise ValueError( "eval dataset split is too small for sample_packing. You should set `eval_sample_packing: False`. " ) if cfg.max_steps: total_num_steps = min( calculate_total_num_steps(cfg, train_dataset), cfg.max_steps ) LOG.info(f"Maximum number of steps set at {total_num_steps}") else: total_num_steps = calculate_total_num_steps(cfg, train_dataset) return train_dataset, eval_dataset, total_num_steps, prompters def load_tokenized_prepared_datasets( tokenizer, cfg, default_dataset_prepared_path ) -> Tuple[DatasetDict, List[Prompter]]: tokenizer_name = tokenizer.__class__.__name__ ds_hash = str( md5( ( str(cfg.sequence_len) + "@" + "|".join( sorted( [ f"{d.path}:{d.type}:{d.shards}:{d.conversation}" for d in cfg.datasets ] ) ) + "|" + tokenizer_name ) ) ) prepared_ds_path = ( Path(cfg.dataset_prepared_path) / ds_hash if cfg.dataset_prepared_path else Path(default_dataset_prepared_path) / ds_hash ) dataset = None prompters = [] use_auth_token = cfg.hf_use_auth_token try: if cfg.push_dataset_to_hub: dataset = load_dataset( f"{cfg.push_dataset_to_hub}/{ds_hash}", token=use_auth_token, ) dataset = dataset["train"] except Exception: # pylint: disable=broad-except # nosec pass if dataset: ... elif ( cfg.dataset_prepared_path and any(prepared_ds_path.glob("*")) and not cfg.is_preprocess ): LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...") dataset = load_from_disk(str(prepared_ds_path)) LOG.info("Prepared dataset loaded from disk...") else: LOG.info(f"Unable to find prepared dataset in {prepared_ds_path}") LOG.info("Loading raw datasets...") if not cfg.is_preprocess: LOG.warning( "Processing datasets during training can lead to VRAM instability. Please pre-process your dataset" ) if cfg.seed: seed = cfg.seed else: LOG.info("No seed provided, using default seed of 42") seed = 42 datasets = [] def for_d_in_datasets(dataset_configs): for dataset in dataset_configs: if dataset.name and isinstance(dataset.name, list): for name in dataset.name: yield DictDefault({**dataset, "name": name}) else: yield dataset # pylint: disable=invalid-name for config_dataset in for_d_in_datasets(cfg.datasets): ds: Union[Dataset, DatasetDict] = None ds_from_hub = False try: load_dataset( config_dataset.path, name=config_dataset.name, streaming=True, token=use_auth_token, ) ds_from_hub = True except (FileNotFoundError, ConnectionError): pass ds_from_cloud = False storage_options = {} remote_file_system = None if config_dataset.path.startswith("s3://"): try: import aiobotocore.session # type: ignore import s3fs # type: ignore except ImportError as exc: raise ImportError( "s3:// paths require aiobotocore and s3fs to be installed" ) from exc # Takes credentials from ~/.aws/credentials for default profile s3_session = aiobotocore.session.AioSession(profile="default") storage_options = {"session": s3_session} remote_file_system = s3fs.S3FileSystem(**storage_options) elif config_dataset.path.startswith( "gs://" ) or config_dataset.path.startswith("gcs://"): try: import gcsfs # type: ignore except ImportError as exc: raise ImportError( "gs:// or gcs:// paths require gcsfs to be installed" ) from exc # gcsfs will use default credentials from the environment else anon # https://gcsfs.readthedocs.io/en/latest/#credentials storage_options = {"token": None} remote_file_system = gcsfs.GCSFileSystem(**storage_options) # TODO: Figure out how to get auth creds passed # elif config_dataset.path.startswith("adl://") or config_dataset.path.startswith("abfs://"): # try: # import adlfs # except ImportError as exc: # raise ImportError( # "adl:// or abfs:// paths require adlfs to be installed" # ) from exc # # Gen 1 # storage_options = { # "tenant_id": TENANT_ID, # "client_id": CLIENT_ID, # "client_secret": CLIENT_SECRET, # } # # Gen 2 # storage_options = { # "account_name": ACCOUNT_NAME, # "account_key": ACCOUNT_KEY, # } # remote_file_system = adlfs.AzureBlobFileSystem(**storage_options) try: if remote_file_system and remote_file_system.exists( config_dataset.path ): ds_from_cloud = True except (FileNotFoundError, ConnectionError): pass # prefer local dataset, even if hub exists local_path = Path(config_dataset.path) if local_path.exists(): if local_path.is_dir(): # TODO dirs with arrow or parquet files could be loaded with `load_from_disk` ds = load_dataset( config_dataset.path, name=config_dataset.name, data_files=config_dataset.data_files, streaming=False, split=None, ) elif local_path.is_file(): ds_type = get_ds_type(config_dataset) ds = load_dataset( ds_type, name=config_dataset.name, data_files=config_dataset.path, streaming=False, split=None, ) else: raise ValueError( "unhandled dataset load: local path exists, but is neither a directory or a file" ) elif ds_from_hub: ds = load_dataset( config_dataset.path, name=config_dataset.name, streaming=False, data_files=config_dataset.data_files, token=use_auth_token, ) elif ds_from_cloud and remote_file_system: if remote_file_system.isdir(config_dataset.path): ds = load_from_disk( config_dataset.path, storage_options=storage_options, ) elif remote_file_system.isfile(config_dataset.path): ds_type = get_ds_type(config_dataset) ds = load_dataset( ds_type, name=config_dataset.name, data_files=config_dataset.path, streaming=False, split=None, storage_options=storage_options, ) else: if isinstance(config_dataset.data_files, str): fp = hf_hub_download( repo_id=config_dataset.path, repo_type="dataset", filename=config_dataset.data_files, ) elif isinstance(config_dataset.data_files, list): fp = [] for file in config_dataset.data_files: fp.append( hf_hub_download( repo_id=config_dataset.path, repo_type="dataset", filename=file, ) ) else: raise ValueError( "data_files must be either a string or list of strings" ) ds = load_dataset( "json", name=config_dataset.name, data_files=fp, streaming=False, split=None, ) if not ds: raise ValueError("unhandled dataset load") # support for using a subset of the data if config_dataset.shards: if "train" in ds: ds = ds.shuffle(seed=seed)["train"].shard( num_shards=config_dataset.shards, index=0 ) else: ds = ds.shuffle(seed=seed).shard( num_shards=config_dataset.shards, index=0 ) d_base_type = d_prompt_style = None d_type = config_dataset.type if isinstance(d_type, str): d_type_split = d_type.split(":") d_base_type = d_type_split[0] d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None if "train" in ds: ds = ds["train"] elif ( isinstance(ds, DatasetDict) and config_dataset.train_on_split and config_dataset.train_on_split in ds ): ds = ds[config_dataset.train_on_split] elif isinstance(ds, DatasetDict): raise ValueError( f"no train split found for dataset {config_dataset.path}, you may specify a split with 'train_on_split: `" ) dataset_wrapper, dataset_prompter = get_dataset_wrapper( config_dataset=config_dataset, dataset=ds, tokenizer=tokenizer, cfg=cfg, d_base_type=d_base_type, d_prompt_style=d_prompt_style, ) datasets.append(dataset_wrapper) prompters.append(dataset_prompter) LOG.info("merging datasets") dataset = concatenate_datasets(datasets) if len(datasets) > 1: LOG.info("shuffle merged datasets") dataset = dataset.shuffle(seed=seed) dataset, _ = process_datasets_for_packing(cfg, dataset, None, tokenizer) if cfg.local_rank == 0: LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}") dataset.save_to_disk(prepared_ds_path) if cfg.push_dataset_to_hub: LOG.info( f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}" ) dataset.push_to_hub( f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True ) return dataset, prompters def get_ds_type(config_dataset: DictDefault): """ Get the dataset type from the path if it's not specified """ ds_type = "json" if config_dataset.ds_type: ds_type = config_dataset.ds_type elif ".parquet" in config_dataset.path: ds_type = "parquet" elif ".arrow" in config_dataset.path: ds_type = "arrow" elif ".csv" in config_dataset.path: ds_type = "csv" elif ".txt" in config_dataset.path: ds_type = "text" return ds_type def load_prepare_datasets( tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path, ) -> Tuple[Dataset, Dataset, List[Prompter]]: dataset, prompters = load_tokenized_prepared_datasets( tokenizer, cfg, default_dataset_prepared_path ) if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None: LOG.info( f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards" ) dataset = dataset.shard( num_shards=cfg.dataset_shard_num, index=cfg.dataset_shard_idx, ) if cfg.val_set_size: # ensure we end up with the same fingerprint by doing rank0 first and being able to cache to_hash_train = ( dataset._fingerprint # pylint: disable=protected-access + "|" + str(cfg.val_set_size) + "|" + "train" + "|" + str(cfg.seed or 42) ) to_hash_test = ( dataset._fingerprint # pylint: disable=protected-access + "|" + str(cfg.val_set_size) + "|" + "test" + "|" + str(cfg.seed or 42) ) train_fingerprint = md5(to_hash_train) test_fingerprint = md5(to_hash_test) dataset = dataset.train_test_split( test_size=cfg.val_set_size, shuffle=False, seed=cfg.seed or 42, train_new_fingerprint=train_fingerprint, test_new_fingerprint=test_fingerprint, ) train_dataset = dataset["train"] eval_dataset = dataset["test"] else: train_dataset = dataset eval_dataset = None return train_dataset, eval_dataset, prompters def get_dataset_wrapper( config_dataset, dataset, tokenizer, cfg, d_base_type, d_prompt_style ): dataset_wrapper = None dataset_prompter = None ds_kwargs = { "process_count": cfg.dataset_processes, "keep_in_memory": cfg.dataset_keep_in_memory is True, } if ( "input_ids" in dataset.features and "attention_mask" in dataset.features and "labels" in dataset.features ): # dataset is already tokenized, just drop it straight in dataset_prompter = UnsupportedPrompter() dataset_wrapper = dataset elif isinstance(config_dataset.type, DictDefault): ds_strategy = load( "user_defined", tokenizer, cfg, config_dataset.type.to_dict() ) dataset_prompter = UnsupportedPrompter() dataset_wrapper = TokenizedPromptDataset( ds_strategy, dataset, **ds_kwargs, ) elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset): dataset_prompter = UnsupportedPrompter() dataset_wrapper = TokenizedPromptDataset( ds_strategy, dataset, **ds_kwargs, ) elif d_base_type == "alpaca": dataset_prompter = AlpacaPrompter(d_prompt_style) ds_strategy = AlpacaPromptTokenizingStrategy( dataset_prompter, tokenizer, cfg.train_on_inputs, cfg.sequence_len, ) ds_wrapper = TokenizedPromptDataset( ds_strategy, dataset, **ds_kwargs, ) dataset_wrapper = ds_wrapper elif d_base_type == "explainchoice": dataset_prompter = MultipleChoiceExplainPrompter(d_prompt_style) ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy( dataset_prompter, tokenizer, cfg.train_on_inputs, cfg.sequence_len, ) ds_wrapper = TokenizedPromptDataset( ds_strategy, dataset, **ds_kwargs, ) dataset_wrapper = ds_wrapper elif d_base_type == "concisechoice": dataset_prompter = MultipleChoiceConcisePrompter(d_prompt_style) ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy( dataset_prompter, tokenizer, cfg.train_on_inputs, cfg.sequence_len, ) ds_wrapper = TokenizedPromptDataset( ds_strategy, dataset, **ds_kwargs, ) dataset_wrapper = ds_wrapper elif d_base_type == "summarizetldr": dataset_prompter = SummarizeTLDRPrompter(d_prompt_style) ds_strategy = SummarizeTLDRPromptTokenizingStrategy( dataset_prompter, tokenizer, cfg.train_on_inputs, cfg.sequence_len, ) ds_wrapper = TokenizedPromptDataset( ds_strategy, dataset, **ds_kwargs, ) dataset_wrapper = ds_wrapper elif d_base_type == "jeopardy": dataset_prompter = JeopardyPrompter(d_prompt_style) ds_strategy = JeopardyPromptTokenizingStrategy( dataset_prompter, tokenizer, cfg.train_on_inputs, cfg.sequence_len, ) ds_wrapper = TokenizedPromptDataset( ds_strategy, dataset, **ds_kwargs, ) dataset_wrapper = ds_wrapper elif d_base_type == "oasst": dataset_prompter = AlpacaPrompter(d_prompt_style) ds_strategy = OpenAssistantPromptTokenizingStrategy( dataset_prompter, tokenizer, cfg.train_on_inputs, cfg.sequence_len, ) ds_wrapper = TokenizedPromptDataset( ds_strategy, dataset, **ds_kwargs, ) dataset_wrapper = ds_wrapper elif d_base_type == "gpteacher": dataset_prompter = GPTeacherPrompter(d_prompt_style) ds_strategy = GPTeacherPromptTokenizingStrategy( dataset_prompter, tokenizer, cfg.train_on_inputs, cfg.sequence_len, ) ds_wrapper = TokenizedPromptDataset( ds_strategy, dataset, **ds_kwargs, ) dataset_wrapper = ds_wrapper elif d_base_type == "reflection": dataset_prompter = ReflectAlpacaPrompter(d_prompt_style) ds_strategy = AlpacaReflectionPTStrategy( dataset_prompter, tokenizer, cfg.train_on_inputs, cfg.sequence_len, ) ds_wrapper = TokenizedPromptDataset( ds_strategy, dataset, **ds_kwargs, ) dataset_wrapper = ds_wrapper else: suffix = "" if ":load_" in config_dataset.type: suffix = f" Did you mean {config_dataset.type.replace(':load_', '.load_')}?" LOG.error( f"unhandled prompt tokenization strategy: {config_dataset.type}. {suffix}" ) raise ValueError( f"unhandled prompt tokenization strategy: {config_dataset.type} {suffix}" ) return dataset_wrapper, dataset_prompter def encode_pretraining( tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str] ) -> Dict[str, List]: res = tokenizer( examples, truncation=True, max_length=max_tokens - 2, add_special_tokens=True, ) # Convert to PyTorch tensors input_ids = [torch.tensor(seq) for seq in res["input_ids"]] attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]] new_input_ids = [] new_attention_mask = [] # Append EOS and PAD tokens to input_ids, and correct attention_mask for i, _ in enumerate(input_ids): input_ids[i] = torch.cat( ( input_ids[i], torch.tensor([tokenizer.eos_token_id, tokenizer.pad_token_id]), ), dim=0, ) attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0) # Concatenate tokens so that their lengths are less than max_tokens buffer_input_ids = torch.tensor([], dtype=torch.long) buffer_attention_mask = torch.tensor([], dtype=torch.long) for ids, mask in zip(input_ids, attention_mask): if buffer_input_ids.numel() == max_tokens: new_input_ids.append(buffer_input_ids) new_attention_mask.append(buffer_attention_mask) buffer_input_ids = torch.tensor([], dtype=torch.long) buffer_attention_mask = torch.tensor([], dtype=torch.long) buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0) buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0) elif buffer_input_ids.numel() + ids.numel() <= max_tokens: buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0) buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0) else: buffer_input_ids = torch.cat( ( buffer_input_ids, torch.full( (max_tokens - buffer_input_ids.numel(),), tokenizer.pad_token_id, dtype=torch.long, ), ), dim=0, ) buffer_attention_mask = torch.cat( ( buffer_attention_mask, torch.full( (max_tokens - buffer_attention_mask.numel(),), 0, dtype=torch.long, ), ), dim=0, ) new_input_ids.append(buffer_input_ids) new_attention_mask.append(buffer_attention_mask) buffer_input_ids = torch.tensor([], dtype=torch.long) buffer_attention_mask = torch.tensor([], dtype=torch.long) buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0) buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0) if buffer_input_ids.numel() > 0: # for any leftover tokens while buffer_input_ids.numel() < max_tokens: # make all sequences equal in size buffer_input_ids = torch.cat( ( buffer_input_ids, torch.full( (max_tokens - buffer_input_ids.numel(),), tokenizer.pad_token_id, dtype=torch.long, ), ), dim=0, ) buffer_attention_mask = torch.cat( ( buffer_attention_mask, torch.full( (max_tokens - buffer_attention_mask.numel(),), 0, dtype=torch.long, ), ), dim=0, ) new_input_ids.append(buffer_input_ids) new_attention_mask.append(buffer_attention_mask) ret = { "input_ids": [seq.tolist() for seq in new_input_ids], "labels": [seq.tolist() for seq in new_input_ids], "attention_mask": [seq.tolist() for seq in new_attention_mask], } LOG.debug(len(ret["input_ids"])) return ret def load_pretraining_dataset(path, tokenizer, cfg, name=None, max_tokens=2048, seed=42): if cfg.sample_packing: collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq( tokenizer, return_tensors="pt", padding=True, pad_to_multiple_of=max_tokens * cfg.micro_batch_size, ) encode = functools.partial( encode_packed_pretraining, tokenizer, collate_fn, max_seq_length=max_tokens, batch_size=cfg.micro_batch_size, ) # set this to 1 so downstream data_loader doesn't try to increase the batch again cfg.micro_batch_size = 1 else: encode = functools.partial(encode_pretraining, tokenizer, max_tokens) dataset = load_dataset(path, streaming=True, split="train", name=name) dataset = dataset.shuffle(seed=seed, buffer_size=10_000) dataset = dataset.map( encode, batched=True, batch_size=10_000, input_columns="text", # remove all the existing columns after mapping since they end up having # a different length than the encoded/tokenized column remove_columns=dataset.features.keys(), ) return dataset def encode_packed_pretraining( tokenizer: PreTrainedTokenizerBase, collate_fn, examples: List[str], max_seq_length: int = 2048, batch_size: int = 4, ) -> Dict[str, List]: # pylint: disable=duplicate-code # tokenize all the examples # rows get split with stride (overlap) res = tokenizer( examples, truncation=True, max_length=max_seq_length - 1, add_special_tokens=True, return_overflowing_tokens=True, stride=256, ) input_ids = [seq + [tokenizer.eos_token_id] for seq in res["input_ids"]] attention_mask = [seq + [1] for seq in res["attention_mask"]] tokenized_examples = { "input_ids": input_ids, "attention_mask": attention_mask, } train_dataset = Dataset.from_dict(tokenized_examples) train_dataset = process_pretraining_datasets_for_packing( train_dataset, max_seq_length ) sampler = MultipackBatchSampler( RandomSampler(train_dataset), batch_size=batch_size, drop_last=True, batch_max_len=batch_size * max_seq_length, lengths=get_dataset_lengths(train_dataset), ) chunked_data = defaultdict(list) for data in sampler: features = train_dataset[data] features["labels"] = features["input_ids"].copy() collated_features = collate_fn(features) for feature in features.keys(): if feature == "length": continue chunked_data[feature].append(collated_features[feature].squeeze(0)) return chunked_data