""" Utilities for reading and writing data files. """ import multiprocessing as mp import os from pathlib import PosixPath from typing import Callable, Dict, List, Optional, Tuple, Union from datasets import load_dataset from torch.utils.data import Dataset from transformers import ( DataCollatorForLanguageModeling, PreTrainedTokenizer, default_data_collator, ) from . import config # To avoid huggingface warning os.environ["TOKENIZERS_PARALLELISM"] = "false" UBUNTU_ROOT = str(config.root) def load_datasets( tokenizer: PreTrainedTokenizer, train_data: Union[str, PosixPath], eval_data: Optional[Union[str, PosixPath]] = None, test_data: Union[str, PosixPath] = None, file_type: str = "csv", delimiter: str = "\t", seq_key: str = "sequence", shuffle: bool = True, filter_empty: bool = False, n_workers: int = mp.cpu_count(), **kwargs, ) -> Dataset: """Load and cache data using Huggingface datasets library Args: tokenizer (PreTrainedTokenizer): tokenizer to apply to the sequences train_data (Union[str, PosixPath]): location of training data eval_data (Union[str, PosixPath], optional): location of evaluation data. Defaults to None. test_data (Union[str, PosixPath], optional): location of test data. Defaults to None. file_type (str, optional): type of file. Possible values are 'text' and 'csv'. Defaults to 'csv'. delimiter (str, optional): Defaults to '\t'. seq_key (str, optional): Column name of sequence data Can be 'sequence', 'seq', or 'text'. Defaults to 'sequence'. shuffle (bool, optional): Whether to shuffle the dataset. Defaults to True. filter_empty (bool, optional): Whether to filter out empty sequences. Defaults to False. NOTE: This completes an additional iteration, which can be time-consuming. Only enable if you have reason to believe that preprocessing steps will result in empty sequences. transformation (str, optional): type of transformation to apply. Options are 'log', 'boxcox'. Defaults to None. log_offset (Union[float, int]): value to offset gene expression values by before log transforming. Defaults to 1. preprocessor (BaseEstimator): preprocessor Yeoh-Johnson transformation. tissue_subset (Union[str, int, list], optional): tissues to subset labels to. Defaults to None. nshards (int, optional): Number of shards to divide data into, only keeping the first. Defaults to None. threshold (float, optional): filter out rows where all labels are below `threshold`. OR if `discretize` is True, see `discretize`. Defaults to None. discretize (bool, optional): set gene expression values below `threshold` to 0, above `threshold` to 1. kmer (int, optional): whether to run the kmer flip experiment and if so, how large kmers to flip. Defaults to None. n_workers (int, optional): number of processes to use for preprocessing. Defaults to `mp.cpu_count()` (number of available CPUs). position_buckets (Tuple[int], optional): the different buckets for the bucketed positional importance experiment Returns: Dataset """ data_files = {"train": str(train_data)} if eval_data: data_files["eval"] = str(eval_data) if test_data: data_files["test"] = str(test_data) if file_type == "csv": kwargs.update({"delimiter": delimiter}) datasets = load_dataset(file_type, data_files=data_files, **kwargs) # Tokenizing preprocess_fn = make_preprocess_function(tokenizer, seq_key=seq_key) # print("Tokenizing...") datasets = datasets.map(preprocess_fn, batched=True, num_proc=n_workers) if filter_empty: datasets = datasets.filter(filter_empty_sequence) if shuffle: seed = config.settings["random_seed"] datasets = datasets.shuffle(seeds={"train": seed, "eval": seed, "test": seed}) return datasets def make_preprocess_function(tokenizer, seq_key: str = "sequence") -> callable: """Make a preprocessing function that selects the appropriate column and tokenizes it. Args: tokenizer (PreTrainedTokenizer): tokenizer to apply to each sequence seq_key (str, optional): column name of the text data. Defaults to 'sequence'. Returns: callable: preprocessing function """ def preprocess_function(examples): if seq_key: seqs = examples[seq_key] else: seqs = examples return tokenizer( seqs, max_length=tokenizer.model_max_length, truncation=True, padding="max_length", ) return preprocess_function def filter_empty_sequence(example: dict) -> bool: """Filter out empty sequences.""" # sum(example['attention_mask']) gives the number of tokens, including SOS and EOS return sum(example["attention_mask"]) > 2 def load_data_collator(model_type: str, tokenizer=None, mlm_prob=None): if model_type == "language-model": assert ( tokenizer is not None ), "tokenizer must not be None if model is type language-model" assert ( mlm_prob is not None ), "mlm_prob must not be None if model is type language-model" return DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=True, mlm_probability=mlm_prob ) else: return default_data_collator