|
|
""" |
|
|
Module: tokenization_args.py |
|
|
|
|
|
This module defines the `TokenizationArgs` dataclass, which encapsulates all the configurable parameters |
|
|
required for the tokenization process in the TEDDY project. These parameters control how gene expression |
|
|
data and biological annotations are tokenized for training. |
|
|
|
|
|
Main Features: |
|
|
- Provides a structured way to define and manage tokenization arguments. |
|
|
- Supports configuration for gene selection, sequence truncation, and annotation inclusion. |
|
|
- Includes options for handling PerturbSeq-specific flags and preprocessing steps. |
|
|
- Allows for flexible mapping of biological annotations (e.g., disease, tissue, cell type, sex). |
|
|
- Enables reproducibility through random seed control for gene selection. |
|
|
|
|
|
Dependencies: |
|
|
- `dataclasses`: For defining the `TokenizationArgs` dataclass. |
|
|
|
|
|
Usage: |
|
|
1. Import the `TokenizationArgs` class: |
|
|
```python |
|
|
from teddy.tokenizer.tokenization_args import TokenizationArgs" |
|
|
``` |
|
|
2. Define tokenization arguments for a specific tokenization task: |
|
|
```python |
|
|
tokenization_args = TokenizationArgs( |
|
|
tokenizer_name_or_path="path/to/tokenizer", |
|
|
... |
|
|
) |
|
|
``` |
|
|
3. Pass the `tokenization_args` object to the tokenization function: |
|
|
```python |
|
|
tokenized_data = tokenize(data, tokenization_args) |
|
|
``` |
|
|
""" |
|
|
|
|
|
from dataclasses import dataclass, field |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TokenizationArgs: |
|
|
tokenizer_name_or_path: str = field(metadata={"help": "Path to tokenizer used."}) |
|
|
gene_id_column: str = field(default="index", metadata={"help": "Field to use while accessing gene_ids for values."}) |
|
|
random_genes: bool = field( |
|
|
default=False, metadata={"help": "whether we want random genes (True) selection or top expressed ones (False)"} |
|
|
) |
|
|
include_zero_genes: bool = field(default=False, metadata={"help": "Path to tokenizer used."}) |
|
|
add_cls: bool = field(default=False, metadata={"help": "Whether to add cls token to the start of the sequence."}) |
|
|
cls_token_id: int = field(default=None, metadata={"help": "Token id for cls token."}) |
|
|
perturbseq: bool = field( |
|
|
default=False, |
|
|
metadata={"help": "[PerturbSeq specific flag] Whether to add perturbation token during tokenization."}, |
|
|
) |
|
|
tokenize_perturbseq_for_train: bool = field( |
|
|
default=True, |
|
|
metadata={ |
|
|
"help": "[PerturbSeq specific flag] Whether to tokenize labels to prepare data for training or to simply prepare tokennized perturbation flags for inference." |
|
|
}, |
|
|
) |
|
|
add_tokens: tuple = field( |
|
|
default=(), |
|
|
metadata={ |
|
|
"help": "Enter a tuple of string values for tokens. Will be pre-pended to the gene id sequence. Can be used instead of add_cls" |
|
|
}, |
|
|
) |
|
|
|
|
|
add_disease_annotation: bool = field(default=False) |
|
|
|
|
|
label_column: str = field( |
|
|
default=None, metadata={"help": "Which column to use as a label for a classification task."} |
|
|
) |
|
|
max_shard_samples: int = field(default=500, metadata={"help": "Number of samples included in sharding."}) |
|
|
max_seq_len: int = field(default=3001, metadata={"help": "Max seq length used for data processing"}) |
|
|
pad_length: int = field(default=3001, metadata={"help": "Pad sequence to x length so that all arrays in all batches are same length"}) |
|
|
truncation_method: str = field( |
|
|
default="max", |
|
|
metadata={ |
|
|
"help": "Indicate here how to restrict the number of genes to obtain max_seq_len from the full set of expresison values. Options: max, random" |
|
|
}, |
|
|
) |
|
|
bins: int = field(default=None, metadata={"help": "Number of bins used when required for data processing"}) |
|
|
|
|
|
rescale_labels: bool = field(default=False, metadata={"help": "If true, labels are binned or continiously ranked"}) |
|
|
|
|
|
continuous_rank: bool = field( |
|
|
default=False, metadata={"help": "If true, gene values are overwritten with linspace[-1, 1] by rank."} |
|
|
) |
|
|
|
|
|
bio_annotations: bool = field( |
|
|
default=False, metadata={"help": "If true, include disease, tissue type, cell type, sex"} |
|
|
) |
|
|
|
|
|
bio_annotation_masking_prob: float = field( |
|
|
default=0.15, metadata={"help": "Mask annotation tokens with this probability"} |
|
|
) |
|
|
|
|
|
disease_mapping: str = field( |
|
|
default=None, metadata={"help": "Path to json mapping from disease names to standard disease categories"} |
|
|
) |
|
|
|
|
|
tissue_mapping: str = field( |
|
|
default=None, metadata={"help": "Path to json mapping from tissue names to standard tissue categories"} |
|
|
) |
|
|
|
|
|
cell_mapping: str = field( |
|
|
default=None, metadata={"help": "Path to json mapping from cell type names to standard cell types"} |
|
|
) |
|
|
|
|
|
sex_mapping: str = field( |
|
|
default=None, metadata={"help": "Path to json mapping from sex names to standard sex categories"} |
|
|
) |
|
|
|
|
|
load_dir: str = field(default="", metadata={"help": "Directory where h5ad data is loaded from."}) |
|
|
|
|
|
save_dir: str = field( |
|
|
default="", |
|
|
metadata={ |
|
|
"help": "Directory where tokenization function will save data. tokenize() saves tokenized in data_path.replace(load_dir, save_dir)" |
|
|
}, |
|
|
) |
|
|
|
|
|
gene_seed: int = field(default=42, metadata={"help": "Random seed that controls randomness of gene selection"}) |
|
|
|