genetransformer
#324
by
sofiaztj
- opened
This view is limited to 50 files because it contains too many changes.
See the raw diff here.
- .gitattributes +2 -2
- MANIFEST.in +3 -4
- README.md +8 -27
- config.json +8 -9
- docs/source/about.rst +5 -9
- docs/source/api.rst +0 -8
- docs/source/geneformer.mtl_classifier.rst +0 -11
- docs/source/geneformer.tokenizer.rst +1 -2
- docs/source/index.rst +1 -1
- examples/multitask_cell_classification.ipynb +0 -420
- examples/pretraining_new_model/pretrain_geneformer_w_deepspeed.py +1 -3
- examples/tokenizing_scRNAseq_data.ipynb +1 -5
- fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/config.json +0 -0
- fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/optimizer.pt +0 -0
- fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/pytorch_model.bin +0 -0
- fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/rng_state.pth +0 -0
- fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/scheduler.pt +0 -0
- fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/trainer_state.json +0 -0
- fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/training_args.bin +0 -0
- fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/config.json +0 -24
- fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/pytorch_model.bin +0 -3
- {gf-12L-30M-i2048 → geneformer-12L-30M}/config.json +0 -0
- {gf-12L-30M-i2048 → geneformer-12L-30M}/pytorch_model.bin +0 -0
- {gf-12L-30M-i2048 → geneformer-12L-30M}/training_args.bin +0 -0
- geneformer/__init__.py +1 -14
- geneformer/classifier.py +56 -189
- geneformer/classifier_utils.py +35 -258
- geneformer/collator_for_classification.py +74 -139
- geneformer/emb_extractor.py +50 -101
- geneformer/ensembl_mapping_dict_gc95M.pkl +0 -3
- geneformer/evaluation_utils.py +5 -5
- geneformer/gene_dictionaries_30m/ensembl_mapping_dict_gc30M.pkl +0 -3
- geneformer/gene_dictionaries_30m/gene_median_dictionary_gc30M.pkl +0 -3
- geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl +0 -3
- geneformer/gene_median_dictionary.pkl +0 -0
- geneformer/gene_median_dictionary_gc95M.pkl +0 -3
- geneformer/{gene_dictionaries_30m/gene_name_id_dict_gc30M.pkl → gene_name_id_dict.pkl} +0 -0
- geneformer/gene_name_id_dict_gc95M.pkl +0 -3
- geneformer/in_silico_perturber.py +136 -776
- geneformer/in_silico_perturber_stats.py +26 -76
- geneformer/mtl/__init__.py +0 -1
- geneformer/mtl/collators.py +0 -76
- geneformer/mtl/data.py +0 -150
- geneformer/mtl/eval_utils.py +0 -88
- geneformer/mtl/imports.py +0 -43
- geneformer/mtl/model.py +0 -121
- geneformer/mtl/optuna_utils.py +0 -27
- geneformer/mtl/train.py +0 -380
- geneformer/mtl/train_utils.py +0 -161
- geneformer/mtl/utils.py +0 -129
.gitattributes
CHANGED
@@ -14,11 +14,10 @@
|
|
14 |
*.ot filter=lfs diff=lfs merge=lfs -text
|
15 |
*.parquet filter=lfs diff=lfs merge=lfs -text
|
16 |
*.pb filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
18 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
19 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
20 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
21 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
22 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
23 |
*.tflite filter=lfs diff=lfs merge=lfs -text
|
24 |
*.tgz filter=lfs diff=lfs merge=lfs -text
|
@@ -26,4 +25,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
26 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
27 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
28 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
29 |
model.safetensors filter=lfs diff=lfs merge=lfs -text
|
|
|
14 |
*.ot filter=lfs diff=lfs merge=lfs -text
|
15 |
*.parquet filter=lfs diff=lfs merge=lfs -text
|
16 |
*.pb filter=lfs diff=lfs merge=lfs -text
|
|
|
17 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
18 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
19 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
20 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
21 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
22 |
*.tflite filter=lfs diff=lfs merge=lfs -text
|
23 |
*.tgz filter=lfs diff=lfs merge=lfs -text
|
|
|
25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
geneformer/gene_name_id_dict.pkl filter=lfs diff=lfs merge=lfs -text
|
29 |
model.safetensors filter=lfs diff=lfs merge=lfs -text
|
MANIFEST.in
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
include geneformer/
|
2 |
-
include geneformer/
|
3 |
-
include geneformer/
|
4 |
-
include geneformer/token_dictionary_gc95M.pkl
|
|
|
1 |
+
include geneformer/gene_median_dictionary.pkl
|
2 |
+
include geneformer/token_dictionary.pkl
|
3 |
+
include geneformer/gene_name_id_dict.pkl
|
|
README.md
CHANGED
@@ -3,38 +3,23 @@ datasets: ctheodoris/Genecorpus-30M
|
|
3 |
license: apache-2.0
|
4 |
---
|
5 |
# Geneformer
|
6 |
-
Geneformer is a
|
7 |
|
8 |
-
- See [our manuscript](https://rdcu.be/ddrx0) for details
|
9 |
-
- See [our manuscript](https://www.biorxiv.org/content/10.1101/2024.08.16.608180v1.full.pdf) for details of the expanded model trained on ~95 million transcriptomes in April 2024 and our continual learning, multitask learning, and quantization strategies.
|
10 |
- See [geneformer.readthedocs.io](https://geneformer.readthedocs.io) for documentation.
|
11 |
|
12 |
# Model Description
|
13 |
-
Geneformer is a
|
14 |
|
15 |
-
|
16 |
-
|
17 |
-
The rank value encoding of each single cell’s transcriptome then proceeds through N layers of transformer encoder units, where N varies dependent on the model size. Pretraining was accomplished using a masked learning objective where 15% of the genes within each transcriptome were masked and the model was trained to predict which gene should be within each masked position in that specific cell state using the context of the remaining unmasked genes. A major strength of this approach is that it is entirely self-supervised and can be accomplished on completely unlabeled data, which allows the inclusion of large amounts of training data without being restricted to samples with accompanying labels.
|
18 |
|
19 |
We detail applications and results in [our manuscript](https://rdcu.be/ddrx0).
|
20 |
|
21 |
-
During pretraining, Geneformer gained a fundamental understanding of network dynamics, encoding network hierarchy in the model’s attention weights in a completely self-supervised manner. With both zero-shot learning and fine-tuning with limited task-specific data, Geneformer consistently boosted predictive accuracy in a diverse panel of downstream tasks relevant to chromatin and network dynamics. In silico perturbation with zero-shot learning identified a novel transcription factor in cardiomyocytes that we experimentally validated to be critical to their ability to generate contractile force. In silico treatment with limited patient data revealed candidate therapeutic targets for cardiomyopathy that we experimentally validated to significantly improve the ability of cardiomyocytes to generate contractile force in an
|
22 |
-
|
23 |
-
The repository includes the following pretrained models:
|
24 |
-
|
25 |
-
L=layers\
|
26 |
-
M=millions of cells used for pretraining\
|
27 |
-
i=input size\
|
28 |
-
(pretraining date)
|
29 |
|
30 |
-
|
31 |
-
- GF-12L-30M-i2048 (June 2021)
|
32 |
-
- GF-12L-95M-i4096 (April 2024)
|
33 |
-
- GF-20L-95M-i4096 (April 2024)
|
34 |
|
35 |
-
|
36 |
-
|
37 |
-
The repository also contains fined tuned models in the fine_tuned_models directory and the cancer-tuned model following continual learning on ~14 million cancer cells, GF-12L-95M-i4096_CLcancer.
|
38 |
|
39 |
# Application
|
40 |
The pretrained Geneformer model can be used directly for zero-shot learning, for example for in silico perturbation analysis, or by fine-tuning towards the relevant downstream task, such as gene or cell state classification.
|
@@ -64,7 +49,7 @@ Example applications demonstrated in [our manuscript](https://rdcu.be/ddrx0) inc
|
|
64 |
- in silico perturbation to determine transcription factor cooperativity
|
65 |
|
66 |
# Installation
|
67 |
-
In addition to the pretrained model, contained herein are functions for tokenizing and collating data specific to single cell transcriptomics, pretraining the model, fine-tuning the model, extracting and plotting cell embeddings, and performing in silico pertrubation with either the pretrained or fine-tuned models. To install
|
68 |
|
69 |
```bash
|
70 |
# Make sure you have git-lfs installed (https://git-lfs.com)
|
@@ -85,7 +70,3 @@ For usage, see [examples](https://huggingface.co/ctheodoris/Geneformer/tree/main
|
|
85 |
Please note that the fine-tuning examples are meant to be generally applicable and the input datasets and labels will vary dependent on the downstream task. Example input files for a few of the downstream tasks demonstrated in the manuscript are located within the [example_input_files directory](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files) in the dataset repository, but these only represent a few example fine-tuning applications.
|
86 |
|
87 |
Please note that GPU resources are required for efficient usage of Geneformer. Additionally, we strongly recommend tuning hyperparameters for each downstream fine-tuning application as this can significantly boost predictive potential in the downstream task (e.g. max learning rate, learning schedule, number of layers to freeze, etc.).
|
88 |
-
|
89 |
-
# Citations
|
90 |
-
- C V Theodoris#, L Xiao, A Chopra, M D Chaffin, Z R Al Sayed, M C Hill, H Mantineo, E Brydon, Z Zeng, X S Liu, P T Ellinor#. Transfer learning enables predictions in network biology. _**Nature**_, 31 May 2023. (#co-corresponding authors)
|
91 |
-
- H Chen*, M S Venkatesh*, J Gomez Ortega, S V Mahesh, T Nandi, R Madduri, K Pelka†, C V Theodoris†#. Quantized multi-task learning for context-specific representations of gene network dynamics. _**bioRxiv**_, 19 Aug 2024. (*co-first authors, †co-senior authors, #corresponding author)
|
|
|
3 |
license: apache-2.0
|
4 |
---
|
5 |
# Geneformer
|
6 |
+
Geneformer is a foundation transformer model pretrained on a large-scale corpus of ~30 million single cell transcriptomes to enable context-aware predictions in settings with limited data in network biology.
|
7 |
|
8 |
+
- See [our manuscript](https://rdcu.be/ddrx0) for details.
|
|
|
9 |
- See [geneformer.readthedocs.io](https://geneformer.readthedocs.io) for documentation.
|
10 |
|
11 |
# Model Description
|
12 |
+
Geneformer is a foundation transformer model pretrained on [Genecorpus-30M](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M), a pretraining corpus comprised of ~30 million single cell transcriptomes from a broad range of human tissues. We excluded cells with high mutational burdens (e.g. malignant cells and immortalized cell lines) that could lead to substantial network rewiring without companion genome sequencing to facilitate interpretation. Each single cell’s transcriptome is presented to the model as a rank value encoding where genes are ranked by their expression in that cell normalized by their expression across the entire Genecorpus-30M. The rank value encoding provides a nonparametric representation of that cell’s transcriptome and takes advantage of the many observations of each gene’s expression across Genecorpus-30M to prioritize genes that distinguish cell state. Specifically, this method will deprioritize ubiquitously highly-expressed housekeeping genes by normalizing them to a lower rank. Conversely, genes such as transcription factors that may be lowly expressed when they are expressed but highly distinguish cell state will move to a higher rank within the encoding. Furthermore, this rank-based approach may be more robust against technical artifacts that may systematically bias the absolute transcript counts value while the overall relative ranking of genes within each cell remains more stable.
|
13 |
|
14 |
+
The rank value encoding of each single cell��s transcriptome then proceeds through six transformer encoder units. Pretraining was accomplished using a masked learning objective where 15% of the genes within each transcriptome were masked and the model was trained to predict which gene should be within each masked position in that specific cell state using the context of the remaining unmasked genes. A major strength of this approach is that it is entirely self-supervised and can be accomplished on completely unlabeled data, which allows the inclusion of large amounts of training data without being restricted to samples with accompanying labels.
|
|
|
|
|
15 |
|
16 |
We detail applications and results in [our manuscript](https://rdcu.be/ddrx0).
|
17 |
|
18 |
+
During pretraining, Geneformer gained a fundamental understanding of network dynamics, encoding network hierarchy in the model’s attention weights in a completely self-supervised manner. With both zero-shot learning and fine-tuning with limited task-specific data, Geneformer consistently boosted predictive accuracy in a diverse panel of downstream tasks relevant to chromatin and network dynamics. In silico perturbation with zero-shot learning identified a novel transcription factor in cardiomyocytes that we experimentally validated to be critical to their ability to generate contractile force. In silico treatment with limited patient data revealed candidate therapeutic targets for cardiomyopathy that we experimentally validated to significantly improve the ability of cardiomyocytes to generate contractile force in an iPSC model of the disease. Overall, Geneformer represents a foundational deep learning model pretrained on ~30 million human single cell transcriptomes to gain a fundamental understanding of gene network dynamics that can now be democratized to a vast array of downstream tasks to accelerate discovery of key network regulators and candidate therapeutic targets.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
+
In [our manuscript](https://rdcu.be/ddrx0), we report results for the 6 layer Geneformer model pretrained on Genecorpus-30M. We additionally provide within this repository a 12 layer Geneformer model, scaled up with retained width:depth aspect ratio, also pretrained on Genecorpus-30M.
|
|
|
|
|
|
|
21 |
|
22 |
+
Both the 6 and 12 layer Geneformer models were pretrained in June 2021.
|
|
|
|
|
23 |
|
24 |
# Application
|
25 |
The pretrained Geneformer model can be used directly for zero-shot learning, for example for in silico perturbation analysis, or by fine-tuning towards the relevant downstream task, such as gene or cell state classification.
|
|
|
49 |
- in silico perturbation to determine transcription factor cooperativity
|
50 |
|
51 |
# Installation
|
52 |
+
In addition to the pretrained model, contained herein are functions for tokenizing and collating data specific to single cell transcriptomics, pretraining the model, fine-tuning the model, extracting and plotting cell embeddings, and performing in silico pertrubation with either the pretrained or fine-tuned models. To install:
|
53 |
|
54 |
```bash
|
55 |
# Make sure you have git-lfs installed (https://git-lfs.com)
|
|
|
70 |
Please note that the fine-tuning examples are meant to be generally applicable and the input datasets and labels will vary dependent on the downstream task. Example input files for a few of the downstream tasks demonstrated in the manuscript are located within the [example_input_files directory](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files) in the dataset repository, but these only represent a few example fine-tuning applications.
|
71 |
|
72 |
Please note that GPU resources are required for efficient usage of Geneformer. Additionally, we strongly recommend tuning hyperparameters for each downstream fine-tuning application as this can significantly boost predictive potential in the downstream task (e.g. max learning rate, learning schedule, number of layers to freeze, etc.).
|
|
|
|
|
|
|
|
config.json
CHANGED
@@ -3,22 +3,21 @@
|
|
3 |
"BertForMaskedLM"
|
4 |
],
|
5 |
"attention_probs_dropout_prob": 0.02,
|
6 |
-
"
|
7 |
"hidden_act": "relu",
|
8 |
"hidden_dropout_prob": 0.02,
|
9 |
-
"hidden_size":
|
10 |
"initializer_range": 0.02,
|
11 |
-
"intermediate_size":
|
12 |
"layer_norm_eps": 1e-12,
|
13 |
-
"max_position_embeddings":
|
14 |
"model_type": "bert",
|
15 |
-
"num_attention_heads":
|
16 |
-
"num_hidden_layers":
|
17 |
"pad_token_id": 0,
|
18 |
"position_embedding_type": "absolute",
|
19 |
-
"
|
20 |
-
"transformers_version": "4.37.1",
|
21 |
"type_vocab_size": 2,
|
22 |
"use_cache": true,
|
23 |
-
"vocab_size":
|
24 |
}
|
|
|
3 |
"BertForMaskedLM"
|
4 |
],
|
5 |
"attention_probs_dropout_prob": 0.02,
|
6 |
+
"gradient_checkpointing": false,
|
7 |
"hidden_act": "relu",
|
8 |
"hidden_dropout_prob": 0.02,
|
9 |
+
"hidden_size": 256,
|
10 |
"initializer_range": 0.02,
|
11 |
+
"intermediate_size": 512,
|
12 |
"layer_norm_eps": 1e-12,
|
13 |
+
"max_position_embeddings": 2048,
|
14 |
"model_type": "bert",
|
15 |
+
"num_attention_heads": 4,
|
16 |
+
"num_hidden_layers": 6,
|
17 |
"pad_token_id": 0,
|
18 |
"position_embedding_type": "absolute",
|
19 |
+
"transformers_version": "4.6.0",
|
|
|
20 |
"type_vocab_size": 2,
|
21 |
"use_cache": true,
|
22 |
+
"vocab_size": 25426
|
23 |
}
|
docs/source/about.rst
CHANGED
@@ -4,13 +4,11 @@ About
|
|
4 |
Model Description
|
5 |
-----------------
|
6 |
|
7 |
-
**Geneformer** is a context-aware, attention-based deep learning model pretrained on a large-scale corpus of single-cell transcriptomes to enable context-specific predictions in settings with limited data in network biology. During pretraining, Geneformer gained a fundamental understanding of network dynamics, encoding network hierarchy in the attention weights of the model in a completely self-supervised manner. With both zero-shot learning and fine-tuning with limited task-specific data, Geneformer consistently boosted predictive accuracy in a diverse panel of downstream tasks relevant to chromatin and network dynamics. In silico perturbation with zero-shot learning identified a novel transcription factor in cardiomyocytes that we experimentally validated to be critical to their ability to generate contractile force. In silico treatment with limited patient data revealed candidate therapeutic targets for cardiomyopathy that we experimentally validated to significantly improve the ability of cardiomyocytes to generate contractile force in an iPSC model of the disease. Overall, Geneformer represents a foundational deep learning model pretrained on
|
8 |
|
9 |
-
In `our manuscript <https://rdcu.be/ddrx0>`_, we report results for the
|
10 |
|
11 |
-
Both the `6 <https://huggingface.co/ctheodoris/Geneformer/blob/main/
|
12 |
-
|
13 |
-
Also see `our 2024 manuscript <https://www.biorxiv.org/content/10.1101/2024.08.16.608180v1.full.pdf>`_, for details of the `expanded model <https://huggingface.co/ctheodoris/Geneformer/blob/main/model.safetensors>`_ trained on ~95 million transcriptomes in April 2024 and our continual learning, multitask learning, and quantization strategies.
|
14 |
|
15 |
Application
|
16 |
-----------
|
@@ -41,9 +39,7 @@ Example applications demonstrated in `our manuscript <https://rdcu.be/ddrx0>`_ i
|
|
41 |
| - in silico perturbation to determine transcription factor targets
|
42 |
| - in silico perturbation to determine transcription factor cooperativity
|
43 |
|
44 |
-
|
45 |
-
|
46 |
|
47 |
| C V Theodoris #, L Xiao, A Chopra, M D Chaffin, Z R Al Sayed, M C Hill, H Mantineo, E Brydon, Z Zeng, X S Liu, P T Ellinor #. `Transfer learning enables predictions in network biology. <https://rdcu.be/ddrx0>`_ *Nature*, 31 May 2023. (# co-corresponding authors)
|
48 |
-
|
49 |
-
| H Chen \*, M S Venkatesh \*, J Gomez Ortega, S V Mahesh, T Nandi, R Madduri, K Pelka †, C V Theodoris † #. `Quantized multi-task learning for context-specific representations of gene network dynamics. <https://www.biorxiv.org/content/10.1101/2024.08.16.608180v1.full.pdf>`_ *bioRxiv*, 19 Aug 2024. (\* co-first authors, † co-senior authors, # corresponding author)
|
|
|
4 |
Model Description
|
5 |
-----------------
|
6 |
|
7 |
+
**Geneformer** is a context-aware, attention-based deep learning model pretrained on a large-scale corpus of ~30 million single-cell transcriptomes to enable context-specific predictions in settings with limited data in network biology. During pretraining, Geneformer gained a fundamental understanding of network dynamics, encoding network hierarchy in the attention weights of the model in a completely self-supervised manner. With both zero-shot learning and fine-tuning with limited task-specific data, Geneformer consistently boosted predictive accuracy in a diverse panel of downstream tasks relevant to chromatin and network dynamics. In silico perturbation with zero-shot learning identified a novel transcription factor in cardiomyocytes that we experimentally validated to be critical to their ability to generate contractile force. In silico treatment with limited patient data revealed candidate therapeutic targets for cardiomyopathy that we experimentally validated to significantly improve the ability of cardiomyocytes to generate contractile force in an iPSC model of the disease. Overall, Geneformer represents a foundational deep learning model pretrained on ~30 million human single cell transcriptomes to gain a fundamental understanding of gene network dynamics that can now be democratized to a vast array of downstream tasks to accelerate discovery of key network regulators and candidate therapeutic targets.
|
8 |
|
9 |
+
In `our manuscript <https://rdcu.be/ddrx0>`_, we report results for the 6 layer Geneformer model pretrained on Genecorpus-30M. We additionally provide within the repository a 12 layer Geneformer model, scaled up with retained width:depth aspect ratio, also pretrained on Genecorpus-30M.
|
10 |
|
11 |
+
Both the `6 <https://huggingface.co/ctheodoris/Geneformer/blob/main/pytorch_model.bin>`_ and `12 <https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer-12L-30M/pytorch_model.bin>`_ layer Geneformer models were pretrained in June 2021.
|
|
|
|
|
12 |
|
13 |
Application
|
14 |
-----------
|
|
|
39 |
| - in silico perturbation to determine transcription factor targets
|
40 |
| - in silico perturbation to determine transcription factor cooperativity
|
41 |
|
42 |
+
Citation
|
43 |
+
--------
|
44 |
|
45 |
| C V Theodoris #, L Xiao, A Chopra, M D Chaffin, Z R Al Sayed, M C Hill, H Mantineo, E Brydon, Z Zeng, X S Liu, P T Ellinor #. `Transfer learning enables predictions in network biology. <https://rdcu.be/ddrx0>`_ *Nature*, 31 May 2023. (# co-corresponding authors)
|
|
|
|
docs/source/api.rst
CHANGED
@@ -17,14 +17,6 @@ Classifier
|
|
17 |
|
18 |
geneformer.classifier
|
19 |
|
20 |
-
Multitask Classifier
|
21 |
-
--------------------
|
22 |
-
|
23 |
-
.. toctree::
|
24 |
-
:maxdepth: 1
|
25 |
-
|
26 |
-
geneformer.mtl_classifier
|
27 |
-
|
28 |
Embedding Extractor
|
29 |
-------------------
|
30 |
|
|
|
17 |
|
18 |
geneformer.classifier
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
Embedding Extractor
|
21 |
-------------------
|
22 |
|
docs/source/geneformer.mtl_classifier.rst
DELETED
@@ -1,11 +0,0 @@
|
|
1 |
-
geneformer.mtl\_classifier
|
2 |
-
==========================
|
3 |
-
|
4 |
-
.. automodule:: geneformer.mtl_classifier
|
5 |
-
:members:
|
6 |
-
:undoc-members:
|
7 |
-
:show-inheritance:
|
8 |
-
:exclude-members:
|
9 |
-
valid_option_dict,
|
10 |
-
validate_options,
|
11 |
-
validate_additional_options
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs/source/geneformer.tokenizer.rst
CHANGED
@@ -11,5 +11,4 @@ geneformer.tokenizer
|
|
11 |
tokenize_files,
|
12 |
tokenize_loom,
|
13 |
rank_genes,
|
14 |
-
tokenize_cell
|
15 |
-
sum_ensembl_ids
|
|
|
11 |
tokenize_files,
|
12 |
tokenize_loom,
|
13 |
rank_genes,
|
14 |
+
tokenize_cell
|
|
docs/source/index.rst
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
Geneformer
|
2 |
==========
|
3 |
|
4 |
-
Geneformer is a foundation transformer model pretrained on a large-scale corpus of single cell transcriptomes to enable context-aware predictions in network biology.
|
5 |
|
6 |
See `our manuscript <https://rdcu.be/ddrx0>`_ for details.
|
7 |
|
|
|
1 |
Geneformer
|
2 |
==========
|
3 |
|
4 |
+
Geneformer is a foundation transformer model pretrained on a large-scale corpus of ~30 million single cell transcriptomes to enable context-aware predictions in network biology.
|
5 |
|
6 |
See `our manuscript <https://rdcu.be/ddrx0>`_ for details.
|
7 |
|
examples/multitask_cell_classification.ipynb
DELETED
@@ -1,420 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "markdown",
|
5 |
-
"id": "866f100c-e11a-4e7b-a37c-831775d845a7",
|
6 |
-
"metadata": {},
|
7 |
-
"source": [
|
8 |
-
"# Geneformer Multi-Task Cell Classifier Tutorial\n",
|
9 |
-
"\n",
|
10 |
-
"This tutorial demonstrates how to use the Geneformer Multi-Task Cell Classifier and optimizatize hyperparameter for fine-tuning"
|
11 |
-
]
|
12 |
-
},
|
13 |
-
{
|
14 |
-
"cell_type": "markdown",
|
15 |
-
"id": "311ba456-b44d-40c7-941d-3fc03bcda85a",
|
16 |
-
"metadata": {},
|
17 |
-
"source": [
|
18 |
-
"## 1. Installation and Imports\n",
|
19 |
-
"\n",
|
20 |
-
"First import the necessary modules."
|
21 |
-
]
|
22 |
-
},
|
23 |
-
{
|
24 |
-
"cell_type": "code",
|
25 |
-
"execution_count": 3,
|
26 |
-
"id": "cd9defdc-0524-4c3b-a741-27117ed3a5be",
|
27 |
-
"metadata": {},
|
28 |
-
"outputs": [],
|
29 |
-
"source": [
|
30 |
-
"from geneformer import MTLClassifier"
|
31 |
-
]
|
32 |
-
},
|
33 |
-
{
|
34 |
-
"cell_type": "markdown",
|
35 |
-
"id": "790e9c3c-f6d9-44b3-b9a5-05725760f4fd",
|
36 |
-
"metadata": {},
|
37 |
-
"source": [
|
38 |
-
"## 2. Set up Paths and Parameters\n",
|
39 |
-
"\n",
|
40 |
-
"Now, let's set up the necessary paths and parameters for our classifier. We'll also define our task columns, which are specific columns from our dataset that represent the classification tasks we want to train the model on."
|
41 |
-
]
|
42 |
-
},
|
43 |
-
{
|
44 |
-
"cell_type": "code",
|
45 |
-
"execution_count": null,
|
46 |
-
"id": "04a04197-8e45-47f8-a86f-202209ea10ae",
|
47 |
-
"metadata": {},
|
48 |
-
"outputs": [],
|
49 |
-
"source": [
|
50 |
-
"# Define paths\n",
|
51 |
-
"pretrained_path = \"/path/to/pretrained/Geneformer/model\" \n",
|
52 |
-
"# input data is tokenized rank value encodings generated by Geneformer tokenizer (see tokenizing_scRNAseq_data.ipynb)\n",
|
53 |
-
"train_path = \"/path/to/train/data.dataset\"\n",
|
54 |
-
"val_path = \"/path/to/val/data.dataset\"\n",
|
55 |
-
"test_path = \"/path/to/test/data.dataset\"\n",
|
56 |
-
"results_dir = \"/path/to/results/directory\"\n",
|
57 |
-
"model_save_path = \"/path/to/model/save/path\"\n",
|
58 |
-
"tensorboard_log_dir = \"/path/to/tensorboard/log/dir\"\n",
|
59 |
-
"\n",
|
60 |
-
"# Define tasks and hyperparameters\n",
|
61 |
-
"# task_columns should be a list of column names from your dataset\n",
|
62 |
-
"# Each column represents a specific classification task (e.g. cell type, disease state)\n",
|
63 |
-
"task_columns = [\"cell_type\", \"disease_state\"] # Example task columns\n",
|
64 |
-
"\n",
|
65 |
-
"hyperparameters = {\n",
|
66 |
-
" \"learning_rate\": {\"type\": \"float\", \"low\": 1e-5, \"high\": 1e-3, \"log\": True},\n",
|
67 |
-
" \"warmup_ratio\": {\"type\": \"float\", \"low\": 0.005, \"high\": 0.01},\n",
|
68 |
-
" \"weight_decay\": {\"type\": \"float\", \"low\": 0.01, \"high\": 0.1},\n",
|
69 |
-
" \"dropout_rate\": {\"type\": \"float\", \"low\": 0.0, \"high\": 0.7},\n",
|
70 |
-
" \"lr_scheduler_type\": {\"type\": \"categorical\", \"choices\": [\"cosine\"]},\n",
|
71 |
-
" \"task_weights\": {\"type\": \"float\", \"low\": 0.1, \"high\": 2.0}\n",
|
72 |
-
"}"
|
73 |
-
]
|
74 |
-
},
|
75 |
-
{
|
76 |
-
"cell_type": "markdown",
|
77 |
-
"id": "31857690-a739-435a-aefd-f171fafc1b78",
|
78 |
-
"metadata": {},
|
79 |
-
"source": [
|
80 |
-
"In the code above, we've defined `task_columns` as `[\"cell_type\", \"disease_state\"]`. This means our model will be trained to classify cells based on two tasks:\n",
|
81 |
-
"1. Identifying the cell type\n",
|
82 |
-
"2. Determining the disease state\n",
|
83 |
-
"3. Note: \"unique_cell_id\" is a required column in the dataset for logging and inference purposes\n",
|
84 |
-
"\n",
|
85 |
-
"These column names should correspond to actual columns in your dataset. Each column should contain the labels for that specific classification task.\n",
|
86 |
-
"\n",
|
87 |
-
"For example, your dataset might look something like this:\n",
|
88 |
-
"\n",
|
89 |
-
" | unique_cell_id | input_ids | ... | cell_type | disease_state |\n",
|
90 |
-
" |----------------|-----------|-----|-----------|---------------|\n",
|
91 |
-
" | cell1 | ... | ... | neuron | healthy |\n",
|
92 |
-
" | cell2 | ... | ... | astrocyte | diseased |\n",
|
93 |
-
" | ... | ... | ... | ... | ... |\n",
|
94 |
-
"The model will learn to predict classes within 'cell_type' and 'disease_state' "
|
95 |
-
]
|
96 |
-
},
|
97 |
-
{
|
98 |
-
"cell_type": "markdown",
|
99 |
-
"id": "b9e3050a-6162-4c01-b6fd-8784bf4ab1e4",
|
100 |
-
"metadata": {},
|
101 |
-
"source": [
|
102 |
-
"## 3. Initialize the MTLClassifier\n",
|
103 |
-
"\n",
|
104 |
-
"Now, let's create an instance of the MTLClassifier with our defined parameters and task columns."
|
105 |
-
]
|
106 |
-
},
|
107 |
-
{
|
108 |
-
"cell_type": "code",
|
109 |
-
"execution_count": null,
|
110 |
-
"id": "e27caac9-670c-409d-9313-50201c665cb9",
|
111 |
-
"metadata": {},
|
112 |
-
"outputs": [],
|
113 |
-
"source": [
|
114 |
-
"mc = MTLClassifier(\n",
|
115 |
-
" task_columns=task_columns, # Our defined classification tasks\n",
|
116 |
-
" study_name=\"MTLClassifier_example\",\n",
|
117 |
-
" pretrained_path=pretrained_path,\n",
|
118 |
-
" train_path=train_path,\n",
|
119 |
-
" val_path=val_path,\n",
|
120 |
-
" test_path=test_path,\n",
|
121 |
-
" model_save_path=model_save_path,\n",
|
122 |
-
" results_dir=results_dir,\n",
|
123 |
-
" tensorboard_log_dir=tensorboard_log_dir,\n",
|
124 |
-
" hyperparameters=hyperparameters,\n",
|
125 |
-
" n_trials=15, # Number of trials for hyperparameter optimization (at least 50 suggested)\n",
|
126 |
-
" epochs=1, # Number of training epochs (1 suggested to prevent overfitting)\n",
|
127 |
-
" batch_size=8, # Adjust based on available GPU memory\n",
|
128 |
-
" seed=42\n",
|
129 |
-
")"
|
130 |
-
]
|
131 |
-
},
|
132 |
-
{
|
133 |
-
"cell_type": "markdown",
|
134 |
-
"id": "0d729444-e3ad-4584-9659-0c464ac97462",
|
135 |
-
"metadata": {},
|
136 |
-
"source": [
|
137 |
-
"## 4. Run Hyperparameter Optimization\n",
|
138 |
-
"\n",
|
139 |
-
"Now, let's run the Optuna study to optimize our hyperparameters for both classification tasks."
|
140 |
-
]
|
141 |
-
},
|
142 |
-
{
|
143 |
-
"cell_type": "code",
|
144 |
-
"execution_count": null,
|
145 |
-
"id": "9298aa3e-6a52-4aa8-b9ff-b63d97beac93",
|
146 |
-
"metadata": {},
|
147 |
-
"outputs": [],
|
148 |
-
"source": [
|
149 |
-
"mc.run_optuna_study()"
|
150 |
-
]
|
151 |
-
},
|
152 |
-
{
|
153 |
-
"cell_type": "markdown",
|
154 |
-
"id": "af23075d-d07b-43d3-bc5d-4df4d5d7199b",
|
155 |
-
"metadata": {},
|
156 |
-
"source": [
|
157 |
-
"## 5. Evaluate the Model on Test Data\n",
|
158 |
-
"\n",
|
159 |
-
"After optimization, we can evaluate our model on the test dataset. This will provide performance metrics for both classification tasks. CSV containing following keys will be generated in specified results directiory \"Cell ID, task(1...n) True,task(1.,.n) Pred,task(1...n) Probabilities\""
|
160 |
-
]
|
161 |
-
},
|
162 |
-
{
|
163 |
-
"cell_type": "code",
|
164 |
-
"execution_count": null,
|
165 |
-
"id": "461bf8d3-b964-4ff4-994f-9f3d313d4614",
|
166 |
-
"metadata": {},
|
167 |
-
"outputs": [],
|
168 |
-
"source": [
|
169 |
-
"mc.load_and_evaluate_test_model()"
|
170 |
-
]
|
171 |
-
},
|
172 |
-
{
|
173 |
-
"cell_type": "markdown",
|
174 |
-
"id": "31cfeb2d-6673-4b02-a79c-2533cc5e4d28",
|
175 |
-
"metadata": {},
|
176 |
-
"source": [
|
177 |
-
"## 6. (Optional) Manual Hyperparameter Tuning\n",
|
178 |
-
"\n",
|
179 |
-
"If you prefer to set hyperparameters manually, you can use the following approach:"
|
180 |
-
]
|
181 |
-
},
|
182 |
-
{
|
183 |
-
"cell_type": "code",
|
184 |
-
"execution_count": null,
|
185 |
-
"id": "8ee6b99f-42e9-4abf-a292-aa9047735e0e",
|
186 |
-
"metadata": {},
|
187 |
-
"outputs": [],
|
188 |
-
"source": [
|
189 |
-
"manual_hyperparameters = {\n",
|
190 |
-
" \"learning_rate\": 0.001,\n",
|
191 |
-
" \"warmup_ratio\": 0.01,\n",
|
192 |
-
" \"weight_decay\": 0.1,\n",
|
193 |
-
" \"dropout_rate\": 0.1,\n",
|
194 |
-
" \"lr_scheduler_type\": \"cosine\",\n",
|
195 |
-
" \"task_weights\": [1, 1], # Weights for each task (cell_type, disease_state)\n",
|
196 |
-
" \"max_layers_to_freeze\": 2\n",
|
197 |
-
"}\n",
|
198 |
-
"\n",
|
199 |
-
"mc_manual = MTLClassifier(\n",
|
200 |
-
" task_columns=task_columns,\n",
|
201 |
-
" study_name=\"mtl_manual\",\n",
|
202 |
-
" pretrained_path=pretrained_path,\n",
|
203 |
-
" train_path=train_path,\n",
|
204 |
-
" val_path=val_path,\n",
|
205 |
-
" test_path=test_path,\n",
|
206 |
-
" model_save_path=model_save_path,\n",
|
207 |
-
" results_dir=results_dir,\n",
|
208 |
-
" tensorboard_log_dir=tensorboard_log_dir,\n",
|
209 |
-
" manual_hyperparameters=manual_hyperparameters,\n",
|
210 |
-
" use_manual_hyperparameters=True,\n",
|
211 |
-
" epochs=10,\n",
|
212 |
-
" batch_size=32,\n",
|
213 |
-
" seed=42\n",
|
214 |
-
")\n",
|
215 |
-
"\n",
|
216 |
-
"mc_manual.run_manual_tuning()"
|
217 |
-
]
|
218 |
-
},
|
219 |
-
{
|
220 |
-
"cell_type": "markdown",
|
221 |
-
"id": "dbaac008-fc00-4b71-8e78-89b2d922d9d8",
|
222 |
-
"metadata": {},
|
223 |
-
"source": [
|
224 |
-
"# Geneformer In Silico Perturber Tutorial (MTL Quantized)\n",
|
225 |
-
"This demonstrates how to use the Geneformer In Silico Perturber with a Multi-Task Learning (MTL) model in a quantized configuration to optimize runtime and memory."
|
226 |
-
]
|
227 |
-
},
|
228 |
-
{
|
229 |
-
"cell_type": "code",
|
230 |
-
"execution_count": null,
|
231 |
-
"id": "2e15ad57-736c-48f0-be87-39cf5015bc5c",
|
232 |
-
"metadata": {},
|
233 |
-
"outputs": [],
|
234 |
-
"source": [
|
235 |
-
"from geneformer import InSilicoPerturber, EmbExtractor, InSilicoPerturberStats"
|
236 |
-
]
|
237 |
-
},
|
238 |
-
{
|
239 |
-
"cell_type": "code",
|
240 |
-
"execution_count": null,
|
241 |
-
"id": "43c18140-151e-4d44-95b4-a9b3a47172cf",
|
242 |
-
"metadata": {},
|
243 |
-
"outputs": [],
|
244 |
-
"source": [
|
245 |
-
"# Define paths\n",
|
246 |
-
"model_directory = \"/path/to/model/save/path\"\n",
|
247 |
-
"input_data_file = \"/path/to/input/data.dataset\"\n",
|
248 |
-
"output_directory = \"/path/to/output/directory\"\n",
|
249 |
-
"output_prefix = \"mtl_quantized_perturbation\"\n",
|
250 |
-
"\n",
|
251 |
-
"# Define parameters\n",
|
252 |
-
"perturb_type = \"delete\" # or \"overexpress\"\n",
|
253 |
-
"\n",
|
254 |
-
"# Define cell states to model\n",
|
255 |
-
"cell_states_to_model = {\n",
|
256 |
-
" \"state_key\": \"disease_state\", \n",
|
257 |
-
" \"start_state\": \"disease\", \n",
|
258 |
-
" \"goal_state\": \"control\"\n",
|
259 |
-
"}\n",
|
260 |
-
"\n",
|
261 |
-
"# Define filter data\n",
|
262 |
-
"filter_data_dict = {\n",
|
263 |
-
" \"cell_type\": [\"Fibroblast\"]\n",
|
264 |
-
"}"
|
265 |
-
]
|
266 |
-
},
|
267 |
-
{
|
268 |
-
"cell_type": "markdown",
|
269 |
-
"id": "3010d0bf-b23c-45c1-ac12-8c472dc8b7a1",
|
270 |
-
"metadata": {},
|
271 |
-
"source": [
|
272 |
-
"## 3. Extract State Embeddings\n",
|
273 |
-
"\n",
|
274 |
-
"Before we initialize the InSilicoPerturber, we need to extract the state embeddings using the EmbExtractor."
|
275 |
-
]
|
276 |
-
},
|
277 |
-
{
|
278 |
-
"cell_type": "code",
|
279 |
-
"execution_count": null,
|
280 |
-
"id": "215f0a90-8041-417d-a5d3-b2483626c3b2",
|
281 |
-
"metadata": {},
|
282 |
-
"outputs": [],
|
283 |
-
"source": [
|
284 |
-
"# Initialize EmbExtractor\n",
|
285 |
-
"embex = EmbExtractor(\n",
|
286 |
-
" filter_data_dict=filter_data_dict,\n",
|
287 |
-
" max_ncells=1000, # Number of cells to extract embeddings for\n",
|
288 |
-
" emb_layer=0, # Use the second to last layer\n",
|
289 |
-
" emb_mode = \"cls\",\n",
|
290 |
-
" summary_stat=\"exact_mean\",\n",
|
291 |
-
" forward_batch_size=8, # Adjust based on available GPU memory\n",
|
292 |
-
" nproc=4\n",
|
293 |
-
")\n",
|
294 |
-
"\n",
|
295 |
-
"# Extract state embeddings\n",
|
296 |
-
"state_embs_dict = embex.get_state_embs(\n",
|
297 |
-
" cell_states_to_model,\n",
|
298 |
-
" model_directory=model_directory,\n",
|
299 |
-
" input_data_file=input_data_file,\n",
|
300 |
-
" output_directory=output_directory,\n",
|
301 |
-
" output_prefix=output_prefix\n",
|
302 |
-
")"
|
303 |
-
]
|
304 |
-
},
|
305 |
-
{
|
306 |
-
"cell_type": "markdown",
|
307 |
-
"id": "23f14e36-4529-4fb2-8af9-7f4875cf81e3",
|
308 |
-
"metadata": {},
|
309 |
-
"source": [
|
310 |
-
"## 4. Initialize the InSilicoPerturber\n",
|
311 |
-
"\n",
|
312 |
-
"Now that we have our state embeddings, let's create an instance of the InSilicoPerturber with MTL and quantized configurations."
|
313 |
-
]
|
314 |
-
},
|
315 |
-
{
|
316 |
-
"cell_type": "code",
|
317 |
-
"execution_count": null,
|
318 |
-
"id": "09f985a1-91bc-4e8d-8001-a3663531b570",
|
319 |
-
"metadata": {},
|
320 |
-
"outputs": [],
|
321 |
-
"source": [
|
322 |
-
"# Initialize InSilicoPerturber\n",
|
323 |
-
"isp = InSilicoPerturber(\n",
|
324 |
-
" perturb_type=perturb_type,\n",
|
325 |
-
" genes_to_perturb=\"all\", # Perturb all genes\n",
|
326 |
-
" model_type=\"MTLCellClassifier-Quantized\", # Use quantized MTL model\n",
|
327 |
-
" emb_mode=\"cls\", # Use CLS token embedding\n",
|
328 |
-
" cell_states_to_model=cell_states_to_model,\n",
|
329 |
-
" state_embs_dict=state_embs_dict,\n",
|
330 |
-
" max_ncells=1000, # Number of cells to perturb (larger number increases power)\n",
|
331 |
-
" emb_layer=0, \n",
|
332 |
-
" forward_batch_size=8, # Adjust based on available GPU memory\n",
|
333 |
-
" nproc=1\n",
|
334 |
-
")"
|
335 |
-
]
|
336 |
-
},
|
337 |
-
{
|
338 |
-
"cell_type": "markdown",
|
339 |
-
"id": "cfcc2c1e-fd7f-4a36-99fc-ac7f43e5be6b",
|
340 |
-
"metadata": {},
|
341 |
-
"source": [
|
342 |
-
"## 5. Run In Silico Perturbation\n",
|
343 |
-
"\n",
|
344 |
-
"Run the in silico perturbation on the dataset."
|
345 |
-
]
|
346 |
-
},
|
347 |
-
{
|
348 |
-
"cell_type": "code",
|
349 |
-
"execution_count": null,
|
350 |
-
"id": "cf030c09-8ae4-45a7-aaf7-3fc2af4fe296",
|
351 |
-
"metadata": {},
|
352 |
-
"outputs": [],
|
353 |
-
"source": [
|
354 |
-
"# Run perturbation and output intermediate files\n",
|
355 |
-
"isp.perturb_data(\n",
|
356 |
-
" model_directory=model_directory,\n",
|
357 |
-
" input_data_file=input_data_file,\n",
|
358 |
-
" output_directory=output_directory,\n",
|
359 |
-
" output_prefix=output_prefix\n",
|
360 |
-
")"
|
361 |
-
]
|
362 |
-
},
|
363 |
-
{
|
364 |
-
"cell_type": "markdown",
|
365 |
-
"id": "bb8ec074-6f2f-422b-a973-37ed32a15c38",
|
366 |
-
"metadata": {},
|
367 |
-
"source": [
|
368 |
-
"## 6. Process Results with InSilicoPerturberStats\n",
|
369 |
-
"\n",
|
370 |
-
"After running the perturbation, we'll use InSilicoPerturberStats to process the intermediate files and generate the final statistics."
|
371 |
-
]
|
372 |
-
},
|
373 |
-
{
|
374 |
-
"cell_type": "code",
|
375 |
-
"execution_count": null,
|
376 |
-
"id": "0a748043-43fc-47ad-ace5-f0ae3dd34674",
|
377 |
-
"metadata": {},
|
378 |
-
"outputs": [],
|
379 |
-
"source": [
|
380 |
-
"# Initialize InSilicoPerturberStats\n",
|
381 |
-
"ispstats = InSilicoPerturberStats(\n",
|
382 |
-
" mode=\"goal_state_shift\",\n",
|
383 |
-
" genes_perturbed=\"all\",\n",
|
384 |
-
" combos=0,\n",
|
385 |
-
" anchor_gene=None,\n",
|
386 |
-
" cell_states_to_model=cell_states_to_model\n",
|
387 |
-
")\n",
|
388 |
-
"\n",
|
389 |
-
"# Process stats and output final .csv\n",
|
390 |
-
"ispstats.get_stats(\n",
|
391 |
-
" input_data_file,\n",
|
392 |
-
" None,\n",
|
393 |
-
" output_directory,\n",
|
394 |
-
" output_prefix\n",
|
395 |
-
")"
|
396 |
-
]
|
397 |
-
}
|
398 |
-
],
|
399 |
-
"metadata": {
|
400 |
-
"kernelspec": {
|
401 |
-
"display_name": "Python 3 (ipykernel)",
|
402 |
-
"language": "python",
|
403 |
-
"name": "python3"
|
404 |
-
},
|
405 |
-
"language_info": {
|
406 |
-
"codemirror_mode": {
|
407 |
-
"name": "ipython",
|
408 |
-
"version": 3
|
409 |
-
},
|
410 |
-
"file_extension": ".py",
|
411 |
-
"mimetype": "text/x-python",
|
412 |
-
"name": "python",
|
413 |
-
"nbconvert_exporter": "python",
|
414 |
-
"pygments_lexer": "ipython3",
|
415 |
-
"version": "3.11.5"
|
416 |
-
}
|
417 |
-
},
|
418 |
-
"nbformat": 4,
|
419 |
-
"nbformat_minor": 5
|
420 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/pretraining_new_model/pretrain_geneformer_w_deepspeed.py
CHANGED
@@ -138,9 +138,7 @@ training_args = {
|
|
138 |
"per_device_train_batch_size": geneformer_batch_size,
|
139 |
"num_train_epochs": epochs,
|
140 |
"save_strategy": "steps",
|
141 |
-
"save_steps": np.floor(
|
142 |
-
num_examples / geneformer_batch_size / 8
|
143 |
-
), # 8 saves per epoch
|
144 |
"logging_steps": 1000,
|
145 |
"output_dir": training_output_dir,
|
146 |
"logging_dir": logging_dir,
|
|
|
138 |
"per_device_train_batch_size": geneformer_batch_size,
|
139 |
"num_train_epochs": epochs,
|
140 |
"save_strategy": "steps",
|
141 |
+
"save_steps": np.floor(num_examples / geneformer_batch_size / 8), # 8 saves per epoch
|
|
|
|
|
142 |
"logging_steps": 1000,
|
143 |
"output_dir": training_output_dir,
|
144 |
"logging_dir": logging_dir,
|
examples/tokenizing_scRNAseq_data.ipynb
CHANGED
@@ -25,11 +25,7 @@
|
|
25 |
"\n",
|
26 |
"#### Additionally, if the original .loom file contains a cell column attribute called \"filter_pass\", this column will be used as a binary indicator of whether to include these cells in the tokenized data. All cells with \"1\" in this attribute will be tokenized, whereas the others will be excluded. One may use this column to indicate QC filtering or other criteria for selection for inclusion in the final tokenized dataset.\n",
|
27 |
"\n",
|
28 |
-
"#### If one's data is in other formats besides .loom or .h5ad, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom or .h5ad format prior to running the transcriptome tokenizer
|
29 |
-
"\n",
|
30 |
-
"#### OF NOTE: PLEASE ENSURE THE CORRECT TOKEN DICTIONARY AND GENE MEDIAN FILE IS USED FOR THE CORRECT MODEL.\n",
|
31 |
-
"\n",
|
32 |
-
"#### The 95M model series also require the special_token argument to be set to True and model_input_size to be 4096."
|
33 |
]
|
34 |
},
|
35 |
{
|
|
|
25 |
"\n",
|
26 |
"#### Additionally, if the original .loom file contains a cell column attribute called \"filter_pass\", this column will be used as a binary indicator of whether to include these cells in the tokenized data. All cells with \"1\" in this attribute will be tokenized, whereas the others will be excluded. One may use this column to indicate QC filtering or other criteria for selection for inclusion in the final tokenized dataset.\n",
|
27 |
"\n",
|
28 |
+
"#### If one's data is in other formats besides .loom or .h5ad, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom or .h5ad format prior to running the transcriptome tokenizer."
|
|
|
|
|
|
|
|
|
29 |
]
|
30 |
},
|
31 |
{
|
fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/config.json
RENAMED
File without changes
|
fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/optimizer.pt
RENAMED
File without changes
|
fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/pytorch_model.bin
RENAMED
File without changes
|
fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/rng_state.pth
RENAMED
File without changes
|
fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/scheduler.pt
RENAMED
File without changes
|
fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/trainer_state.json
RENAMED
File without changes
|
fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/training_args.bin
RENAMED
File without changes
|
fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/config.json
DELETED
@@ -1,24 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"architectures": [
|
3 |
-
"BertForMaskedLM"
|
4 |
-
],
|
5 |
-
"attention_probs_dropout_prob": 0.02,
|
6 |
-
"classifier_dropout": null,
|
7 |
-
"hidden_act": "relu",
|
8 |
-
"hidden_dropout_prob": 0.02,
|
9 |
-
"hidden_size": 512,
|
10 |
-
"initializer_range": 0.02,
|
11 |
-
"intermediate_size": 1024,
|
12 |
-
"layer_norm_eps": 1e-12,
|
13 |
-
"max_position_embeddings": 4096,
|
14 |
-
"model_type": "bert",
|
15 |
-
"num_attention_heads": 8,
|
16 |
-
"num_hidden_layers": 12,
|
17 |
-
"pad_token_id": 0,
|
18 |
-
"position_embedding_type": "absolute",
|
19 |
-
"torch_dtype": "float32",
|
20 |
-
"transformers_version": "4.37.2",
|
21 |
-
"type_vocab_size": 2,
|
22 |
-
"use_cache": true,
|
23 |
-
"vocab_size": 20275
|
24 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/pytorch_model.bin
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:07b28d8c7bb789d59755c42d32f6182cc04d2cf34aafaa6397aa50e4fdf1a9b4
|
3 |
-
size 152363342
|
|
|
|
|
|
|
|
{gf-12L-30M-i2048 → geneformer-12L-30M}/config.json
RENAMED
File without changes
|
{gf-12L-30M-i2048 → geneformer-12L-30M}/pytorch_model.bin
RENAMED
File without changes
|
{gf-12L-30M-i2048 → geneformer-12L-30M}/training_args.bin
RENAMED
File without changes
|
geneformer/__init__.py
CHANGED
@@ -1,14 +1,4 @@
|
|
1 |
# ruff: noqa: F401
|
2 |
-
import warnings
|
3 |
-
from pathlib import Path
|
4 |
-
|
5 |
-
warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*") # noqa # isort:skip
|
6 |
-
|
7 |
-
GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary_gc95M.pkl"
|
8 |
-
TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary_gc95M.pkl"
|
9 |
-
ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict_gc95M.pkl"
|
10 |
-
ENSEMBL_MAPPING_FILE = Path(__file__).parent / "ensembl_mapping_dict_gc95M.pkl"
|
11 |
-
|
12 |
from . import (
|
13 |
collator_for_classification,
|
14 |
emb_extractor,
|
@@ -21,7 +11,7 @@ from .collator_for_classification import (
|
|
21 |
DataCollatorForCellClassification,
|
22 |
DataCollatorForGeneClassification,
|
23 |
)
|
24 |
-
from .emb_extractor import EmbExtractor
|
25 |
from .in_silico_perturber import InSilicoPerturber
|
26 |
from .in_silico_perturber_stats import InSilicoPerturberStats
|
27 |
from .pretrainer import GeneformerPretrainer
|
@@ -29,6 +19,3 @@ from .tokenizer import TranscriptomeTokenizer
|
|
29 |
|
30 |
from . import classifier # noqa # isort:skip
|
31 |
from .classifier import Classifier # noqa # isort:skip
|
32 |
-
|
33 |
-
from . import mtl_classifier # noqa # isort:skip
|
34 |
-
from .mtl_classifier import MTLClassifier # noqa # isort:skip
|
|
|
1 |
# ruff: noqa: F401
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from . import (
|
3 |
collator_for_classification,
|
4 |
emb_extractor,
|
|
|
11 |
DataCollatorForCellClassification,
|
12 |
DataCollatorForGeneClassification,
|
13 |
)
|
14 |
+
from .emb_extractor import EmbExtractor
|
15 |
from .in_silico_perturber import InSilicoPerturber
|
16 |
from .in_silico_perturber_stats import InSilicoPerturberStats
|
17 |
from .pretrainer import GeneformerPretrainer
|
|
|
19 |
|
20 |
from . import classifier # noqa # isort:skip
|
21 |
from .classifier import Classifier # noqa # isort:skip
|
|
|
|
|
|
geneformer/classifier.py
CHANGED
@@ -53,18 +53,16 @@ from pathlib import Path
|
|
53 |
import numpy as np
|
54 |
import pandas as pd
|
55 |
import seaborn as sns
|
|
|
56 |
from tqdm.auto import tqdm, trange
|
57 |
from transformers import Trainer
|
58 |
from transformers.training_args import TrainingArguments
|
59 |
|
60 |
-
from . import
|
61 |
-
TOKEN_DICTIONARY_FILE,
|
62 |
-
DataCollatorForCellClassification,
|
63 |
-
DataCollatorForGeneClassification,
|
64 |
-
)
|
65 |
from . import classifier_utils as cu
|
66 |
from . import evaluation_utils as eu
|
67 |
from . import perturber_utils as pu
|
|
|
68 |
|
69 |
sns.set()
|
70 |
|
@@ -75,7 +73,6 @@ logger = logging.getLogger(__name__)
|
|
75 |
class Classifier:
|
76 |
valid_option_dict = {
|
77 |
"classifier": {"cell", "gene"},
|
78 |
-
"quantize": {bool, dict},
|
79 |
"cell_state_dict": {None, dict},
|
80 |
"gene_class_dict": {None, dict},
|
81 |
"filter_data": {None, dict},
|
@@ -89,7 +86,6 @@ class Classifier:
|
|
89 |
"no_eval": {bool},
|
90 |
"stratify_splits_col": {None, str},
|
91 |
"forward_batch_size": {int},
|
92 |
-
"token_dictionary_file": {None, str},
|
93 |
"nproc": {int},
|
94 |
"ngpu": {int},
|
95 |
}
|
@@ -97,7 +93,6 @@ class Classifier:
|
|
97 |
def __init__(
|
98 |
self,
|
99 |
classifier=None,
|
100 |
-
quantize=False,
|
101 |
cell_state_dict=None,
|
102 |
gene_class_dict=None,
|
103 |
filter_data=None,
|
@@ -112,7 +107,6 @@ class Classifier:
|
|
112 |
stratify_splits_col=None,
|
113 |
no_eval=False,
|
114 |
forward_batch_size=100,
|
115 |
-
token_dictionary_file=None,
|
116 |
nproc=4,
|
117 |
ngpu=1,
|
118 |
):
|
@@ -123,13 +117,6 @@ class Classifier:
|
|
123 |
|
124 |
classifier : {"cell", "gene"}
|
125 |
| Whether to fine-tune a cell state or gene classifier.
|
126 |
-
quantize : bool, dict
|
127 |
-
| Whether to fine-tune a quantized model.
|
128 |
-
| If True and no config provided, will use default.
|
129 |
-
| Will use custom config if provided.
|
130 |
-
| Configs should be provided as dictionary of BitsAndBytesConfig (transformers) and LoraConfig (peft).
|
131 |
-
| For example: {"bnb_config": BitsAndBytesConfig(...),
|
132 |
-
| "peft_config": LoraConfig(...)}
|
133 |
cell_state_dict : None, dict
|
134 |
| Cell states to fine-tune model to distinguish.
|
135 |
| Two-item dictionary with keys: state_key and states
|
@@ -188,9 +175,6 @@ class Classifier:
|
|
188 |
| Otherwise, will perform eval during training.
|
189 |
forward_batch_size : int
|
190 |
| Batch size for forward pass (for evaluation, not training).
|
191 |
-
token_dictionary_file : None, str
|
192 |
-
| Default is to use token dictionary file from Geneformer
|
193 |
-
| Otherwise, will load custom gene token dictionary.
|
194 |
nproc : int
|
195 |
| Number of CPU processes to use.
|
196 |
ngpu : int
|
@@ -199,11 +183,6 @@ class Classifier:
|
|
199 |
"""
|
200 |
|
201 |
self.classifier = classifier
|
202 |
-
if self.classifier == "cell":
|
203 |
-
self.model_type = "CellClassifier"
|
204 |
-
elif self.classifier == "gene":
|
205 |
-
self.model_type = "GeneClassifier"
|
206 |
-
self.quantize = quantize
|
207 |
self.cell_state_dict = cell_state_dict
|
208 |
self.gene_class_dict = gene_class_dict
|
209 |
self.filter_data = filter_data
|
@@ -222,7 +201,6 @@ class Classifier:
|
|
222 |
self.stratify_splits_col = stratify_splits_col
|
223 |
self.no_eval = no_eval
|
224 |
self.forward_batch_size = forward_batch_size
|
225 |
-
self.token_dictionary_file = token_dictionary_file
|
226 |
self.nproc = nproc
|
227 |
self.ngpu = ngpu
|
228 |
|
@@ -244,9 +222,7 @@ class Classifier:
|
|
244 |
] = self.cell_state_dict["states"]
|
245 |
|
246 |
# load token dictionary (Ensembl IDs:token)
|
247 |
-
|
248 |
-
self.token_dictionary_file = TOKEN_DICTIONARY_FILE
|
249 |
-
with open(self.token_dictionary_file, "rb") as f:
|
250 |
self.gene_token_dict = pickle.load(f)
|
251 |
|
252 |
self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
|
@@ -269,7 +245,7 @@ class Classifier:
|
|
269 |
f"Genes to classify {missing_genes} are not in token dictionary."
|
270 |
)
|
271 |
self.gene_class_dict = {
|
272 |
-
k:
|
273 |
for k, v in self.gene_class_dict.items()
|
274 |
}
|
275 |
empty_classes = []
|
@@ -291,7 +267,7 @@ class Classifier:
|
|
291 |
continue
|
292 |
valid_type = False
|
293 |
for option in valid_options:
|
294 |
-
if (option in [int, float, list, dict, bool
|
295 |
attr_value, option
|
296 |
):
|
297 |
valid_type = True
|
@@ -417,15 +393,6 @@ class Classifier:
|
|
417 |
)
|
418 |
raise
|
419 |
|
420 |
-
if (attr_to_split is not None) and (attr_to_balance is None):
|
421 |
-
logger.error(
|
422 |
-
"Splitting by attribute while balancing confounders requires both attr_to_split and attr_to_balance to be defined."
|
423 |
-
)
|
424 |
-
raise
|
425 |
-
|
426 |
-
if not isinstance(attr_to_balance, list):
|
427 |
-
attr_to_balance = [attr_to_balance]
|
428 |
-
|
429 |
if self.classifier == "cell":
|
430 |
# remove cell states representing < rare_threshold of cells
|
431 |
data = cu.remove_rare(
|
@@ -467,8 +434,8 @@ class Classifier:
|
|
467 |
test_data_output_path = (
|
468 |
Path(output_directory) / f"{output_prefix}_labeled_test"
|
469 |
).with_suffix(".dataset")
|
470 |
-
data_dict["train"].save_to_disk(
|
471 |
-
data_dict["test"].save_to_disk(
|
472 |
elif (test_size is not None) and (self.classifier == "cell"):
|
473 |
if 1 > test_size > 0:
|
474 |
if attr_to_split is None:
|
@@ -483,8 +450,8 @@ class Classifier:
|
|
483 |
test_data_output_path = (
|
484 |
Path(output_directory) / f"{output_prefix}_labeled_test"
|
485 |
).with_suffix(".dataset")
|
486 |
-
data_dict["train"].save_to_disk(
|
487 |
-
data_dict["test"].save_to_disk(
|
488 |
else:
|
489 |
data_dict, balance_df = cu.balance_attr_splits(
|
490 |
data,
|
@@ -505,19 +472,19 @@ class Classifier:
|
|
505 |
test_data_output_path = (
|
506 |
Path(output_directory) / f"{output_prefix}_labeled_test"
|
507 |
).with_suffix(".dataset")
|
508 |
-
data_dict["train"].save_to_disk(
|
509 |
-
data_dict["test"].save_to_disk(
|
510 |
else:
|
511 |
data_output_path = (
|
512 |
Path(output_directory) / f"{output_prefix}_labeled"
|
513 |
).with_suffix(".dataset")
|
514 |
-
data.save_to_disk(
|
515 |
print(data_output_path)
|
516 |
else:
|
517 |
data_output_path = (
|
518 |
Path(output_directory) / f"{output_prefix}_labeled"
|
519 |
).with_suffix(".dataset")
|
520 |
-
data.save_to_disk(
|
521 |
|
522 |
def train_all_data(
|
523 |
self,
|
@@ -527,7 +494,6 @@ class Classifier:
|
|
527 |
output_directory,
|
528 |
output_prefix,
|
529 |
save_eval_output=True,
|
530 |
-
gene_balance=False,
|
531 |
):
|
532 |
"""
|
533 |
Train cell state or gene classifier using all data.
|
@@ -548,9 +514,6 @@ class Classifier:
|
|
548 |
save_eval_output : bool
|
549 |
| Whether to save cross-fold eval output
|
550 |
| Saves as pickle file of dictionary of eval metrics
|
551 |
-
gene_balance : None, bool
|
552 |
-
| Whether to automatically balance genes in training set.
|
553 |
-
| Only available for binary gene classifications.
|
554 |
|
555 |
**Output**
|
556 |
|
@@ -558,12 +521,6 @@ class Classifier:
|
|
558 |
|
559 |
"""
|
560 |
|
561 |
-
if (gene_balance is True) and (len(self.gene_class_dict.values()) != 2):
|
562 |
-
logger.error(
|
563 |
-
"Automatically balancing gene sets for training is only available for binary gene classifications."
|
564 |
-
)
|
565 |
-
raise
|
566 |
-
|
567 |
##### Load data and prepare output directory #####
|
568 |
# load numerical id to class dictionary (id:class)
|
569 |
with open(id_class_dict_file, "rb") as f:
|
@@ -595,7 +552,7 @@ class Classifier:
|
|
595 |
)
|
596 |
assert len(targets) == len(labels)
|
597 |
data = cu.prep_gene_classifier_all_data(
|
598 |
-
data, targets, labels, self.max_ncells, self.nproc
|
599 |
)
|
600 |
|
601 |
trainer = self.train_classifier(
|
@@ -614,15 +571,12 @@ class Classifier:
|
|
614 |
split_id_dict=None,
|
615 |
attr_to_split=None,
|
616 |
attr_to_balance=None,
|
617 |
-
gene_balance=False,
|
618 |
max_trials=100,
|
619 |
pval_threshold=0.1,
|
620 |
save_eval_output=True,
|
621 |
predict_eval=True,
|
622 |
predict_trainer=False,
|
623 |
n_hyperopt_trials=0,
|
624 |
-
save_gene_split_datasets=True,
|
625 |
-
debug_gene_split_datasets=False,
|
626 |
):
|
627 |
"""
|
628 |
(Cross-)validate cell state or gene classifier.
|
@@ -657,9 +611,6 @@ class Classifier:
|
|
657 |
attr_to_balance : None, list
|
658 |
| List of attribute keys on which to balance data while splitting on attr_to_split
|
659 |
| e.g. ["age", "sex"] for balancing these characteristics while splitting by patient
|
660 |
-
gene_balance : None, bool
|
661 |
-
| Whether to automatically balance genes in training set.
|
662 |
-
| Only available for binary gene classifications.
|
663 |
max_trials : None, int
|
664 |
| Maximum number of trials of random splitting to try to achieve balanced other attribute
|
665 |
| If no split is found without significant (p < pval_threshold) differences in other attributes, will select best
|
@@ -678,19 +629,12 @@ class Classifier:
|
|
678 |
n_hyperopt_trials : int
|
679 |
| Number of trials to run for hyperparameter optimization
|
680 |
| If 0, will not optimize hyperparameters
|
681 |
-
save_gene_split_datasets : bool
|
682 |
-
| Whether or not to save train, valid, and test gene-labeled datasets
|
683 |
"""
|
|
|
684 |
if self.num_crossval_splits == 0:
|
685 |
logger.error("num_crossval_splits must be 1 or 5 to validate.")
|
686 |
raise
|
687 |
|
688 |
-
if (gene_balance is True) and (len(self.gene_class_dict.values()) != 2):
|
689 |
-
logger.error(
|
690 |
-
"Automatically balancing gene sets for training is only available for binary gene classifications."
|
691 |
-
)
|
692 |
-
raise
|
693 |
-
|
694 |
# ensure number of genes in each class is > 5 if validating model
|
695 |
if self.classifier == "gene":
|
696 |
insuff_classes = [k for k, v in self.gene_class_dict.items() if len(v) < 5]
|
@@ -771,7 +715,7 @@ class Classifier:
|
|
771 |
else:
|
772 |
# 5-fold cross-validate
|
773 |
num_cells = len(data)
|
774 |
-
fifth_cells =
|
775 |
num_eval = min((self.eval_size * num_cells), fifth_cells)
|
776 |
start = i * fifth_cells
|
777 |
end = start + num_eval
|
@@ -828,20 +772,17 @@ class Classifier:
|
|
828 |
]
|
829 |
)
|
830 |
assert len(targets) == len(labels)
|
831 |
-
n_splits = int(1 /
|
832 |
-
skf =
|
833 |
# (Cross-)validate
|
834 |
-
|
835 |
-
for train_index, eval_index, test_index in tqdm(
|
836 |
-
skf.split(targets, labels, test_ratio)
|
837 |
-
):
|
838 |
print(
|
839 |
f"****** Validation split: {iteration_num}/{self.num_crossval_splits} ******\n"
|
840 |
)
|
841 |
ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}")
|
842 |
# filter data for examples containing classes for this split
|
843 |
# subsample to max_ncells and relabel data in column "labels"
|
844 |
-
train_data, eval_data = cu.
|
845 |
data,
|
846 |
targets,
|
847 |
labels,
|
@@ -850,42 +791,8 @@ class Classifier:
|
|
850 |
self.max_ncells,
|
851 |
iteration_num,
|
852 |
self.nproc,
|
853 |
-
gene_balance,
|
854 |
)
|
855 |
|
856 |
-
if save_gene_split_datasets is True:
|
857 |
-
for split_name in ["train", "valid"]:
|
858 |
-
labeled_dataset_output_path = (
|
859 |
-
Path(output_dir)
|
860 |
-
/ f"{output_prefix}_{split_name}_gene_labeled_ksplit{iteration_num}"
|
861 |
-
).with_suffix(".dataset")
|
862 |
-
if split_name == "train":
|
863 |
-
train_data.save_to_disk(str(labeled_dataset_output_path))
|
864 |
-
elif split_name == "valid":
|
865 |
-
eval_data.save_to_disk(str(labeled_dataset_output_path))
|
866 |
-
|
867 |
-
if self.oos_test_size > 0:
|
868 |
-
test_data = cu.prep_gene_classifier_split(
|
869 |
-
data,
|
870 |
-
targets,
|
871 |
-
labels,
|
872 |
-
test_index,
|
873 |
-
"test",
|
874 |
-
self.max_ncells,
|
875 |
-
iteration_num,
|
876 |
-
self.nproc,
|
877 |
-
)
|
878 |
-
if save_gene_split_datasets is True:
|
879 |
-
test_labeled_dataset_output_path = (
|
880 |
-
Path(output_dir)
|
881 |
-
/ f"{output_prefix}_test_gene_labeled_ksplit{iteration_num}"
|
882 |
-
).with_suffix(".dataset")
|
883 |
-
test_data.save_to_disk(str(test_labeled_dataset_output_path))
|
884 |
-
if debug_gene_split_datasets is True:
|
885 |
-
logger.error(
|
886 |
-
"Exiting after saving gene split datasets given debug_gene_split_datasets = True."
|
887 |
-
)
|
888 |
-
raise
|
889 |
if n_hyperopt_trials == 0:
|
890 |
trainer = self.train_classifier(
|
891 |
model_directory,
|
@@ -895,15 +802,6 @@ class Classifier:
|
|
895 |
ksplit_output_dir,
|
896 |
predict_trainer,
|
897 |
)
|
898 |
-
result = self.evaluate_model(
|
899 |
-
trainer.model,
|
900 |
-
num_classes,
|
901 |
-
id_class_dict,
|
902 |
-
eval_data,
|
903 |
-
predict_eval,
|
904 |
-
ksplit_output_dir,
|
905 |
-
output_prefix,
|
906 |
-
)
|
907 |
else:
|
908 |
trainer = self.hyperopt_classifier(
|
909 |
model_directory,
|
@@ -913,27 +811,20 @@ class Classifier:
|
|
913 |
ksplit_output_dir,
|
914 |
n_trials=n_hyperopt_trials,
|
915 |
)
|
916 |
-
|
917 |
-
|
918 |
-
ksplit_output_dir, self.model_type, num_classes
|
919 |
-
)
|
920 |
-
|
921 |
-
if self.oos_test_size > 0:
|
922 |
-
result = self.evaluate_model(
|
923 |
-
model,
|
924 |
-
num_classes,
|
925 |
-
id_class_dict,
|
926 |
-
test_data,
|
927 |
-
predict_eval,
|
928 |
-
ksplit_output_dir,
|
929 |
-
output_prefix,
|
930 |
-
)
|
931 |
else:
|
932 |
-
|
933 |
-
|
934 |
-
|
935 |
-
|
936 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
937 |
results += [result]
|
938 |
all_conf_mat = all_conf_mat + result["conf_mat"]
|
939 |
# break after 1 or 5 splits, each with train/eval proportions dictated by eval_size
|
@@ -1034,13 +925,12 @@ class Classifier:
|
|
1034 |
subprocess.call(f"mkdir {output_directory}", shell=True)
|
1035 |
|
1036 |
##### Load model and training args #####
|
1037 |
-
|
1038 |
-
|
1039 |
-
|
1040 |
-
|
1041 |
-
|
1042 |
-
|
1043 |
-
)
|
1044 |
def_training_args, def_freeze_layers = cu.get_default_train_args(
|
1045 |
model, self.classifier, train_data, output_directory
|
1046 |
)
|
@@ -1056,31 +946,18 @@ class Classifier:
|
|
1056 |
if eval_data is None:
|
1057 |
def_training_args["evaluation_strategy"] = "no"
|
1058 |
def_training_args["load_best_model_at_end"] = False
|
1059 |
-
def_training_args.update(
|
1060 |
-
{"save_strategy": "epoch", "save_total_limit": 1}
|
1061 |
-
) # only save last model for each run
|
1062 |
training_args_init = TrainingArguments(**def_training_args)
|
1063 |
|
1064 |
##### Fine-tune the model #####
|
1065 |
# define the data collator
|
1066 |
if self.classifier == "cell":
|
1067 |
-
data_collator = DataCollatorForCellClassification(
|
1068 |
-
token_dictionary=self.gene_token_dict
|
1069 |
-
)
|
1070 |
elif self.classifier == "gene":
|
1071 |
-
data_collator = DataCollatorForGeneClassification(
|
1072 |
-
token_dictionary=self.gene_token_dict
|
1073 |
-
)
|
1074 |
|
1075 |
# define function to initiate model
|
1076 |
def model_init():
|
1077 |
-
model = pu.load_model(
|
1078 |
-
self.model_type,
|
1079 |
-
num_classes,
|
1080 |
-
model_directory,
|
1081 |
-
"train",
|
1082 |
-
quantize=self.quantize,
|
1083 |
-
)
|
1084 |
|
1085 |
if self.freeze_layers is not None:
|
1086 |
def_freeze_layers = self.freeze_layers
|
@@ -1091,8 +968,7 @@ class Classifier:
|
|
1091 |
for param in module.parameters():
|
1092 |
param.requires_grad = False
|
1093 |
|
1094 |
-
|
1095 |
-
model = model.to("cuda:0")
|
1096 |
return model
|
1097 |
|
1098 |
# create the trainer
|
@@ -1142,7 +1018,6 @@ class Classifier:
|
|
1142 |
metric="eval_macro_f1",
|
1143 |
metric_columns=["loss", "eval_loss", "eval_accuracy", "eval_macro_f1"],
|
1144 |
),
|
1145 |
-
storage_path=output_directory,
|
1146 |
)
|
1147 |
|
1148 |
return trainer
|
@@ -1205,13 +1080,11 @@ class Classifier:
|
|
1205 |
subprocess.call(f"mkdir {output_directory}", shell=True)
|
1206 |
|
1207 |
##### Load model and training args #####
|
1208 |
-
|
1209 |
-
|
1210 |
-
|
1211 |
-
|
1212 |
-
|
1213 |
-
quantize=self.quantize,
|
1214 |
-
)
|
1215 |
|
1216 |
def_training_args, def_freeze_layers = cu.get_default_train_args(
|
1217 |
model, self.classifier, train_data, output_directory
|
@@ -1241,13 +1114,9 @@ class Classifier:
|
|
1241 |
##### Fine-tune the model #####
|
1242 |
# define the data collator
|
1243 |
if self.classifier == "cell":
|
1244 |
-
data_collator = DataCollatorForCellClassification(
|
1245 |
-
token_dictionary=self.gene_token_dict
|
1246 |
-
)
|
1247 |
elif self.classifier == "gene":
|
1248 |
-
data_collator = DataCollatorForGeneClassification(
|
1249 |
-
token_dictionary=self.gene_token_dict
|
1250 |
-
)
|
1251 |
|
1252 |
# create the trainer
|
1253 |
trainer = Trainer(
|
@@ -1369,13 +1238,11 @@ class Classifier:
|
|
1369 |
test_data = pu.load_and_filter(None, self.nproc, test_data_file)
|
1370 |
|
1371 |
# load previously fine-tuned model
|
1372 |
-
|
1373 |
-
|
1374 |
-
|
1375 |
-
|
1376 |
-
|
1377 |
-
quantize=self.quantize,
|
1378 |
-
)
|
1379 |
|
1380 |
# evaluate the model
|
1381 |
result = self.evaluate_model(
|
|
|
53 |
import numpy as np
|
54 |
import pandas as pd
|
55 |
import seaborn as sns
|
56 |
+
from sklearn.model_selection import StratifiedKFold
|
57 |
from tqdm.auto import tqdm, trange
|
58 |
from transformers import Trainer
|
59 |
from transformers.training_args import TrainingArguments
|
60 |
|
61 |
+
from . import DataCollatorForCellClassification, DataCollatorForGeneClassification
|
|
|
|
|
|
|
|
|
62 |
from . import classifier_utils as cu
|
63 |
from . import evaluation_utils as eu
|
64 |
from . import perturber_utils as pu
|
65 |
+
from .tokenizer import TOKEN_DICTIONARY_FILE
|
66 |
|
67 |
sns.set()
|
68 |
|
|
|
73 |
class Classifier:
|
74 |
valid_option_dict = {
|
75 |
"classifier": {"cell", "gene"},
|
|
|
76 |
"cell_state_dict": {None, dict},
|
77 |
"gene_class_dict": {None, dict},
|
78 |
"filter_data": {None, dict},
|
|
|
86 |
"no_eval": {bool},
|
87 |
"stratify_splits_col": {None, str},
|
88 |
"forward_batch_size": {int},
|
|
|
89 |
"nproc": {int},
|
90 |
"ngpu": {int},
|
91 |
}
|
|
|
93 |
def __init__(
|
94 |
self,
|
95 |
classifier=None,
|
|
|
96 |
cell_state_dict=None,
|
97 |
gene_class_dict=None,
|
98 |
filter_data=None,
|
|
|
107 |
stratify_splits_col=None,
|
108 |
no_eval=False,
|
109 |
forward_batch_size=100,
|
|
|
110 |
nproc=4,
|
111 |
ngpu=1,
|
112 |
):
|
|
|
117 |
|
118 |
classifier : {"cell", "gene"}
|
119 |
| Whether to fine-tune a cell state or gene classifier.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
cell_state_dict : None, dict
|
121 |
| Cell states to fine-tune model to distinguish.
|
122 |
| Two-item dictionary with keys: state_key and states
|
|
|
175 |
| Otherwise, will perform eval during training.
|
176 |
forward_batch_size : int
|
177 |
| Batch size for forward pass (for evaluation, not training).
|
|
|
|
|
|
|
178 |
nproc : int
|
179 |
| Number of CPU processes to use.
|
180 |
ngpu : int
|
|
|
183 |
"""
|
184 |
|
185 |
self.classifier = classifier
|
|
|
|
|
|
|
|
|
|
|
186 |
self.cell_state_dict = cell_state_dict
|
187 |
self.gene_class_dict = gene_class_dict
|
188 |
self.filter_data = filter_data
|
|
|
201 |
self.stratify_splits_col = stratify_splits_col
|
202 |
self.no_eval = no_eval
|
203 |
self.forward_batch_size = forward_batch_size
|
|
|
204 |
self.nproc = nproc
|
205 |
self.ngpu = ngpu
|
206 |
|
|
|
222 |
] = self.cell_state_dict["states"]
|
223 |
|
224 |
# load token dictionary (Ensembl IDs:token)
|
225 |
+
with open(TOKEN_DICTIONARY_FILE, "rb") as f:
|
|
|
|
|
226 |
self.gene_token_dict = pickle.load(f)
|
227 |
|
228 |
self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
|
|
|
245 |
f"Genes to classify {missing_genes} are not in token dictionary."
|
246 |
)
|
247 |
self.gene_class_dict = {
|
248 |
+
k: set([self.gene_token_dict.get(gene) for gene in v])
|
249 |
for k, v in self.gene_class_dict.items()
|
250 |
}
|
251 |
empty_classes = []
|
|
|
267 |
continue
|
268 |
valid_type = False
|
269 |
for option in valid_options:
|
270 |
+
if (option in [int, float, list, dict, bool]) and isinstance(
|
271 |
attr_value, option
|
272 |
):
|
273 |
valid_type = True
|
|
|
393 |
)
|
394 |
raise
|
395 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
396 |
if self.classifier == "cell":
|
397 |
# remove cell states representing < rare_threshold of cells
|
398 |
data = cu.remove_rare(
|
|
|
434 |
test_data_output_path = (
|
435 |
Path(output_directory) / f"{output_prefix}_labeled_test"
|
436 |
).with_suffix(".dataset")
|
437 |
+
data_dict["train"].save_to_disk(train_data_output_path)
|
438 |
+
data_dict["test"].save_to_disk(test_data_output_path)
|
439 |
elif (test_size is not None) and (self.classifier == "cell"):
|
440 |
if 1 > test_size > 0:
|
441 |
if attr_to_split is None:
|
|
|
450 |
test_data_output_path = (
|
451 |
Path(output_directory) / f"{output_prefix}_labeled_test"
|
452 |
).with_suffix(".dataset")
|
453 |
+
data_dict["train"].save_to_disk(train_data_output_path)
|
454 |
+
data_dict["test"].save_to_disk(test_data_output_path)
|
455 |
else:
|
456 |
data_dict, balance_df = cu.balance_attr_splits(
|
457 |
data,
|
|
|
472 |
test_data_output_path = (
|
473 |
Path(output_directory) / f"{output_prefix}_labeled_test"
|
474 |
).with_suffix(".dataset")
|
475 |
+
data_dict["train"].save_to_disk(train_data_output_path)
|
476 |
+
data_dict["test"].save_to_disk(test_data_output_path)
|
477 |
else:
|
478 |
data_output_path = (
|
479 |
Path(output_directory) / f"{output_prefix}_labeled"
|
480 |
).with_suffix(".dataset")
|
481 |
+
data.save_to_disk(data_output_path)
|
482 |
print(data_output_path)
|
483 |
else:
|
484 |
data_output_path = (
|
485 |
Path(output_directory) / f"{output_prefix}_labeled"
|
486 |
).with_suffix(".dataset")
|
487 |
+
data.save_to_disk(data_output_path)
|
488 |
|
489 |
def train_all_data(
|
490 |
self,
|
|
|
494 |
output_directory,
|
495 |
output_prefix,
|
496 |
save_eval_output=True,
|
|
|
497 |
):
|
498 |
"""
|
499 |
Train cell state or gene classifier using all data.
|
|
|
514 |
save_eval_output : bool
|
515 |
| Whether to save cross-fold eval output
|
516 |
| Saves as pickle file of dictionary of eval metrics
|
|
|
|
|
|
|
517 |
|
518 |
**Output**
|
519 |
|
|
|
521 |
|
522 |
"""
|
523 |
|
|
|
|
|
|
|
|
|
|
|
|
|
524 |
##### Load data and prepare output directory #####
|
525 |
# load numerical id to class dictionary (id:class)
|
526 |
with open(id_class_dict_file, "rb") as f:
|
|
|
552 |
)
|
553 |
assert len(targets) == len(labels)
|
554 |
data = cu.prep_gene_classifier_all_data(
|
555 |
+
data, targets, labels, self.max_ncells, self.nproc
|
556 |
)
|
557 |
|
558 |
trainer = self.train_classifier(
|
|
|
571 |
split_id_dict=None,
|
572 |
attr_to_split=None,
|
573 |
attr_to_balance=None,
|
|
|
574 |
max_trials=100,
|
575 |
pval_threshold=0.1,
|
576 |
save_eval_output=True,
|
577 |
predict_eval=True,
|
578 |
predict_trainer=False,
|
579 |
n_hyperopt_trials=0,
|
|
|
|
|
580 |
):
|
581 |
"""
|
582 |
(Cross-)validate cell state or gene classifier.
|
|
|
611 |
attr_to_balance : None, list
|
612 |
| List of attribute keys on which to balance data while splitting on attr_to_split
|
613 |
| e.g. ["age", "sex"] for balancing these characteristics while splitting by patient
|
|
|
|
|
|
|
614 |
max_trials : None, int
|
615 |
| Maximum number of trials of random splitting to try to achieve balanced other attribute
|
616 |
| If no split is found without significant (p < pval_threshold) differences in other attributes, will select best
|
|
|
629 |
n_hyperopt_trials : int
|
630 |
| Number of trials to run for hyperparameter optimization
|
631 |
| If 0, will not optimize hyperparameters
|
|
|
|
|
632 |
"""
|
633 |
+
|
634 |
if self.num_crossval_splits == 0:
|
635 |
logger.error("num_crossval_splits must be 1 or 5 to validate.")
|
636 |
raise
|
637 |
|
|
|
|
|
|
|
|
|
|
|
|
|
638 |
# ensure number of genes in each class is > 5 if validating model
|
639 |
if self.classifier == "gene":
|
640 |
insuff_classes = [k for k, v in self.gene_class_dict.items() if len(v) < 5]
|
|
|
715 |
else:
|
716 |
# 5-fold cross-validate
|
717 |
num_cells = len(data)
|
718 |
+
fifth_cells = num_cells * 0.2
|
719 |
num_eval = min((self.eval_size * num_cells), fifth_cells)
|
720 |
start = i * fifth_cells
|
721 |
end = start + num_eval
|
|
|
772 |
]
|
773 |
)
|
774 |
assert len(targets) == len(labels)
|
775 |
+
n_splits = int(1 / self.eval_size)
|
776 |
+
skf = StratifiedKFold(n_splits=n_splits, random_state=0, shuffle=True)
|
777 |
# (Cross-)validate
|
778 |
+
for train_index, eval_index in tqdm(skf.split(targets, labels)):
|
|
|
|
|
|
|
779 |
print(
|
780 |
f"****** Validation split: {iteration_num}/{self.num_crossval_splits} ******\n"
|
781 |
)
|
782 |
ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}")
|
783 |
# filter data for examples containing classes for this split
|
784 |
# subsample to max_ncells and relabel data in column "labels"
|
785 |
+
train_data, eval_data = cu.prep_gene_classifier_split(
|
786 |
data,
|
787 |
targets,
|
788 |
labels,
|
|
|
791 |
self.max_ncells,
|
792 |
iteration_num,
|
793 |
self.nproc,
|
|
|
794 |
)
|
795 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
796 |
if n_hyperopt_trials == 0:
|
797 |
trainer = self.train_classifier(
|
798 |
model_directory,
|
|
|
802 |
ksplit_output_dir,
|
803 |
predict_trainer,
|
804 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
805 |
else:
|
806 |
trainer = self.hyperopt_classifier(
|
807 |
model_directory,
|
|
|
811 |
ksplit_output_dir,
|
812 |
n_trials=n_hyperopt_trials,
|
813 |
)
|
814 |
+
if iteration_num == self.num_crossval_splits:
|
815 |
+
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
816 |
else:
|
817 |
+
iteration_num = iteration_num + 1
|
818 |
+
continue
|
819 |
+
result = self.evaluate_model(
|
820 |
+
trainer.model,
|
821 |
+
num_classes,
|
822 |
+
id_class_dict,
|
823 |
+
eval_data,
|
824 |
+
predict_eval,
|
825 |
+
ksplit_output_dir,
|
826 |
+
output_prefix,
|
827 |
+
)
|
828 |
results += [result]
|
829 |
all_conf_mat = all_conf_mat + result["conf_mat"]
|
830 |
# break after 1 or 5 splits, each with train/eval proportions dictated by eval_size
|
|
|
925 |
subprocess.call(f"mkdir {output_directory}", shell=True)
|
926 |
|
927 |
##### Load model and training args #####
|
928 |
+
if self.classifier == "cell":
|
929 |
+
model_type = "CellClassifier"
|
930 |
+
elif self.classifier == "gene":
|
931 |
+
model_type = "GeneClassifier"
|
932 |
+
|
933 |
+
model = pu.load_model(model_type, num_classes, model_directory, "train")
|
|
|
934 |
def_training_args, def_freeze_layers = cu.get_default_train_args(
|
935 |
model, self.classifier, train_data, output_directory
|
936 |
)
|
|
|
946 |
if eval_data is None:
|
947 |
def_training_args["evaluation_strategy"] = "no"
|
948 |
def_training_args["load_best_model_at_end"] = False
|
|
|
|
|
|
|
949 |
training_args_init = TrainingArguments(**def_training_args)
|
950 |
|
951 |
##### Fine-tune the model #####
|
952 |
# define the data collator
|
953 |
if self.classifier == "cell":
|
954 |
+
data_collator = DataCollatorForCellClassification()
|
|
|
|
|
955 |
elif self.classifier == "gene":
|
956 |
+
data_collator = DataCollatorForGeneClassification()
|
|
|
|
|
957 |
|
958 |
# define function to initiate model
|
959 |
def model_init():
|
960 |
+
model = pu.load_model(model_type, num_classes, model_directory, "train")
|
|
|
|
|
|
|
|
|
|
|
|
|
961 |
|
962 |
if self.freeze_layers is not None:
|
963 |
def_freeze_layers = self.freeze_layers
|
|
|
968 |
for param in module.parameters():
|
969 |
param.requires_grad = False
|
970 |
|
971 |
+
model = model.to("cuda:0")
|
|
|
972 |
return model
|
973 |
|
974 |
# create the trainer
|
|
|
1018 |
metric="eval_macro_f1",
|
1019 |
metric_columns=["loss", "eval_loss", "eval_accuracy", "eval_macro_f1"],
|
1020 |
),
|
|
|
1021 |
)
|
1022 |
|
1023 |
return trainer
|
|
|
1080 |
subprocess.call(f"mkdir {output_directory}", shell=True)
|
1081 |
|
1082 |
##### Load model and training args #####
|
1083 |
+
if self.classifier == "cell":
|
1084 |
+
model_type = "CellClassifier"
|
1085 |
+
elif self.classifier == "gene":
|
1086 |
+
model_type = "GeneClassifier"
|
1087 |
+
model = pu.load_model(model_type, num_classes, model_directory, "train")
|
|
|
|
|
1088 |
|
1089 |
def_training_args, def_freeze_layers = cu.get_default_train_args(
|
1090 |
model, self.classifier, train_data, output_directory
|
|
|
1114 |
##### Fine-tune the model #####
|
1115 |
# define the data collator
|
1116 |
if self.classifier == "cell":
|
1117 |
+
data_collator = DataCollatorForCellClassification()
|
|
|
|
|
1118 |
elif self.classifier == "gene":
|
1119 |
+
data_collator = DataCollatorForGeneClassification()
|
|
|
|
|
1120 |
|
1121 |
# create the trainer
|
1122 |
trainer = Trainer(
|
|
|
1238 |
test_data = pu.load_and_filter(None, self.nproc, test_data_file)
|
1239 |
|
1240 |
# load previously fine-tuned model
|
1241 |
+
if self.classifier == "cell":
|
1242 |
+
model_type = "CellClassifier"
|
1243 |
+
elif self.classifier == "gene":
|
1244 |
+
model_type = "GeneClassifier"
|
1245 |
+
model = pu.load_model(model_type, num_classes, model_directory, "eval")
|
|
|
|
|
1246 |
|
1247 |
# evaluate the model
|
1248 |
result = self.evaluate_model(
|
geneformer/classifier_utils.py
CHANGED
@@ -1,6 +1,4 @@
|
|
1 |
-
import json
|
2 |
import logging
|
3 |
-
import os
|
4 |
import random
|
5 |
from collections import Counter, defaultdict
|
6 |
|
@@ -8,7 +6,6 @@ import numpy as np
|
|
8 |
import pandas as pd
|
9 |
from scipy.stats import chisquare, ranksums
|
10 |
from sklearn.metrics import accuracy_score, f1_score
|
11 |
-
from sklearn.model_selection import StratifiedKFold, train_test_split
|
12 |
|
13 |
from . import perturber_utils as pu
|
14 |
|
@@ -136,98 +133,64 @@ def label_gene_classes(example, class_id_dict, gene_class_dict):
|
|
136 |
]
|
137 |
|
138 |
|
139 |
-
def prep_gene_classifier_train_eval_split(
|
140 |
-
data,
|
141 |
-
targets,
|
142 |
-
labels,
|
143 |
-
train_index,
|
144 |
-
eval_index,
|
145 |
-
max_ncells,
|
146 |
-
iteration_num,
|
147 |
-
num_proc,
|
148 |
-
balance=False,
|
149 |
-
):
|
150 |
-
# generate cross-validation splits
|
151 |
-
train_data = prep_gene_classifier_split(
|
152 |
-
data,
|
153 |
-
targets,
|
154 |
-
labels,
|
155 |
-
train_index,
|
156 |
-
"train",
|
157 |
-
max_ncells,
|
158 |
-
iteration_num,
|
159 |
-
num_proc,
|
160 |
-
balance,
|
161 |
-
)
|
162 |
-
eval_data = prep_gene_classifier_split(
|
163 |
-
data,
|
164 |
-
targets,
|
165 |
-
labels,
|
166 |
-
eval_index,
|
167 |
-
"eval",
|
168 |
-
max_ncells,
|
169 |
-
iteration_num,
|
170 |
-
num_proc,
|
171 |
-
balance,
|
172 |
-
)
|
173 |
-
return train_data, eval_data
|
174 |
-
|
175 |
-
|
176 |
def prep_gene_classifier_split(
|
177 |
-
data,
|
178 |
-
targets,
|
179 |
-
labels,
|
180 |
-
index,
|
181 |
-
subset_name,
|
182 |
-
max_ncells,
|
183 |
-
iteration_num,
|
184 |
-
num_proc,
|
185 |
-
balance=False,
|
186 |
):
|
187 |
# generate cross-validation splits
|
188 |
targets = np.array(targets)
|
189 |
labels = np.array(labels)
|
190 |
-
|
191 |
-
|
192 |
-
|
|
|
193 |
|
194 |
# function to filter by whether contains train or eval labels
|
195 |
-
def
|
196 |
-
a =
|
|
|
|
|
|
|
|
|
|
|
197 |
b = example["input_ids"]
|
198 |
return not set(a).isdisjoint(b)
|
199 |
|
200 |
# filter dataset for examples containing classes for this split
|
201 |
-
logger.info(f"Filtering data for
|
202 |
-
|
203 |
logger.info(
|
204 |
-
f"Filtered {round((1-len(
|
|
|
|
|
|
|
|
|
|
|
205 |
)
|
206 |
-
|
207 |
-
# balance gene subsets if train
|
208 |
-
if (subset_name == "train") and (balance is True):
|
209 |
-
subset_data, label_dict_subset = balance_gene_split(
|
210 |
-
subset_data, label_dict_subset, num_proc
|
211 |
-
)
|
212 |
|
213 |
# subsample to max_ncells
|
214 |
-
|
|
|
215 |
|
216 |
# relabel genes for this split
|
217 |
-
def
|
218 |
example["labels"] = [
|
219 |
-
|
220 |
]
|
221 |
return example
|
222 |
|
223 |
-
|
|
|
|
|
|
|
|
|
224 |
|
225 |
-
|
|
|
226 |
|
|
|
227 |
|
228 |
-
|
229 |
-
|
230 |
-
):
|
231 |
targets = np.array(targets)
|
232 |
labels = np.array(labels)
|
233 |
label_dict_train = dict(zip(targets, labels))
|
@@ -245,11 +208,6 @@ def prep_gene_classifier_all_data(
|
|
245 |
f"Filtered {round((1-len(train_data)/len(data))*100)}%; {len(train_data)} remain\n"
|
246 |
)
|
247 |
|
248 |
-
if balance is True:
|
249 |
-
train_data, label_dict_train = balance_gene_split(
|
250 |
-
train_data, label_dict_train, num_proc
|
251 |
-
)
|
252 |
-
|
253 |
# subsample to max_ncells
|
254 |
train_data = downsample_and_shuffle(train_data, max_ncells, None, None)
|
255 |
|
@@ -265,145 +223,6 @@ def prep_gene_classifier_all_data(
|
|
265 |
return train_data
|
266 |
|
267 |
|
268 |
-
def balance_gene_split(subset_data, label_dict_subset, num_proc):
|
269 |
-
# count occurrence of genes in each label category
|
270 |
-
label0_counts, label1_counts = count_genes_for_balancing(
|
271 |
-
subset_data, label_dict_subset, num_proc
|
272 |
-
)
|
273 |
-
label_ratio_0to1 = label0_counts / label1_counts
|
274 |
-
|
275 |
-
if 8 / 10 <= label_ratio_0to1 <= 10 / 8:
|
276 |
-
# gene sets already balanced
|
277 |
-
logger.info(
|
278 |
-
"Gene sets were already balanced within 0.8-1.25 fold and did not require balancing.\n"
|
279 |
-
)
|
280 |
-
return subset_data, label_dict_subset
|
281 |
-
else:
|
282 |
-
label_ratio_0to1_orig = label_ratio_0to1 + 0
|
283 |
-
label_dict_subset_orig = label_dict_subset.copy()
|
284 |
-
# balance gene sets
|
285 |
-
max_ntrials = 25
|
286 |
-
boost = 1
|
287 |
-
if label_ratio_0to1 > 10 / 8:
|
288 |
-
# downsample label 0
|
289 |
-
for i in range(max_ntrials):
|
290 |
-
label0 = 0
|
291 |
-
label0_genes = [k for k, v in label_dict_subset.items() if v == label0]
|
292 |
-
label0_ngenes = len(label0_genes)
|
293 |
-
label0_nremove = max(
|
294 |
-
1,
|
295 |
-
int(
|
296 |
-
np.floor(
|
297 |
-
label0_ngenes - label0_ngenes / (label_ratio_0to1 * boost)
|
298 |
-
)
|
299 |
-
),
|
300 |
-
)
|
301 |
-
random.seed(i)
|
302 |
-
label0_remove_genes = random.sample(label0_genes, label0_nremove)
|
303 |
-
label_dict_subset_new = {
|
304 |
-
k: v
|
305 |
-
for k, v in label_dict_subset.items()
|
306 |
-
if k not in label0_remove_genes
|
307 |
-
}
|
308 |
-
label0_counts, label1_counts = count_genes_for_balancing(
|
309 |
-
subset_data, label_dict_subset_new, num_proc
|
310 |
-
)
|
311 |
-
label_ratio_0to1 = label0_counts / label1_counts
|
312 |
-
if 8 / 10 <= label_ratio_0to1 <= 10 / 8:
|
313 |
-
# if gene sets now balanced, return new filtered data and new label_dict_subset
|
314 |
-
return filter_data_balanced_genes(
|
315 |
-
subset_data, label_dict_subset_new, num_proc
|
316 |
-
)
|
317 |
-
elif label_ratio_0to1 > 10 / 8:
|
318 |
-
boost = boost * 1.1
|
319 |
-
elif label_ratio_0to1 < 8 / 10:
|
320 |
-
boost = boost * 0.9
|
321 |
-
else:
|
322 |
-
# downsample label 1
|
323 |
-
for i in range(max_ntrials):
|
324 |
-
label1 = 1
|
325 |
-
label1_genes = [k for k, v in label_dict_subset.items() if v == label1]
|
326 |
-
label1_ngenes = len(label1_genes)
|
327 |
-
label1_nremove = max(
|
328 |
-
1,
|
329 |
-
int(
|
330 |
-
np.floor(
|
331 |
-
label1_ngenes
|
332 |
-
- label1_ngenes / ((1 / label_ratio_0to1) * boost)
|
333 |
-
)
|
334 |
-
),
|
335 |
-
)
|
336 |
-
random.seed(i)
|
337 |
-
label1_remove_genes = random.sample(label1_genes, label1_nremove)
|
338 |
-
label_dict_subset_new = {
|
339 |
-
k: v
|
340 |
-
for k, v in label_dict_subset.items()
|
341 |
-
if k not in label1_remove_genes
|
342 |
-
}
|
343 |
-
label0_counts, label1_counts = count_genes_for_balancing(
|
344 |
-
subset_data, label_dict_subset_new, num_proc
|
345 |
-
)
|
346 |
-
label_ratio_0to1 = label0_counts / label1_counts
|
347 |
-
if 8 / 10 <= label_ratio_0to1 <= 10 / 8:
|
348 |
-
# if gene sets now balanced, return new filtered data and new label_dict_subset
|
349 |
-
return filter_data_balanced_genes(
|
350 |
-
subset_data, label_dict_subset_new, num_proc
|
351 |
-
)
|
352 |
-
elif label_ratio_0to1 < 8 / 10:
|
353 |
-
boost = boost * 1.1
|
354 |
-
elif label_ratio_0to1 > 10 / 8:
|
355 |
-
boost = boost * 0.9
|
356 |
-
|
357 |
-
assert i + 1 == max_ntrials
|
358 |
-
if (label_ratio_0to1 <= label_ratio_0to1_orig < 8 / 10) or (
|
359 |
-
10 / 8 > label_ratio_0to1_orig >= label_ratio_0to1
|
360 |
-
):
|
361 |
-
label_ratio_0to1 = label_ratio_0to1_orig
|
362 |
-
label_dict_subset_new = label_dict_subset_orig
|
363 |
-
logger.warning(
|
364 |
-
f"Gene sets were not able to be balanced within 0.8-1.25 fold after {max_ntrials} trials. Imbalance level: {label_ratio_0to1}\n"
|
365 |
-
)
|
366 |
-
return filter_data_balanced_genes(subset_data, label_dict_subset_new, num_proc)
|
367 |
-
|
368 |
-
|
369 |
-
def count_genes_for_balancing(subset_data, label_dict_subset, num_proc):
|
370 |
-
def count_targets(example):
|
371 |
-
labels = [
|
372 |
-
label_dict_subset.get(token_id, -100) for token_id in example["input_ids"]
|
373 |
-
]
|
374 |
-
counter_labels = Counter(labels)
|
375 |
-
# get count of labels 0 or 1, or if absent, return 0
|
376 |
-
example["labels_counts"] = [counter_labels.get(0, 0), counter_labels.get(1, 0)]
|
377 |
-
return example
|
378 |
-
|
379 |
-
subset_data = subset_data.map(count_targets, num_proc=num_proc)
|
380 |
-
|
381 |
-
label0_counts = sum([counts[0] for counts in subset_data["labels_counts"]])
|
382 |
-
label1_counts = sum([counts[1] for counts in subset_data["labels_counts"]])
|
383 |
-
|
384 |
-
subset_data = subset_data.remove_columns("labels_counts")
|
385 |
-
|
386 |
-
return label0_counts, label1_counts
|
387 |
-
|
388 |
-
|
389 |
-
def filter_data_balanced_genes(subset_data, label_dict_subset, num_proc):
|
390 |
-
# function to filter by whether contains labels
|
391 |
-
def if_contains_subset_label(example):
|
392 |
-
a = list(label_dict_subset.keys())
|
393 |
-
b = example["input_ids"]
|
394 |
-
return not set(a).isdisjoint(b)
|
395 |
-
|
396 |
-
# filter dataset for examples containing classes for this split
|
397 |
-
logger.info("Filtering data for balanced genes")
|
398 |
-
subset_data_len_orig = len(subset_data)
|
399 |
-
subset_data = subset_data.filter(if_contains_subset_label, num_proc=num_proc)
|
400 |
-
logger.info(
|
401 |
-
f"Filtered {round((1-len(subset_data)/subset_data_len_orig)*100)}%; {len(subset_data)} remain\n"
|
402 |
-
)
|
403 |
-
|
404 |
-
return subset_data, label_dict_subset
|
405 |
-
|
406 |
-
|
407 |
def balance_attr_splits(
|
408 |
data,
|
409 |
attr_to_split,
|
@@ -490,7 +309,7 @@ def balance_attr_splits(
|
|
490 |
exp_counts[cat] * sum(obs) / sum(exp_counts.values())
|
491 |
for cat in all_categ
|
492 |
]
|
493 |
-
|
494 |
train_attr_counts = str(obs_counts).strip("Counter(").strip(")")
|
495 |
eval_attr_counts = str(exp_counts).strip("Counter(").strip(")")
|
496 |
df_vals += [train_attr_counts, eval_attr_counts, pval]
|
@@ -604,45 +423,3 @@ def get_default_train_args(model, classifier, data, output_dir):
|
|
604 |
training_args.update(default_training_args)
|
605 |
|
606 |
return training_args, freeze_layers
|
607 |
-
|
608 |
-
|
609 |
-
def load_best_model(directory, model_type, num_classes, mode="eval"):
|
610 |
-
file_dict = dict()
|
611 |
-
for subdir, dirs, files in os.walk(directory):
|
612 |
-
for file in files:
|
613 |
-
if file.endswith("result.json"):
|
614 |
-
with open(f"{subdir}/{file}", "rb") as fp:
|
615 |
-
result_json = json.load(fp)
|
616 |
-
file_dict[f"{subdir}"] = result_json["eval_macro_f1"]
|
617 |
-
file_df = pd.DataFrame(
|
618 |
-
{"dir": file_dict.keys(), "eval_macro_f1": file_dict.values()}
|
619 |
-
)
|
620 |
-
model_superdir = (
|
621 |
-
"run-"
|
622 |
-
+ file_df.iloc[file_df["eval_macro_f1"].idxmax()]["dir"]
|
623 |
-
.split("_objective_")[2]
|
624 |
-
.split("_")[0]
|
625 |
-
)
|
626 |
-
|
627 |
-
for subdir, dirs, files in os.walk(f"{directory}/{model_superdir}"):
|
628 |
-
for file in files:
|
629 |
-
if file.endswith("model.safetensors"):
|
630 |
-
model = pu.load_model(model_type, num_classes, f"{subdir}", mode)
|
631 |
-
return model
|
632 |
-
|
633 |
-
|
634 |
-
class StratifiedKFold3(StratifiedKFold):
|
635 |
-
def split(self, targets, labels, test_ratio=0.5, groups=None):
|
636 |
-
s = super().split(targets, labels, groups)
|
637 |
-
for train_indxs, test_indxs in s:
|
638 |
-
if test_ratio == 0:
|
639 |
-
yield train_indxs, test_indxs, None
|
640 |
-
else:
|
641 |
-
labels_test = np.array(labels)[test_indxs]
|
642 |
-
valid_indxs, test_indxs = train_test_split(
|
643 |
-
test_indxs,
|
644 |
-
stratify=labels_test,
|
645 |
-
test_size=test_ratio,
|
646 |
-
random_state=0,
|
647 |
-
)
|
648 |
-
yield train_indxs, valid_indxs, test_indxs
|
|
|
|
|
1 |
import logging
|
|
|
2 |
import random
|
3 |
from collections import Counter, defaultdict
|
4 |
|
|
|
6 |
import pandas as pd
|
7 |
from scipy.stats import chisquare, ranksums
|
8 |
from sklearn.metrics import accuracy_score, f1_score
|
|
|
9 |
|
10 |
from . import perturber_utils as pu
|
11 |
|
|
|
133 |
]
|
134 |
|
135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
def prep_gene_classifier_split(
|
137 |
+
data, targets, labels, train_index, eval_index, max_ncells, iteration_num, num_proc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
):
|
139 |
# generate cross-validation splits
|
140 |
targets = np.array(targets)
|
141 |
labels = np.array(labels)
|
142 |
+
targets_train, targets_eval = targets[train_index], targets[eval_index]
|
143 |
+
labels_train, labels_eval = labels[train_index], labels[eval_index]
|
144 |
+
label_dict_train = dict(zip(targets_train, labels_train))
|
145 |
+
label_dict_eval = dict(zip(targets_eval, labels_eval))
|
146 |
|
147 |
# function to filter by whether contains train or eval labels
|
148 |
+
def if_contains_train_label(example):
|
149 |
+
a = targets_train
|
150 |
+
b = example["input_ids"]
|
151 |
+
return not set(a).isdisjoint(b)
|
152 |
+
|
153 |
+
def if_contains_eval_label(example):
|
154 |
+
a = targets_eval
|
155 |
b = example["input_ids"]
|
156 |
return not set(a).isdisjoint(b)
|
157 |
|
158 |
# filter dataset for examples containing classes for this split
|
159 |
+
logger.info(f"Filtering training data for genes in split {iteration_num}")
|
160 |
+
train_data = data.filter(if_contains_train_label, num_proc=num_proc)
|
161 |
logger.info(
|
162 |
+
f"Filtered {round((1-len(train_data)/len(data))*100)}%; {len(train_data)} remain\n"
|
163 |
+
)
|
164 |
+
logger.info(f"Filtering evalation data for genes in split {iteration_num}")
|
165 |
+
eval_data = data.filter(if_contains_eval_label, num_proc=num_proc)
|
166 |
+
logger.info(
|
167 |
+
f"Filtered {round((1-len(eval_data)/len(data))*100)}%; {len(eval_data)} remain\n"
|
168 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
# subsample to max_ncells
|
171 |
+
train_data = downsample_and_shuffle(train_data, max_ncells, None, None)
|
172 |
+
eval_data = downsample_and_shuffle(eval_data, max_ncells, None, None)
|
173 |
|
174 |
# relabel genes for this split
|
175 |
+
def train_classes_to_ids(example):
|
176 |
example["labels"] = [
|
177 |
+
label_dict_train.get(token_id, -100) for token_id in example["input_ids"]
|
178 |
]
|
179 |
return example
|
180 |
|
181 |
+
def eval_classes_to_ids(example):
|
182 |
+
example["labels"] = [
|
183 |
+
label_dict_eval.get(token_id, -100) for token_id in example["input_ids"]
|
184 |
+
]
|
185 |
+
return example
|
186 |
|
187 |
+
train_data = train_data.map(train_classes_to_ids, num_proc=num_proc)
|
188 |
+
eval_data = eval_data.map(eval_classes_to_ids, num_proc=num_proc)
|
189 |
|
190 |
+
return train_data, eval_data
|
191 |
|
192 |
+
|
193 |
+
def prep_gene_classifier_all_data(data, targets, labels, max_ncells, num_proc):
|
|
|
194 |
targets = np.array(targets)
|
195 |
labels = np.array(labels)
|
196 |
label_dict_train = dict(zip(targets, labels))
|
|
|
208 |
f"Filtered {round((1-len(train_data)/len(data))*100)}%; {len(train_data)} remain\n"
|
209 |
)
|
210 |
|
|
|
|
|
|
|
|
|
|
|
211 |
# subsample to max_ncells
|
212 |
train_data = downsample_and_shuffle(train_data, max_ncells, None, None)
|
213 |
|
|
|
223 |
return train_data
|
224 |
|
225 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
def balance_attr_splits(
|
227 |
data,
|
228 |
attr_to_split,
|
|
|
309 |
exp_counts[cat] * sum(obs) / sum(exp_counts.values())
|
310 |
for cat in all_categ
|
311 |
]
|
312 |
+
chisquare(f_obs=obs, f_exp=exp).pvalue
|
313 |
train_attr_counts = str(obs_counts).strip("Counter(").strip(")")
|
314 |
eval_attr_counts = str(exp_counts).strip("Counter(").strip(")")
|
315 |
df_vals += [train_attr_counts, eval_attr_counts, pval]
|
|
|
423 |
training_args.update(default_training_args)
|
424 |
|
425 |
return training_args, freeze_layers
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
geneformer/collator_for_classification.py
CHANGED
@@ -1,22 +1,24 @@
|
|
1 |
"""
|
2 |
Geneformer collator for gene and cell classification.
|
|
|
3 |
Huggingface data collator modified to accommodate single-cell transcriptomics data for gene and cell classification.
|
4 |
"""
|
5 |
-
|
|
|
6 |
import warnings
|
7 |
from enum import Enum
|
8 |
from typing import Dict, List, Optional, Union
|
9 |
|
10 |
-
import numpy as np
|
11 |
-
import torch
|
12 |
from transformers import (
|
13 |
-
BatchEncoding,
|
14 |
DataCollatorForTokenClassification,
|
15 |
SpecialTokensMixin,
|
|
|
16 |
)
|
17 |
from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
|
18 |
from transformers.utils.generic import _is_tensorflow, _is_torch
|
19 |
|
|
|
|
|
20 |
EncodedInput = List[int]
|
21 |
logger = logging.get_logger(__name__)
|
22 |
VERY_LARGE_INTEGER = int(
|
@@ -28,7 +30,6 @@ LARGE_INTEGER = int(
|
|
28 |
|
29 |
# precollator functions
|
30 |
|
31 |
-
|
32 |
class ExplicitEnum(Enum):
|
33 |
"""
|
34 |
Enum with more explicit error message for missing values.
|
@@ -41,7 +42,6 @@ class ExplicitEnum(Enum):
|
|
41 |
% (value, cls.__name__, str(list(cls._value2member_map_.keys())))
|
42 |
)
|
43 |
|
44 |
-
|
45 |
class TruncationStrategy(ExplicitEnum):
|
46 |
"""
|
47 |
Possible values for the ``truncation`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
|
@@ -54,6 +54,7 @@ class TruncationStrategy(ExplicitEnum):
|
|
54 |
DO_NOT_TRUNCATE = "do_not_truncate"
|
55 |
|
56 |
|
|
|
57 |
class PaddingStrategy(ExplicitEnum):
|
58 |
"""
|
59 |
Possible values for the ``padding`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for tab-completion
|
@@ -65,6 +66,7 @@ class PaddingStrategy(ExplicitEnum):
|
|
65 |
DO_NOT_PAD = "do_not_pad"
|
66 |
|
67 |
|
|
|
68 |
class TensorType(ExplicitEnum):
|
69 |
"""
|
70 |
Possible values for the ``return_tensors`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
|
@@ -76,41 +78,21 @@ class TensorType(ExplicitEnum):
|
|
76 |
NUMPY = "np"
|
77 |
JAX = "jax"
|
78 |
|
79 |
-
|
80 |
class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
self.token_dictionary.get("<pad>"),
|
92 |
-
]
|
93 |
-
|
94 |
-
@property
|
95 |
-
def all_special_ids(self):
|
96 |
-
return self._all_special_ids
|
97 |
-
|
98 |
-
@property
|
99 |
-
def mask_token_id(self):
|
100 |
-
return self._mask_token_id
|
101 |
-
|
102 |
-
@property
|
103 |
-
def pad_token_id(self):
|
104 |
-
return self._pad_token_id
|
105 |
|
106 |
def _get_padding_truncation_strategies(
|
107 |
-
self,
|
108 |
-
padding=True,
|
109 |
-
truncation=False,
|
110 |
-
max_length=None,
|
111 |
-
pad_to_multiple_of=None,
|
112 |
-
verbose=True,
|
113 |
-
**kwargs,
|
114 |
):
|
115 |
"""
|
116 |
Find the correct padding/truncation strategy with backward compatibility for old arguments (truncation_strategy
|
@@ -123,9 +105,7 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
|
123 |
# If you only set max_length, it activates truncation for max_length
|
124 |
if max_length is not None and padding is False and truncation is False:
|
125 |
if verbose:
|
126 |
-
if not self.deprecation_warnings.get(
|
127 |
-
"Truncation-not-explicitly-activated", False
|
128 |
-
):
|
129 |
logger.warning(
|
130 |
"Truncation was not explicitly activated but `max_length` is provided a specific value, "
|
131 |
"please use `truncation=True` to explicitly truncate examples to max length. "
|
@@ -153,9 +133,7 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
|
153 |
padding_strategy = PaddingStrategy.MAX_LENGTH
|
154 |
elif padding is not False:
|
155 |
if padding is True:
|
156 |
-
padding_strategy =
|
157 |
-
PaddingStrategy.LONGEST
|
158 |
-
) # Default to pad to the longest sequence in the batch
|
159 |
elif not isinstance(padding, PaddingStrategy):
|
160 |
padding_strategy = PaddingStrategy(padding)
|
161 |
elif isinstance(padding, PaddingStrategy):
|
@@ -195,9 +173,7 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
|
195 |
if padding_strategy == PaddingStrategy.MAX_LENGTH:
|
196 |
if self.model_max_length > LARGE_INTEGER:
|
197 |
if verbose:
|
198 |
-
if not self.deprecation_warnings.get(
|
199 |
-
"Asking-to-pad-to-max_length", False
|
200 |
-
):
|
201 |
logger.warning(
|
202 |
"Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. "
|
203 |
"Default to no padding."
|
@@ -210,24 +186,18 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
|
210 |
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE:
|
211 |
if self.model_max_length > LARGE_INTEGER:
|
212 |
if verbose:
|
213 |
-
if not self.deprecation_warnings.get(
|
214 |
-
"Asking-to-truncate-to-max_length", False
|
215 |
-
):
|
216 |
logger.warning(
|
217 |
"Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. "
|
218 |
"Default to no truncation."
|
219 |
)
|
220 |
-
self.deprecation_warnings[
|
221 |
-
"Asking-to-truncate-to-max_length"
|
222 |
-
] = True
|
223 |
truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
|
224 |
else:
|
225 |
max_length = self.model_max_length
|
226 |
|
227 |
# Test if we have a padding token
|
228 |
-
if padding_strategy != PaddingStrategy.DO_NOT_PAD and (
|
229 |
-
not self.pad_token or self.pad_token_id < 0
|
230 |
-
):
|
231 |
raise ValueError(
|
232 |
"Asking to pad but the tokenizer does not have a padding token. "
|
233 |
"Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` "
|
@@ -258,7 +228,7 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
|
258 |
Dict[str, List[EncodedInput]],
|
259 |
List[Dict[str, EncodedInput]],
|
260 |
],
|
261 |
-
class_type,
|
262 |
padding: Union[bool, str, PaddingStrategy] = True,
|
263 |
max_length: Optional[int] = None,
|
264 |
pad_to_multiple_of: Optional[int] = None,
|
@@ -269,23 +239,29 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
|
269 |
"""
|
270 |
Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length
|
271 |
in the batch.
|
|
|
272 |
Padding side (left/right) padding token ids are defined at the tokenizer level (with ``self.padding_side``,
|
273 |
``self.pad_token_id`` and ``self.pad_token_type_id``)
|
|
|
274 |
.. note::
|
|
|
275 |
If the ``encoded_inputs`` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the
|
276 |
result will use the same type unless you provide a different tensor type with ``return_tensors``. In the
|
277 |
case of PyTorch tensors, you will lose the specific device of your tensors however.
|
|
|
278 |
Args:
|
279 |
encoded_inputs (:class:`~transformers.BatchEncoding`, list of :class:`~transformers.BatchEncoding`, :obj:`Dict[str, List[int]]`, :obj:`Dict[str, List[List[int]]` or :obj:`List[Dict[str, List[int]]]`):
|
280 |
Tokenized inputs. Can represent one input (:class:`~transformers.BatchEncoding` or :obj:`Dict[str,
|
281 |
List[int]]`) or a batch of tokenized inputs (list of :class:`~transformers.BatchEncoding`, `Dict[str,
|
282 |
List[List[int]]]` or `List[Dict[str, List[int]]]`) so you can use this method during preprocessing as
|
283 |
well as in a PyTorch Dataloader collate function.
|
|
|
284 |
Instead of :obj:`List[int]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors),
|
285 |
see the note above for the return type.
|
286 |
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
287 |
Select a strategy to pad the returned sequences (according to the model's padding side and padding
|
288 |
index) among:
|
|
|
289 |
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
|
290 |
single sequence if provided).
|
291 |
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
@@ -296,14 +272,17 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
|
296 |
Maximum length of the returned list and optionally padding length (see above).
|
297 |
pad_to_multiple_of (:obj:`int`, `optional`):
|
298 |
If set will pad the sequence to a multiple of the provided value.
|
|
|
299 |
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
|
300 |
>= 7.5 (Volta).
|
301 |
return_attention_mask (:obj:`bool`, `optional`):
|
302 |
Whether to return the attention mask. If left to the default, will return the attention mask according
|
303 |
to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute.
|
|
|
304 |
`What are attention masks? <../glossary.html#attention-mask>`__
|
305 |
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`):
|
306 |
If set, will return tensors instead of list of python integers. Acceptable values are:
|
|
|
307 |
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
|
308 |
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
|
309 |
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
|
@@ -312,13 +291,8 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
|
312 |
"""
|
313 |
# If we have a list of dicts, let's convert it in a dict of lists
|
314 |
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
|
315 |
-
if isinstance(encoded_inputs, (list, tuple)) and isinstance(
|
316 |
-
encoded_inputs[0]
|
317 |
-
):
|
318 |
-
encoded_inputs = {
|
319 |
-
key: [example[key] for example in encoded_inputs]
|
320 |
-
for key in encoded_inputs[0].keys()
|
321 |
-
}
|
322 |
|
323 |
# The model's main input name, usually `input_ids`, has be passed for padding
|
324 |
if self.model_input_names[0] not in encoded_inputs:
|
@@ -412,7 +386,7 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
|
412 |
def _pad(
|
413 |
self,
|
414 |
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
|
415 |
-
class_type,
|
416 |
max_length: Optional[int] = None,
|
417 |
padding_strategy: PaddingStrategy = PaddingStrategy.LONGEST,
|
418 |
pad_to_multiple_of: Optional[int] = None,
|
@@ -420,15 +394,18 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
|
420 |
) -> dict:
|
421 |
"""
|
422 |
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
|
|
|
423 |
Args:
|
424 |
encoded_inputs: Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
|
425 |
max_length: maximum length of the returned list and optionally padding length (see below).
|
426 |
Will truncate by taking into account the special tokens.
|
427 |
padding_strategy: PaddingStrategy to use for padding.
|
|
|
428 |
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
|
429 |
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
|
430 |
- PaddingStrategy.DO_NOT_PAD: Do not pad
|
431 |
The tokenizer padding sides are defined in self.padding_side:
|
|
|
432 |
- 'left': pads on the left of the sequences
|
433 |
- 'right': pads on the right of the sequences
|
434 |
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
|
@@ -445,73 +422,46 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
|
445 |
if padding_strategy == PaddingStrategy.LONGEST:
|
446 |
max_length = len(required_input)
|
447 |
|
448 |
-
if (
|
449 |
-
max_length is not None
|
450 |
-
and pad_to_multiple_of is not None
|
451 |
-
and (max_length % pad_to_multiple_of != 0)
|
452 |
-
):
|
453 |
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
454 |
|
455 |
-
needs_to_be_padded = (
|
456 |
-
padding_strategy != PaddingStrategy.DO_NOT_PAD
|
457 |
-
and len(required_input) != max_length
|
458 |
-
)
|
459 |
|
460 |
if needs_to_be_padded:
|
461 |
difference = max_length - len(required_input)
|
462 |
if self.padding_side == "right":
|
463 |
if return_attention_mask:
|
464 |
-
encoded_inputs["attention_mask"] = [1] * len(required_input) + [
|
465 |
-
0
|
466 |
-
] * difference
|
467 |
if "token_type_ids" in encoded_inputs:
|
468 |
encoded_inputs["token_type_ids"] = (
|
469 |
-
encoded_inputs["token_type_ids"]
|
470 |
-
+ [self.pad_token_type_id] * difference
|
471 |
)
|
472 |
if "special_tokens_mask" in encoded_inputs:
|
473 |
-
encoded_inputs["special_tokens_mask"] =
|
474 |
-
|
475 |
-
)
|
476 |
-
encoded_inputs[self.model_input_names[0]] = (
|
477 |
-
required_input + [self.pad_token_id] * difference
|
478 |
-
)
|
479 |
if class_type == "gene":
|
480 |
-
encoded_inputs["labels"] =
|
481 |
-
encoded_inputs["labels"] + [-100] * difference
|
482 |
-
)
|
483 |
elif self.padding_side == "left":
|
484 |
if return_attention_mask:
|
485 |
-
encoded_inputs["attention_mask"] = [0] * difference + [1] * len(
|
486 |
-
required_input
|
487 |
-
)
|
488 |
if "token_type_ids" in encoded_inputs:
|
489 |
-
encoded_inputs["token_type_ids"] = [
|
490 |
-
|
491 |
-
]
|
492 |
if "special_tokens_mask" in encoded_inputs:
|
493 |
-
encoded_inputs["special_tokens_mask"] = [
|
494 |
-
|
495 |
-
] * difference + encoded_inputs["special_tokens_mask"]
|
496 |
-
encoded_inputs[self.model_input_names[0]] = [
|
497 |
-
self.pad_token_id
|
498 |
-
] * difference + required_input
|
499 |
if class_type == "gene":
|
500 |
-
encoded_inputs["labels"] = [-100] * difference + encoded_inputs[
|
501 |
-
"labels"
|
502 |
-
]
|
503 |
else:
|
504 |
raise ValueError("Invalid padding strategy:" + str(self.padding_side))
|
505 |
elif return_attention_mask and "attention_mask" not in encoded_inputs:
|
506 |
encoded_inputs["attention_mask"] = [1] * len(required_input)
|
507 |
-
|
508 |
return encoded_inputs
|
509 |
|
510 |
def get_special_tokens_mask(
|
511 |
-
self,
|
512 |
-
token_ids_0: List[int],
|
513 |
-
token_ids_1: Optional[List[int]] = None,
|
514 |
-
already_has_special_tokens: bool = False,
|
515 |
) -> List[int]:
|
516 |
"""
|
517 |
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
@@ -535,15 +485,11 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
|
535 |
|
536 |
all_special_ids = self.all_special_ids # cache the property
|
537 |
|
538 |
-
special_tokens_mask = [
|
539 |
-
1 if token in all_special_ids else 0 for token in token_ids_0
|
540 |
-
]
|
541 |
|
542 |
return special_tokens_mask
|
543 |
|
544 |
-
def convert_tokens_to_ids(
|
545 |
-
self, tokens: Union[str, List[str]]
|
546 |
-
) -> Union[int, List[int]]:
|
547 |
"""
|
548 |
Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the
|
549 |
vocabulary.
|
@@ -567,15 +513,14 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
|
567 |
if token is None:
|
568 |
return None
|
569 |
|
570 |
-
return
|
571 |
|
572 |
def __len__(self):
|
573 |
-
return len(
|
574 |
|
575 |
|
576 |
# collator functions
|
577 |
|
578 |
-
|
579 |
class DataCollatorForGeneClassification(DataCollatorForTokenClassification):
|
580 |
"""
|
581 |
Data collator that will dynamically pad the inputs received, as well as the labels.
|
@@ -601,33 +546,25 @@ class DataCollatorForGeneClassification(DataCollatorForTokenClassification):
|
|
601 |
The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
|
602 |
"""
|
603 |
|
|
|
604 |
class_type = "gene"
|
605 |
padding: Union[bool, str, PaddingStrategy] = True
|
606 |
max_length: Optional[int] = None
|
607 |
pad_to_multiple_of: Optional[int] = None
|
608 |
label_pad_token_id: int = -100
|
609 |
-
|
610 |
def __init__(self, *args, **kwargs) -> None:
|
611 |
-
self.token_dictionary = kwargs.pop("token_dictionary")
|
612 |
super().__init__(
|
613 |
-
tokenizer=
|
614 |
-
token_dictionary=self.token_dictionary
|
615 |
-
),
|
616 |
padding=self.padding,
|
617 |
max_length=self.max_length,
|
618 |
pad_to_multiple_of=self.pad_to_multiple_of,
|
619 |
label_pad_token_id=self.label_pad_token_id,
|
620 |
-
*args,
|
621 |
-
**kwargs,
|
622 |
-
)
|
623 |
|
624 |
def _prepare_batch(self, features):
|
625 |
label_name = "label" if "label" in features[0].keys() else "labels"
|
626 |
-
labels = (
|
627 |
-
[feature[label_name] for feature in features]
|
628 |
-
if label_name in features[0].keys()
|
629 |
-
else None
|
630 |
-
)
|
631 |
batch = self.tokenizer.pad(
|
632 |
features,
|
633 |
class_type=self.class_type,
|
@@ -637,31 +574,29 @@ class DataCollatorForGeneClassification(DataCollatorForTokenClassification):
|
|
637 |
return_tensors="pt",
|
638 |
)
|
639 |
return batch
|
640 |
-
|
641 |
def __call__(self, features):
|
642 |
batch = self._prepare_batch(features)
|
643 |
|
644 |
batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
|
645 |
return batch
|
646 |
|
647 |
-
|
648 |
class DataCollatorForCellClassification(DataCollatorForGeneClassification):
|
|
|
649 |
class_type = "cell"
|
650 |
|
651 |
def _prepare_batch(self, features):
|
|
|
652 |
batch = super()._prepare_batch(features)
|
653 |
-
|
654 |
# Special handling for labels.
|
655 |
# Ensure that tensor is created with the correct type
|
656 |
# (it should be automatically the case, but let's make sure of it.)
|
657 |
first = features[0]
|
658 |
if "label" in first and first["label"] is not None:
|
659 |
-
label = (
|
660 |
-
first["label"].item()
|
661 |
-
if isinstance(first["label"], torch.Tensor)
|
662 |
-
else first["label"]
|
663 |
-
)
|
664 |
dtype = torch.long if isinstance(label, int) else torch.float
|
665 |
batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
|
666 |
-
|
667 |
return batch
|
|
|
1 |
"""
|
2 |
Geneformer collator for gene and cell classification.
|
3 |
+
|
4 |
Huggingface data collator modified to accommodate single-cell transcriptomics data for gene and cell classification.
|
5 |
"""
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
import warnings
|
9 |
from enum import Enum
|
10 |
from typing import Dict, List, Optional, Union
|
11 |
|
|
|
|
|
12 |
from transformers import (
|
|
|
13 |
DataCollatorForTokenClassification,
|
14 |
SpecialTokensMixin,
|
15 |
+
BatchEncoding,
|
16 |
)
|
17 |
from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
|
18 |
from transformers.utils.generic import _is_tensorflow, _is_torch
|
19 |
|
20 |
+
from .pretrainer import token_dictionary
|
21 |
+
|
22 |
EncodedInput = List[int]
|
23 |
logger = logging.get_logger(__name__)
|
24 |
VERY_LARGE_INTEGER = int(
|
|
|
30 |
|
31 |
# precollator functions
|
32 |
|
|
|
33 |
class ExplicitEnum(Enum):
|
34 |
"""
|
35 |
Enum with more explicit error message for missing values.
|
|
|
42 |
% (value, cls.__name__, str(list(cls._value2member_map_.keys())))
|
43 |
)
|
44 |
|
|
|
45 |
class TruncationStrategy(ExplicitEnum):
|
46 |
"""
|
47 |
Possible values for the ``truncation`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
|
|
|
54 |
DO_NOT_TRUNCATE = "do_not_truncate"
|
55 |
|
56 |
|
57 |
+
|
58 |
class PaddingStrategy(ExplicitEnum):
|
59 |
"""
|
60 |
Possible values for the ``padding`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for tab-completion
|
|
|
66 |
DO_NOT_PAD = "do_not_pad"
|
67 |
|
68 |
|
69 |
+
|
70 |
class TensorType(ExplicitEnum):
|
71 |
"""
|
72 |
Possible values for the ``return_tensors`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
|
|
|
78 |
NUMPY = "np"
|
79 |
JAX = "jax"
|
80 |
|
81 |
+
|
82 |
class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
83 |
+
mask_token = "<mask>"
|
84 |
+
mask_token_id = token_dictionary.get("<mask>")
|
85 |
+
pad_token = "<pad>"
|
86 |
+
pad_token_id = token_dictionary.get("<pad>")
|
87 |
+
padding_side = "right"
|
88 |
+
all_special_ids = [
|
89 |
+
token_dictionary.get("<mask>"),
|
90 |
+
token_dictionary.get("<pad>")
|
91 |
+
]
|
92 |
+
model_input_names = ["input_ids"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
def _get_padding_truncation_strategies(
|
95 |
+
self, padding=True, truncation=False, max_length=None, pad_to_multiple_of=None, verbose=True, **kwargs
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
):
|
97 |
"""
|
98 |
Find the correct padding/truncation strategy with backward compatibility for old arguments (truncation_strategy
|
|
|
105 |
# If you only set max_length, it activates truncation for max_length
|
106 |
if max_length is not None and padding is False and truncation is False:
|
107 |
if verbose:
|
108 |
+
if not self.deprecation_warnings.get("Truncation-not-explicitly-activated", False):
|
|
|
|
|
109 |
logger.warning(
|
110 |
"Truncation was not explicitly activated but `max_length` is provided a specific value, "
|
111 |
"please use `truncation=True` to explicitly truncate examples to max length. "
|
|
|
133 |
padding_strategy = PaddingStrategy.MAX_LENGTH
|
134 |
elif padding is not False:
|
135 |
if padding is True:
|
136 |
+
padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch
|
|
|
|
|
137 |
elif not isinstance(padding, PaddingStrategy):
|
138 |
padding_strategy = PaddingStrategy(padding)
|
139 |
elif isinstance(padding, PaddingStrategy):
|
|
|
173 |
if padding_strategy == PaddingStrategy.MAX_LENGTH:
|
174 |
if self.model_max_length > LARGE_INTEGER:
|
175 |
if verbose:
|
176 |
+
if not self.deprecation_warnings.get("Asking-to-pad-to-max_length", False):
|
|
|
|
|
177 |
logger.warning(
|
178 |
"Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. "
|
179 |
"Default to no padding."
|
|
|
186 |
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE:
|
187 |
if self.model_max_length > LARGE_INTEGER:
|
188 |
if verbose:
|
189 |
+
if not self.deprecation_warnings.get("Asking-to-truncate-to-max_length", False):
|
|
|
|
|
190 |
logger.warning(
|
191 |
"Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. "
|
192 |
"Default to no truncation."
|
193 |
)
|
194 |
+
self.deprecation_warnings["Asking-to-truncate-to-max_length"] = True
|
|
|
|
|
195 |
truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
|
196 |
else:
|
197 |
max_length = self.model_max_length
|
198 |
|
199 |
# Test if we have a padding token
|
200 |
+
if padding_strategy != PaddingStrategy.DO_NOT_PAD and (not self.pad_token or self.pad_token_id < 0):
|
|
|
|
|
201 |
raise ValueError(
|
202 |
"Asking to pad but the tokenizer does not have a padding token. "
|
203 |
"Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` "
|
|
|
228 |
Dict[str, List[EncodedInput]],
|
229 |
List[Dict[str, EncodedInput]],
|
230 |
],
|
231 |
+
class_type, # options: "gene" or "cell"
|
232 |
padding: Union[bool, str, PaddingStrategy] = True,
|
233 |
max_length: Optional[int] = None,
|
234 |
pad_to_multiple_of: Optional[int] = None,
|
|
|
239 |
"""
|
240 |
Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length
|
241 |
in the batch.
|
242 |
+
|
243 |
Padding side (left/right) padding token ids are defined at the tokenizer level (with ``self.padding_side``,
|
244 |
``self.pad_token_id`` and ``self.pad_token_type_id``)
|
245 |
+
|
246 |
.. note::
|
247 |
+
|
248 |
If the ``encoded_inputs`` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the
|
249 |
result will use the same type unless you provide a different tensor type with ``return_tensors``. In the
|
250 |
case of PyTorch tensors, you will lose the specific device of your tensors however.
|
251 |
+
|
252 |
Args:
|
253 |
encoded_inputs (:class:`~transformers.BatchEncoding`, list of :class:`~transformers.BatchEncoding`, :obj:`Dict[str, List[int]]`, :obj:`Dict[str, List[List[int]]` or :obj:`List[Dict[str, List[int]]]`):
|
254 |
Tokenized inputs. Can represent one input (:class:`~transformers.BatchEncoding` or :obj:`Dict[str,
|
255 |
List[int]]`) or a batch of tokenized inputs (list of :class:`~transformers.BatchEncoding`, `Dict[str,
|
256 |
List[List[int]]]` or `List[Dict[str, List[int]]]`) so you can use this method during preprocessing as
|
257 |
well as in a PyTorch Dataloader collate function.
|
258 |
+
|
259 |
Instead of :obj:`List[int]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors),
|
260 |
see the note above for the return type.
|
261 |
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
262 |
Select a strategy to pad the returned sequences (according to the model's padding side and padding
|
263 |
index) among:
|
264 |
+
|
265 |
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
|
266 |
single sequence if provided).
|
267 |
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
|
|
272 |
Maximum length of the returned list and optionally padding length (see above).
|
273 |
pad_to_multiple_of (:obj:`int`, `optional`):
|
274 |
If set will pad the sequence to a multiple of the provided value.
|
275 |
+
|
276 |
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
|
277 |
>= 7.5 (Volta).
|
278 |
return_attention_mask (:obj:`bool`, `optional`):
|
279 |
Whether to return the attention mask. If left to the default, will return the attention mask according
|
280 |
to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute.
|
281 |
+
|
282 |
`What are attention masks? <../glossary.html#attention-mask>`__
|
283 |
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`):
|
284 |
If set, will return tensors instead of list of python integers. Acceptable values are:
|
285 |
+
|
286 |
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
|
287 |
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
|
288 |
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
|
|
|
291 |
"""
|
292 |
# If we have a list of dicts, let's convert it in a dict of lists
|
293 |
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
|
294 |
+
if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], (dict, BatchEncoding)):
|
295 |
+
encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()}
|
|
|
|
|
|
|
|
|
|
|
296 |
|
297 |
# The model's main input name, usually `input_ids`, has be passed for padding
|
298 |
if self.model_input_names[0] not in encoded_inputs:
|
|
|
386 |
def _pad(
|
387 |
self,
|
388 |
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
|
389 |
+
class_type, # options: "gene" or "cell"
|
390 |
max_length: Optional[int] = None,
|
391 |
padding_strategy: PaddingStrategy = PaddingStrategy.LONGEST,
|
392 |
pad_to_multiple_of: Optional[int] = None,
|
|
|
394 |
) -> dict:
|
395 |
"""
|
396 |
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
|
397 |
+
|
398 |
Args:
|
399 |
encoded_inputs: Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
|
400 |
max_length: maximum length of the returned list and optionally padding length (see below).
|
401 |
Will truncate by taking into account the special tokens.
|
402 |
padding_strategy: PaddingStrategy to use for padding.
|
403 |
+
|
404 |
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
|
405 |
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
|
406 |
- PaddingStrategy.DO_NOT_PAD: Do not pad
|
407 |
The tokenizer padding sides are defined in self.padding_side:
|
408 |
+
|
409 |
- 'left': pads on the left of the sequences
|
410 |
- 'right': pads on the right of the sequences
|
411 |
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
|
|
|
422 |
if padding_strategy == PaddingStrategy.LONGEST:
|
423 |
max_length = len(required_input)
|
424 |
|
425 |
+
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
|
|
|
|
|
|
|
|
|
426 |
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
427 |
|
428 |
+
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
|
|
|
|
|
|
|
429 |
|
430 |
if needs_to_be_padded:
|
431 |
difference = max_length - len(required_input)
|
432 |
if self.padding_side == "right":
|
433 |
if return_attention_mask:
|
434 |
+
encoded_inputs["attention_mask"] = [1] * len(required_input) + [0] * difference
|
|
|
|
|
435 |
if "token_type_ids" in encoded_inputs:
|
436 |
encoded_inputs["token_type_ids"] = (
|
437 |
+
encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
|
|
|
438 |
)
|
439 |
if "special_tokens_mask" in encoded_inputs:
|
440 |
+
encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
|
441 |
+
encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
|
|
|
|
|
|
|
|
|
442 |
if class_type == "gene":
|
443 |
+
encoded_inputs["labels"] = encoded_inputs["labels"] + [-100] * difference
|
|
|
|
|
444 |
elif self.padding_side == "left":
|
445 |
if return_attention_mask:
|
446 |
+
encoded_inputs["attention_mask"] = [0] * difference + [1] * len(required_input)
|
|
|
|
|
447 |
if "token_type_ids" in encoded_inputs:
|
448 |
+
encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
|
449 |
+
"token_type_ids"
|
450 |
+
]
|
451 |
if "special_tokens_mask" in encoded_inputs:
|
452 |
+
encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
|
453 |
+
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
|
|
|
|
|
|
|
|
|
454 |
if class_type == "gene":
|
455 |
+
encoded_inputs["labels"] = [-100] * difference + encoded_inputs["labels"]
|
|
|
|
|
456 |
else:
|
457 |
raise ValueError("Invalid padding strategy:" + str(self.padding_side))
|
458 |
elif return_attention_mask and "attention_mask" not in encoded_inputs:
|
459 |
encoded_inputs["attention_mask"] = [1] * len(required_input)
|
460 |
+
|
461 |
return encoded_inputs
|
462 |
|
463 |
def get_special_tokens_mask(
|
464 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
|
|
|
|
|
|
465 |
) -> List[int]:
|
466 |
"""
|
467 |
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
|
|
485 |
|
486 |
all_special_ids = self.all_special_ids # cache the property
|
487 |
|
488 |
+
special_tokens_mask = [1 if token in all_special_ids else 0 for token in token_ids_0]
|
|
|
|
|
489 |
|
490 |
return special_tokens_mask
|
491 |
|
492 |
+
def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
|
|
|
|
|
493 |
"""
|
494 |
Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the
|
495 |
vocabulary.
|
|
|
513 |
if token is None:
|
514 |
return None
|
515 |
|
516 |
+
return token_dictionary.get(token)
|
517 |
|
518 |
def __len__(self):
|
519 |
+
return len(token_dictionary)
|
520 |
|
521 |
|
522 |
# collator functions
|
523 |
|
|
|
524 |
class DataCollatorForGeneClassification(DataCollatorForTokenClassification):
|
525 |
"""
|
526 |
Data collator that will dynamically pad the inputs received, as well as the labels.
|
|
|
546 |
The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
|
547 |
"""
|
548 |
|
549 |
+
tokenizer = PrecollatorForGeneAndCellClassification()
|
550 |
class_type = "gene"
|
551 |
padding: Union[bool, str, PaddingStrategy] = True
|
552 |
max_length: Optional[int] = None
|
553 |
pad_to_multiple_of: Optional[int] = None
|
554 |
label_pad_token_id: int = -100
|
555 |
+
|
556 |
def __init__(self, *args, **kwargs) -> None:
|
|
|
557 |
super().__init__(
|
558 |
+
tokenizer=self.tokenizer,
|
|
|
|
|
559 |
padding=self.padding,
|
560 |
max_length=self.max_length,
|
561 |
pad_to_multiple_of=self.pad_to_multiple_of,
|
562 |
label_pad_token_id=self.label_pad_token_id,
|
563 |
+
*args, **kwargs)
|
|
|
|
|
564 |
|
565 |
def _prepare_batch(self, features):
|
566 |
label_name = "label" if "label" in features[0].keys() else "labels"
|
567 |
+
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
|
|
|
|
|
|
|
|
|
568 |
batch = self.tokenizer.pad(
|
569 |
features,
|
570 |
class_type=self.class_type,
|
|
|
574 |
return_tensors="pt",
|
575 |
)
|
576 |
return batch
|
577 |
+
|
578 |
def __call__(self, features):
|
579 |
batch = self._prepare_batch(features)
|
580 |
|
581 |
batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
|
582 |
return batch
|
583 |
|
584 |
+
|
585 |
class DataCollatorForCellClassification(DataCollatorForGeneClassification):
|
586 |
+
|
587 |
class_type = "cell"
|
588 |
|
589 |
def _prepare_batch(self, features):
|
590 |
+
|
591 |
batch = super()._prepare_batch(features)
|
592 |
+
|
593 |
# Special handling for labels.
|
594 |
# Ensure that tensor is created with the correct type
|
595 |
# (it should be automatically the case, but let's make sure of it.)
|
596 |
first = features[0]
|
597 |
if "label" in first and first["label"] is not None:
|
598 |
+
label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"]
|
|
|
|
|
|
|
|
|
599 |
dtype = torch.long if isinstance(label, int) else torch.float
|
600 |
batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
|
601 |
+
|
602 |
return batch
|
geneformer/emb_extractor.py
CHANGED
@@ -24,8 +24,8 @@ import torch
|
|
24 |
from tdigest import TDigest
|
25 |
from tqdm.auto import trange
|
26 |
|
27 |
-
from . import TOKEN_DICTIONARY_FILE
|
28 |
from . import perturber_utils as pu
|
|
|
29 |
|
30 |
logger = logging.getLogger(__name__)
|
31 |
|
@@ -38,8 +38,6 @@ def get_embs(
|
|
38 |
layer_to_quant,
|
39 |
pad_token_id,
|
40 |
forward_batch_size,
|
41 |
-
token_gene_dict,
|
42 |
-
special_token=False,
|
43 |
summary_stat=None,
|
44 |
silent=False,
|
45 |
):
|
@@ -49,8 +47,10 @@ def get_embs(
|
|
49 |
if summary_stat is None:
|
50 |
embs_list = []
|
51 |
elif summary_stat is not None:
|
52 |
-
#
|
53 |
-
|
|
|
|
|
54 |
if emb_mode == "cell":
|
55 |
# initiate tdigests for # of emb dims
|
56 |
embs_tdigests = [TDigest() for _ in range(emb_dims)]
|
@@ -67,27 +67,6 @@ def get_embs(
|
|
67 |
k: [TDigest() for _ in range(emb_dims)] for k in gene_set
|
68 |
}
|
69 |
|
70 |
-
# Check if CLS and EOS token is present in the token dictionary
|
71 |
-
cls_present = any("<cls>" in value for value in token_gene_dict.values())
|
72 |
-
eos_present = any("<eos>" in value for value in token_gene_dict.values())
|
73 |
-
if emb_mode == "cls":
|
74 |
-
assert cls_present, "<cls> token missing in token dictionary"
|
75 |
-
# Check to make sure that the first token of the filtered input data is cls token
|
76 |
-
gene_token_dict = {v: k for k, v in token_gene_dict.items()}
|
77 |
-
cls_token_id = gene_token_dict["<cls>"]
|
78 |
-
assert (
|
79 |
-
filtered_input_data["input_ids"][0][0] == cls_token_id
|
80 |
-
), "First token is not <cls> token value"
|
81 |
-
elif emb_mode == "cell":
|
82 |
-
if cls_present:
|
83 |
-
logger.warning(
|
84 |
-
"CLS token present in token dictionary, excluding from average."
|
85 |
-
)
|
86 |
-
if eos_present:
|
87 |
-
logger.warning(
|
88 |
-
"EOS token present in token dictionary, excluding from average."
|
89 |
-
)
|
90 |
-
|
91 |
overall_max_len = 0
|
92 |
|
93 |
for i in trange(0, total_batch_length, forward_batch_size, leave=(not silent)):
|
@@ -113,14 +92,7 @@ def get_embs(
|
|
113 |
embs_i = outputs.hidden_states[layer_to_quant]
|
114 |
|
115 |
if emb_mode == "cell":
|
116 |
-
|
117 |
-
non_cls_embs = embs_i[:, 1:, :] # Get all layers except the embs
|
118 |
-
if eos_present:
|
119 |
-
mean_embs = pu.mean_nonpadding_embs(non_cls_embs, original_lens - 2)
|
120 |
-
else:
|
121 |
-
mean_embs = pu.mean_nonpadding_embs(non_cls_embs, original_lens - 1)
|
122 |
-
else:
|
123 |
-
mean_embs = pu.mean_nonpadding_embs(embs_i, original_lens)
|
124 |
if summary_stat is None:
|
125 |
embs_list.append(mean_embs)
|
126 |
elif summary_stat is not None:
|
@@ -149,12 +121,6 @@ def get_embs(
|
|
149 |
accumulate_tdigests(
|
150 |
embs_tdigests_dict[int(k)], dict_h[k], emb_dims
|
151 |
)
|
152 |
-
del embs_h
|
153 |
-
del dict_h
|
154 |
-
elif emb_mode == "cls":
|
155 |
-
cls_embs = embs_i[:, 0, :].clone().detach() # CLS token layer
|
156 |
-
embs_list.append(cls_embs)
|
157 |
-
del cls_embs
|
158 |
|
159 |
overall_max_len = max(overall_max_len, max_len)
|
160 |
del outputs
|
@@ -165,7 +131,7 @@ def get_embs(
|
|
165 |
torch.cuda.empty_cache()
|
166 |
|
167 |
if summary_stat is None:
|
168 |
-
if
|
169 |
embs_stack = torch.cat(embs_list, dim=0)
|
170 |
elif emb_mode == "gene":
|
171 |
embs_stack = pu.pad_tensor_list(
|
@@ -243,6 +209,14 @@ def tdigest_median(embs_tdigests, emb_dims):
|
|
243 |
return [embs_tdigests[i].percentile(50) for i in range(emb_dims)]
|
244 |
|
245 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
246 |
def label_cell_embs(embs, downsampled_data, emb_labels):
|
247 |
embs_df = pd.DataFrame(embs.cpu().numpy())
|
248 |
if emb_labels is not None:
|
@@ -278,7 +252,7 @@ def label_gene_embs(embs, downsampled_data, token_gene_dict):
|
|
278 |
return embs_df
|
279 |
|
280 |
|
281 |
-
def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict
|
282 |
only_embs_df = embs_df.iloc[:, :emb_dims]
|
283 |
only_embs_df.index = pd.RangeIndex(0, only_embs_df.shape[0], name=None).astype(str)
|
284 |
only_embs_df.columns = pd.RangeIndex(0, only_embs_df.shape[1], name=None).astype(
|
@@ -288,27 +262,15 @@ def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict, seed=0):
|
|
288 |
obs_dict = {"cell_id": list(only_embs_df.index), f"{label}": list(embs_df[label])}
|
289 |
adata = anndata.AnnData(X=only_embs_df, obs=obs_dict, var=vars_dict)
|
290 |
sc.tl.pca(adata, svd_solver="arpack")
|
291 |
-
sc.pp.neighbors(adata
|
292 |
-
sc.tl.umap(adata
|
293 |
sns.set(rc={"figure.figsize": (10, 10)}, font_scale=2.3)
|
294 |
sns.set_style("white")
|
295 |
-
default_kwargs_dict = {"size": 200}
|
296 |
if kwargs_dict is not None:
|
297 |
default_kwargs_dict.update(kwargs_dict)
|
298 |
|
299 |
-
|
300 |
-
|
301 |
-
with plt.rc_context():
|
302 |
-
ax = sc.pl.umap(adata, color=label, show=False, **default_kwargs_dict)
|
303 |
-
ax.legend(
|
304 |
-
markerscale=2,
|
305 |
-
frameon=False,
|
306 |
-
loc="center left",
|
307 |
-
bbox_to_anchor=(1, 0.5),
|
308 |
-
ncol=(1 if len(cats) <= 14 else 2 if len(cats) <= 30 else 3),
|
309 |
-
)
|
310 |
-
plt.show()
|
311 |
-
plt.savefig(output_file, bbox_inches="tight")
|
312 |
|
313 |
|
314 |
def gen_heatmap_class_colors(labels, df):
|
@@ -384,8 +346,7 @@ def plot_heatmap(embs_df, emb_dims, label, output_file, kwargs_dict):
|
|
384 |
bbox_to_anchor=(0.5, 1),
|
385 |
facecolor="white",
|
386 |
)
|
387 |
-
|
388 |
-
logger.info(f"Output file: {output_file}")
|
389 |
plt.savefig(output_file, bbox_inches="tight")
|
390 |
|
391 |
|
@@ -393,7 +354,7 @@ class EmbExtractor:
|
|
393 |
valid_option_dict = {
|
394 |
"model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
|
395 |
"num_classes": {int},
|
396 |
-
"emb_mode": {"
|
397 |
"cell_emb_style": {"mean_pool"},
|
398 |
"gene_emb_style": {"mean_pool"},
|
399 |
"filter_data": {None, dict},
|
@@ -402,7 +363,6 @@ class EmbExtractor:
|
|
402 |
"emb_label": {None, list},
|
403 |
"labels_to_plot": {None, list},
|
404 |
"forward_batch_size": {int},
|
405 |
-
"token_dictionary_file": {None, str},
|
406 |
"nproc": {int},
|
407 |
"summary_stat": {None, "mean", "median", "exact_mean", "exact_median"},
|
408 |
}
|
@@ -411,7 +371,7 @@ class EmbExtractor:
|
|
411 |
self,
|
412 |
model_type="Pretrained",
|
413 |
num_classes=0,
|
414 |
-
emb_mode="
|
415 |
cell_emb_style="mean_pool",
|
416 |
gene_emb_style="mean_pool",
|
417 |
filter_data=None,
|
@@ -422,7 +382,7 @@ class EmbExtractor:
|
|
422 |
forward_batch_size=100,
|
423 |
nproc=4,
|
424 |
summary_stat=None,
|
425 |
-
token_dictionary_file=
|
426 |
):
|
427 |
"""
|
428 |
Initialize embedding extractor.
|
@@ -434,11 +394,10 @@ class EmbExtractor:
|
|
434 |
num_classes : int
|
435 |
| If model is a gene or cell classifier, specify number of classes it was trained to classify.
|
436 |
| For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
|
437 |
-
emb_mode : {"
|
438 |
-
| Whether to output
|
439 |
-
|
440 |
-
|
441 |
-
| Method for summarizing cell embeddings if not using CLS token.
|
442 |
| Currently only option is mean pooling of gene embeddings for given cell.
|
443 |
gene_emb_style : "mean_pool"
|
444 |
| Method for summarizing gene embeddings.
|
@@ -473,7 +432,6 @@ class EmbExtractor:
|
|
473 |
| Non-exact recommended if encountering memory constraints while generating goal embedding positions.
|
474 |
| Non-exact is slower but more memory-efficient.
|
475 |
token_dictionary_file : Path
|
476 |
-
| Default is the Geneformer token dictionary
|
477 |
| Path to pickle file containing token dictionary (Ensembl ID:token).
|
478 |
|
479 |
**Examples:**
|
@@ -486,6 +444,7 @@ class EmbExtractor:
|
|
486 |
... emb_mode="cell",
|
487 |
... filter_data={"cell_type":["cardiomyocyte"]},
|
488 |
... max_ncells=1000,
|
|
|
489 |
... emb_layer=-1,
|
490 |
... emb_label=["disease", "cell_type"],
|
491 |
... labels_to_plot=["disease", "cell_type"])
|
@@ -502,7 +461,6 @@ class EmbExtractor:
|
|
502 |
self.emb_layer = emb_layer
|
503 |
self.emb_label = emb_label
|
504 |
self.labels_to_plot = labels_to_plot
|
505 |
-
self.token_dictionary_file = token_dictionary_file
|
506 |
self.forward_batch_size = forward_batch_size
|
507 |
self.nproc = nproc
|
508 |
if (summary_stat is not None) and ("exact" in summary_stat):
|
@@ -515,8 +473,6 @@ class EmbExtractor:
|
|
515 |
self.validate_options()
|
516 |
|
517 |
# load token dictionary (Ensembl IDs:token)
|
518 |
-
if self.token_dictionary_file is None:
|
519 |
-
token_dictionary_file = TOKEN_DICTIONARY_FILE
|
520 |
with open(token_dictionary_file, "rb") as f:
|
521 |
self.gene_token_dict = pickle.load(f)
|
522 |
|
@@ -532,7 +488,7 @@ class EmbExtractor:
|
|
532 |
continue
|
533 |
valid_type = False
|
534 |
for option in valid_options:
|
535 |
-
if (option in [int, list, dict, bool
|
536 |
attr_value, option
|
537 |
):
|
538 |
valid_type = True
|
@@ -606,14 +562,13 @@ class EmbExtractor:
|
|
606 |
)
|
607 |
layer_to_quant = pu.quant_layers(model) + self.emb_layer
|
608 |
embs = get_embs(
|
609 |
-
model
|
610 |
-
|
611 |
-
|
612 |
-
layer_to_quant
|
613 |
-
|
614 |
-
|
615 |
-
|
616 |
-
summary_stat=self.summary_stat,
|
617 |
)
|
618 |
|
619 |
if self.emb_mode == "cell":
|
@@ -627,8 +582,6 @@ class EmbExtractor:
|
|
627 |
elif self.summary_stat is not None:
|
628 |
embs_df = pd.DataFrame(embs).T
|
629 |
embs_df.index = [self.token_gene_dict[token] for token in embs_df.index]
|
630 |
-
elif self.emb_mode == "cls":
|
631 |
-
embs_df = label_cell_embs(embs, downsampled_data, self.emb_label)
|
632 |
|
633 |
# save embeddings to output_path
|
634 |
if cell_state is None:
|
@@ -637,17 +590,13 @@ class EmbExtractor:
|
|
637 |
|
638 |
if self.exact_summary_stat == "exact_mean":
|
639 |
embs = embs.mean(dim=0)
|
640 |
-
emb_dims = pu.get_model_emb_dims(model)
|
641 |
embs_df = pd.DataFrame(
|
642 |
-
embs_df[0
|
643 |
-
columns=[self.exact_summary_stat],
|
644 |
).T
|
645 |
elif self.exact_summary_stat == "exact_median":
|
646 |
embs = torch.median(embs, dim=0)[0]
|
647 |
-
emb_dims = pu.get_model_emb_dims(model)
|
648 |
embs_df = pd.DataFrame(
|
649 |
-
embs_df[0
|
650 |
-
columns=[self.exact_summary_stat],
|
651 |
).T
|
652 |
|
653 |
if cell_state is not None:
|
@@ -800,15 +749,15 @@ class EmbExtractor:
|
|
800 |
logger.error("Plotting UMAP requires 'labels_to_plot'. ")
|
801 |
raise
|
802 |
|
803 |
-
if max_ncells_to_plot
|
804 |
-
|
805 |
-
|
806 |
-
|
807 |
-
|
808 |
-
|
809 |
-
|
810 |
-
|
811 |
-
|
812 |
|
813 |
if self.emb_label is None:
|
814 |
label_len = 0
|
@@ -830,11 +779,11 @@ class EmbExtractor:
|
|
830 |
f"not present in provided embeddings dataframe."
|
831 |
)
|
832 |
continue
|
833 |
-
output_prefix_label = output_prefix + f"_umap_{label}"
|
834 |
output_file = (
|
835 |
Path(output_directory) / output_prefix_label
|
836 |
).with_suffix(".pdf")
|
837 |
-
plot_umap(embs, emb_dims, label,
|
838 |
|
839 |
if plot_style == "heatmap":
|
840 |
for label in self.labels_to_plot:
|
|
|
24 |
from tdigest import TDigest
|
25 |
from tqdm.auto import trange
|
26 |
|
|
|
27 |
from . import perturber_utils as pu
|
28 |
+
from .tokenizer import TOKEN_DICTIONARY_FILE
|
29 |
|
30 |
logger = logging.getLogger(__name__)
|
31 |
|
|
|
38 |
layer_to_quant,
|
39 |
pad_token_id,
|
40 |
forward_batch_size,
|
|
|
|
|
41 |
summary_stat=None,
|
42 |
silent=False,
|
43 |
):
|
|
|
47 |
if summary_stat is None:
|
48 |
embs_list = []
|
49 |
elif summary_stat is not None:
|
50 |
+
# test embedding extraction for example cell and extract # emb dims
|
51 |
+
example = filtered_input_data.select([i for i in range(1)])
|
52 |
+
example.set_format(type="torch")
|
53 |
+
emb_dims = test_emb(model, example["input_ids"], layer_to_quant)
|
54 |
if emb_mode == "cell":
|
55 |
# initiate tdigests for # of emb dims
|
56 |
embs_tdigests = [TDigest() for _ in range(emb_dims)]
|
|
|
67 |
k: [TDigest() for _ in range(emb_dims)] for k in gene_set
|
68 |
}
|
69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
overall_max_len = 0
|
71 |
|
72 |
for i in trange(0, total_batch_length, forward_batch_size, leave=(not silent)):
|
|
|
92 |
embs_i = outputs.hidden_states[layer_to_quant]
|
93 |
|
94 |
if emb_mode == "cell":
|
95 |
+
mean_embs = pu.mean_nonpadding_embs(embs_i, original_lens)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
if summary_stat is None:
|
97 |
embs_list.append(mean_embs)
|
98 |
elif summary_stat is not None:
|
|
|
121 |
accumulate_tdigests(
|
122 |
embs_tdigests_dict[int(k)], dict_h[k], emb_dims
|
123 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
|
125 |
overall_max_len = max(overall_max_len, max_len)
|
126 |
del outputs
|
|
|
131 |
torch.cuda.empty_cache()
|
132 |
|
133 |
if summary_stat is None:
|
134 |
+
if emb_mode == "cell":
|
135 |
embs_stack = torch.cat(embs_list, dim=0)
|
136 |
elif emb_mode == "gene":
|
137 |
embs_stack = pu.pad_tensor_list(
|
|
|
209 |
return [embs_tdigests[i].percentile(50) for i in range(emb_dims)]
|
210 |
|
211 |
|
212 |
+
def test_emb(model, example, layer_to_quant):
|
213 |
+
with torch.no_grad():
|
214 |
+
outputs = model(input_ids=example.to("cuda"))
|
215 |
+
|
216 |
+
embs_test = outputs.hidden_states[layer_to_quant]
|
217 |
+
return embs_test.size()[2]
|
218 |
+
|
219 |
+
|
220 |
def label_cell_embs(embs, downsampled_data, emb_labels):
|
221 |
embs_df = pd.DataFrame(embs.cpu().numpy())
|
222 |
if emb_labels is not None:
|
|
|
252 |
return embs_df
|
253 |
|
254 |
|
255 |
+
def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict):
|
256 |
only_embs_df = embs_df.iloc[:, :emb_dims]
|
257 |
only_embs_df.index = pd.RangeIndex(0, only_embs_df.shape[0], name=None).astype(str)
|
258 |
only_embs_df.columns = pd.RangeIndex(0, only_embs_df.shape[1], name=None).astype(
|
|
|
262 |
obs_dict = {"cell_id": list(only_embs_df.index), f"{label}": list(embs_df[label])}
|
263 |
adata = anndata.AnnData(X=only_embs_df, obs=obs_dict, var=vars_dict)
|
264 |
sc.tl.pca(adata, svd_solver="arpack")
|
265 |
+
sc.pp.neighbors(adata)
|
266 |
+
sc.tl.umap(adata)
|
267 |
sns.set(rc={"figure.figsize": (10, 10)}, font_scale=2.3)
|
268 |
sns.set_style("white")
|
269 |
+
default_kwargs_dict = {"palette": "Set2", "size": 200}
|
270 |
if kwargs_dict is not None:
|
271 |
default_kwargs_dict.update(kwargs_dict)
|
272 |
|
273 |
+
sc.pl.umap(adata, color=label, save=output_file, **default_kwargs_dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
|
275 |
|
276 |
def gen_heatmap_class_colors(labels, df):
|
|
|
346 |
bbox_to_anchor=(0.5, 1),
|
347 |
facecolor="white",
|
348 |
)
|
349 |
+
|
|
|
350 |
plt.savefig(output_file, bbox_inches="tight")
|
351 |
|
352 |
|
|
|
354 |
valid_option_dict = {
|
355 |
"model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
|
356 |
"num_classes": {int},
|
357 |
+
"emb_mode": {"cell", "gene"},
|
358 |
"cell_emb_style": {"mean_pool"},
|
359 |
"gene_emb_style": {"mean_pool"},
|
360 |
"filter_data": {None, dict},
|
|
|
363 |
"emb_label": {None, list},
|
364 |
"labels_to_plot": {None, list},
|
365 |
"forward_batch_size": {int},
|
|
|
366 |
"nproc": {int},
|
367 |
"summary_stat": {None, "mean", "median", "exact_mean", "exact_median"},
|
368 |
}
|
|
|
371 |
self,
|
372 |
model_type="Pretrained",
|
373 |
num_classes=0,
|
374 |
+
emb_mode="cell",
|
375 |
cell_emb_style="mean_pool",
|
376 |
gene_emb_style="mean_pool",
|
377 |
filter_data=None,
|
|
|
382 |
forward_batch_size=100,
|
383 |
nproc=4,
|
384 |
summary_stat=None,
|
385 |
+
token_dictionary_file=TOKEN_DICTIONARY_FILE,
|
386 |
):
|
387 |
"""
|
388 |
Initialize embedding extractor.
|
|
|
394 |
num_classes : int
|
395 |
| If model is a gene or cell classifier, specify number of classes it was trained to classify.
|
396 |
| For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
|
397 |
+
emb_mode : {"cell", "gene"}
|
398 |
+
| Whether to output cell or gene embeddings.
|
399 |
+
cell_emb_style : "mean_pool"
|
400 |
+
| Method for summarizing cell embeddings.
|
|
|
401 |
| Currently only option is mean pooling of gene embeddings for given cell.
|
402 |
gene_emb_style : "mean_pool"
|
403 |
| Method for summarizing gene embeddings.
|
|
|
432 |
| Non-exact recommended if encountering memory constraints while generating goal embedding positions.
|
433 |
| Non-exact is slower but more memory-efficient.
|
434 |
token_dictionary_file : Path
|
|
|
435 |
| Path to pickle file containing token dictionary (Ensembl ID:token).
|
436 |
|
437 |
**Examples:**
|
|
|
444 |
... emb_mode="cell",
|
445 |
... filter_data={"cell_type":["cardiomyocyte"]},
|
446 |
... max_ncells=1000,
|
447 |
+
... max_ncells_to_plot=1000,
|
448 |
... emb_layer=-1,
|
449 |
... emb_label=["disease", "cell_type"],
|
450 |
... labels_to_plot=["disease", "cell_type"])
|
|
|
461 |
self.emb_layer = emb_layer
|
462 |
self.emb_label = emb_label
|
463 |
self.labels_to_plot = labels_to_plot
|
|
|
464 |
self.forward_batch_size = forward_batch_size
|
465 |
self.nproc = nproc
|
466 |
if (summary_stat is not None) and ("exact" in summary_stat):
|
|
|
473 |
self.validate_options()
|
474 |
|
475 |
# load token dictionary (Ensembl IDs:token)
|
|
|
|
|
476 |
with open(token_dictionary_file, "rb") as f:
|
477 |
self.gene_token_dict = pickle.load(f)
|
478 |
|
|
|
488 |
continue
|
489 |
valid_type = False
|
490 |
for option in valid_options:
|
491 |
+
if (option in [int, list, dict, bool]) and isinstance(
|
492 |
attr_value, option
|
493 |
):
|
494 |
valid_type = True
|
|
|
562 |
)
|
563 |
layer_to_quant = pu.quant_layers(model) + self.emb_layer
|
564 |
embs = get_embs(
|
565 |
+
model,
|
566 |
+
downsampled_data,
|
567 |
+
self.emb_mode,
|
568 |
+
layer_to_quant,
|
569 |
+
self.pad_token_id,
|
570 |
+
self.forward_batch_size,
|
571 |
+
self.summary_stat,
|
|
|
572 |
)
|
573 |
|
574 |
if self.emb_mode == "cell":
|
|
|
582 |
elif self.summary_stat is not None:
|
583 |
embs_df = pd.DataFrame(embs).T
|
584 |
embs_df.index = [self.token_gene_dict[token] for token in embs_df.index]
|
|
|
|
|
585 |
|
586 |
# save embeddings to output_path
|
587 |
if cell_state is None:
|
|
|
590 |
|
591 |
if self.exact_summary_stat == "exact_mean":
|
592 |
embs = embs.mean(dim=0)
|
|
|
593 |
embs_df = pd.DataFrame(
|
594 |
+
embs_df[0:255].mean(axis="rows"), columns=[self.exact_summary_stat]
|
|
|
595 |
).T
|
596 |
elif self.exact_summary_stat == "exact_median":
|
597 |
embs = torch.median(embs, dim=0)[0]
|
|
|
598 |
embs_df = pd.DataFrame(
|
599 |
+
embs_df[0:255].median(axis="rows"), columns=[self.exact_summary_stat]
|
|
|
600 |
).T
|
601 |
|
602 |
if cell_state is not None:
|
|
|
749 |
logger.error("Plotting UMAP requires 'labels_to_plot'. ")
|
750 |
raise
|
751 |
|
752 |
+
if max_ncells_to_plot > self.max_ncells:
|
753 |
+
max_ncells_to_plot = self.max_ncells
|
754 |
+
logger.warning(
|
755 |
+
"max_ncells_to_plot must be <= max_ncells. "
|
756 |
+
f"Changing max_ncells_to_plot to {self.max_ncells}."
|
757 |
+
)
|
758 |
+
|
759 |
+
if (max_ncells_to_plot is not None) and (max_ncells_to_plot < self.max_ncells):
|
760 |
+
embs = embs.sample(max_ncells_to_plot, axis=0)
|
761 |
|
762 |
if self.emb_label is None:
|
763 |
label_len = 0
|
|
|
779 |
f"not present in provided embeddings dataframe."
|
780 |
)
|
781 |
continue
|
782 |
+
output_prefix_label = "_" + output_prefix + f"_umap_{label}"
|
783 |
output_file = (
|
784 |
Path(output_directory) / output_prefix_label
|
785 |
).with_suffix(".pdf")
|
786 |
+
plot_umap(embs, emb_dims, label, output_prefix_label, kwargs_dict)
|
787 |
|
788 |
if plot_style == "heatmap":
|
789 |
for label in self.labels_to_plot:
|
geneformer/ensembl_mapping_dict_gc95M.pkl
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:0819bcbd869cfa14279449b037eb9ed1d09a91310e77bd1a19d927465030e95c
|
3 |
-
size 3957652
|
|
|
|
|
|
|
|
geneformer/evaluation_utils.py
CHANGED
@@ -20,20 +20,20 @@ from sklearn.metrics import (
|
|
20 |
)
|
21 |
from tqdm.auto import trange
|
22 |
|
23 |
-
from . import TOKEN_DICTIONARY_FILE
|
24 |
from .emb_extractor import make_colorbar
|
|
|
25 |
|
26 |
logger = logging.getLogger(__name__)
|
27 |
|
|
|
|
|
|
|
|
|
28 |
|
29 |
def preprocess_classifier_batch(cell_batch, max_len, label_name):
|
30 |
if max_len is None:
|
31 |
max_len = max([len(i) for i in cell_batch["input_ids"]])
|
32 |
|
33 |
-
# load token dictionary (Ensembl IDs:token)
|
34 |
-
with open(TOKEN_DICTIONARY_FILE, "rb") as f:
|
35 |
-
gene_token_dict = pickle.load(f)
|
36 |
-
|
37 |
def pad_label_example(example):
|
38 |
example[label_name] = np.pad(
|
39 |
example[label_name],
|
|
|
20 |
)
|
21 |
from tqdm.auto import trange
|
22 |
|
|
|
23 |
from .emb_extractor import make_colorbar
|
24 |
+
from .tokenizer import TOKEN_DICTIONARY_FILE
|
25 |
|
26 |
logger = logging.getLogger(__name__)
|
27 |
|
28 |
+
# load token dictionary (Ensembl IDs:token)
|
29 |
+
with open(TOKEN_DICTIONARY_FILE, "rb") as f:
|
30 |
+
gene_token_dict = pickle.load(f)
|
31 |
+
|
32 |
|
33 |
def preprocess_classifier_batch(cell_batch, max_len, label_name):
|
34 |
if max_len is None:
|
35 |
max_len = max([len(i) for i in cell_batch["input_ids"]])
|
36 |
|
|
|
|
|
|
|
|
|
37 |
def pad_label_example(example):
|
38 |
example[label_name] = np.pad(
|
39 |
example[label_name],
|
geneformer/gene_dictionaries_30m/ensembl_mapping_dict_gc30M.pkl
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:eac0fb0b3007267871b6305ac0003ceba19d4f28d85686cb9067ecf142787869
|
3 |
-
size 584125
|
|
|
|
|
|
|
|
geneformer/gene_dictionaries_30m/gene_median_dictionary_gc30M.pkl
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:b3b589bb5ec75040d05fc44dd6bf0184cf87f3c362cf158d196a6ed3b7fe5f39
|
3 |
-
size 940965
|
|
|
|
|
|
|
|
geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:ab9dc40973fa5224d77b793e2fd114cacf3d08423ed9c4c49caf0ba9c7f218f1
|
3 |
-
size 788424
|
|
|
|
|
|
|
|
geneformer/gene_median_dictionary.pkl
ADDED
Binary file (941 kB). View file
|
|
geneformer/gene_median_dictionary_gc95M.pkl
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:a51c53f6a771d64508dfaf61529df70e394c53bd20856926117ae5d641a24bf5
|
3 |
-
size 1512661
|
|
|
|
|
|
|
|
geneformer/{gene_dictionaries_30m/gene_name_id_dict_gc30M.pkl → gene_name_id_dict.pkl}
RENAMED
File without changes
|
geneformer/gene_name_id_dict_gc95M.pkl
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:8b0fd0521406ed18b2e341ef0acb5f53aa1a62457a07ca5840e1c142f46dd326
|
3 |
-
size 2038812
|
|
|
|
|
|
|
|
geneformer/in_silico_perturber.py
CHANGED
@@ -38,17 +38,19 @@ import logging
|
|
38 |
import os
|
39 |
import pickle
|
40 |
from collections import defaultdict
|
|
|
41 |
|
|
|
42 |
import torch
|
43 |
-
from datasets import Dataset
|
44 |
-
from multiprocess import set_start_method
|
45 |
from tqdm.auto import trange
|
46 |
|
47 |
-
from . import TOKEN_DICTIONARY_FILE
|
48 |
from . import perturber_utils as pu
|
49 |
from .emb_extractor import get_embs
|
|
|
|
|
|
|
50 |
|
51 |
-
disable_progress_bars()
|
52 |
|
53 |
logger = logging.getLogger(__name__)
|
54 |
|
@@ -60,9 +62,9 @@ class InSilicoPerturber:
|
|
60 |
"genes_to_perturb": {"all", list},
|
61 |
"combos": {0, 1},
|
62 |
"anchor_gene": {None, str},
|
63 |
-
"model_type": {"Pretrained", "GeneClassifier", "CellClassifier"
|
64 |
"num_classes": {int},
|
65 |
-
"emb_mode": {"
|
66 |
"cell_emb_style": {"mean_pool"},
|
67 |
"filter_data": {None, dict},
|
68 |
"cell_states_to_model": {None, dict},
|
@@ -70,7 +72,6 @@ class InSilicoPerturber:
|
|
70 |
"max_ncells": {None, int},
|
71 |
"cell_inds_to_perturb": {"all", dict},
|
72 |
"emb_layer": {-1, 0},
|
73 |
-
"token_dictionary_file": {None, str},
|
74 |
"forward_batch_size": {int},
|
75 |
"nproc": {int},
|
76 |
}
|
@@ -94,8 +95,7 @@ class InSilicoPerturber:
|
|
94 |
emb_layer=-1,
|
95 |
forward_batch_size=100,
|
96 |
nproc=4,
|
97 |
-
token_dictionary_file=
|
98 |
-
clear_mem_ncells=1000,
|
99 |
):
|
100 |
"""
|
101 |
Initialize in silico perturber.
|
@@ -130,16 +130,16 @@ class InSilicoPerturber:
|
|
130 |
| ENSEMBL ID of gene to use as anchor in combination perturbations.
|
131 |
| For example, if combos=1 and anchor_gene="ENSG00000148400":
|
132 |
| anchor gene will be perturbed in combination with each other gene.
|
133 |
-
model_type : {"Pretrained", "GeneClassifier", "CellClassifier"
|
134 |
-
| Whether model is the pretrained Geneformer or a fine-tuned gene
|
135 |
num_classes : int
|
136 |
| If model is a gene or cell classifier, specify number of classes it was trained to classify.
|
137 |
| For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
|
138 |
-
emb_mode : {"
|
139 |
-
| Whether to output impact of perturbation on
|
140 |
| Gene embedding shifts only available as compared to original cell, not comparing to goal state.
|
141 |
cell_emb_style : "mean_pool"
|
142 |
-
| Method for summarizing cell embeddings
|
143 |
| Currently only option is mean pooling of gene embeddings for given cell.
|
144 |
filter_data : None, dict
|
145 |
| Default is to use all input data for in silico perturbation study.
|
@@ -184,13 +184,7 @@ class InSilicoPerturber:
|
|
184 |
| Number of CPU processes to use.
|
185 |
token_dictionary_file : Path
|
186 |
| Path to pickle file containing token dictionary (Ensembl ID:token).
|
187 |
-
clear_mem_ncells : int
|
188 |
-
| Clear memory every n cells.
|
189 |
"""
|
190 |
-
try:
|
191 |
-
set_start_method("spawn")
|
192 |
-
except RuntimeError:
|
193 |
-
pass
|
194 |
|
195 |
self.perturb_type = perturb_type
|
196 |
self.perturb_rank_shift = perturb_rank_shift
|
@@ -222,32 +216,14 @@ class InSilicoPerturber:
|
|
222 |
self.emb_layer = emb_layer
|
223 |
self.forward_batch_size = forward_batch_size
|
224 |
self.nproc = nproc
|
225 |
-
self.token_dictionary_file = token_dictionary_file
|
226 |
-
self.clear_mem_ncells = clear_mem_ncells
|
227 |
|
228 |
self.validate_options()
|
229 |
|
230 |
# load token dictionary (Ensembl IDs:token)
|
231 |
-
if self.token_dictionary_file is None:
|
232 |
-
token_dictionary_file = TOKEN_DICTIONARY_FILE
|
233 |
with open(token_dictionary_file, "rb") as f:
|
234 |
self.gene_token_dict = pickle.load(f)
|
235 |
-
self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
|
236 |
|
237 |
self.pad_token_id = self.gene_token_dict.get("<pad>")
|
238 |
-
self.cls_token_id = self.gene_token_dict.get("<cls>")
|
239 |
-
self.eos_token_id = self.gene_token_dict.get("<eos>")
|
240 |
-
|
241 |
-
# Identify if special token is present in the token dictionary
|
242 |
-
if (self.cls_token_id is not None) and (self.eos_token_id is not None):
|
243 |
-
self.special_token = True
|
244 |
-
else:
|
245 |
-
if "cls" in self.emb_mode:
|
246 |
-
logger.error(
|
247 |
-
f"emb_mode set to {self.emb_mode} but <cls> or <eos> token not in token dictionary."
|
248 |
-
)
|
249 |
-
raise
|
250 |
-
self.special_token = False
|
251 |
|
252 |
if self.anchor_gene is None:
|
253 |
self.anchor_token = None
|
@@ -305,7 +281,7 @@ class InSilicoPerturber:
|
|
305 |
continue
|
306 |
valid_type = False
|
307 |
for option in valid_options:
|
308 |
-
if (option in [bool, int, list, dict
|
309 |
attr_value, option
|
310 |
):
|
311 |
valid_type = True
|
@@ -451,45 +427,16 @@ class InSilicoPerturber:
|
|
451 |
filtered_input_data = pu.load_and_filter(
|
452 |
self.filter_data, self.nproc, input_data_file
|
453 |
)
|
454 |
-
|
455 |
-
# Ensure emb_mode is cls if first token of the filtered input data is cls token
|
456 |
-
if self.special_token:
|
457 |
-
if (filtered_input_data["input_ids"][0][0] == self.cls_token_id) and (
|
458 |
-
"cls" not in self.emb_mode
|
459 |
-
):
|
460 |
-
logger.error(
|
461 |
-
"Emb mode 'cls' or 'cls_and_gene' required when first token is <cls>."
|
462 |
-
)
|
463 |
-
raise
|
464 |
-
if "cls" in self.emb_mode:
|
465 |
-
if (filtered_input_data["input_ids"][0][0] != self.cls_token_id) or (
|
466 |
-
filtered_input_data["input_ids"][0][-1] != self.eos_token_id
|
467 |
-
):
|
468 |
-
logger.error(
|
469 |
-
"Emb mode 'cls' and 'cls_and_gene' require that first token is <cls> and last token is <eos>."
|
470 |
-
)
|
471 |
-
raise
|
472 |
-
|
473 |
filtered_input_data = self.apply_additional_filters(filtered_input_data)
|
474 |
|
475 |
if self.perturb_group is True:
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
)
|
480 |
-
else:
|
481 |
-
self.isp_perturb_set(
|
482 |
-
model, filtered_input_data, layer_to_quant, output_path_prefix
|
483 |
-
)
|
484 |
else:
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
)
|
489 |
-
else:
|
490 |
-
self.isp_perturb_all(
|
491 |
-
model, filtered_input_data, layer_to_quant, output_path_prefix
|
492 |
-
)
|
493 |
|
494 |
def apply_additional_filters(self, filtered_input_data):
|
495 |
# additional filtering of input data dependent on isp mode
|
@@ -552,9 +499,7 @@ class InSilicoPerturber:
|
|
552 |
if self.perturb_type == "delete":
|
553 |
example = pu.delete_indices(example)
|
554 |
elif self.perturb_type == "overexpress":
|
555 |
-
example = pu.overexpress_tokens(
|
556 |
-
example, self.max_len, self.special_token
|
557 |
-
)
|
558 |
example["n_overflow"] = pu.calc_n_overflow(
|
559 |
self.max_len,
|
560 |
example["length"],
|
@@ -575,7 +520,6 @@ class InSilicoPerturber:
|
|
575 |
perturbed_data = filtered_input_data.map(
|
576 |
make_group_perturbation_batch, num_proc=self.nproc
|
577 |
)
|
578 |
-
|
579 |
if self.perturb_type == "overexpress":
|
580 |
filtered_input_data = filtered_input_data.add_column(
|
581 |
"n_overflow", perturbed_data["n_overflow"]
|
@@ -608,7 +552,6 @@ class InSilicoPerturber:
|
|
608 |
layer_to_quant,
|
609 |
self.pad_token_id,
|
610 |
self.forward_batch_size,
|
611 |
-
token_gene_dict=self.token_gene_dict,
|
612 |
summary_stat=None,
|
613 |
silent=True,
|
614 |
)
|
@@ -628,7 +571,6 @@ class InSilicoPerturber:
|
|
628 |
layer_to_quant,
|
629 |
self.pad_token_id,
|
630 |
self.forward_batch_size,
|
631 |
-
token_gene_dict=self.token_gene_dict,
|
632 |
summary_stat=None,
|
633 |
silent=True,
|
634 |
)
|
@@ -728,6 +670,8 @@ class InSilicoPerturber:
|
|
728 |
cos_sims_dict = self.update_perturbation_dictionary(
|
729 |
cos_sims_dict,
|
730 |
cos_sims_data,
|
|
|
|
|
731 |
gene_list,
|
732 |
)
|
733 |
else:
|
@@ -736,6 +680,8 @@ class InSilicoPerturber:
|
|
736 |
cos_sims_dict[state] = self.update_perturbation_dictionary(
|
737 |
cos_sims_dict[state],
|
738 |
cos_sims_data[state],
|
|
|
|
|
739 |
gene_list,
|
740 |
)
|
741 |
del minibatch
|
@@ -757,264 +703,6 @@ class InSilicoPerturber:
|
|
757 |
f"{output_path_prefix}_gene_embs_dict_{self.tokens_to_perturb}",
|
758 |
)
|
759 |
|
760 |
-
def isp_perturb_set_special(
|
761 |
-
self,
|
762 |
-
model,
|
763 |
-
filtered_input_data: Dataset,
|
764 |
-
layer_to_quant: int,
|
765 |
-
output_path_prefix: str,
|
766 |
-
):
|
767 |
-
def make_group_perturbation_batch(example):
|
768 |
-
example_input_ids = example["input_ids"]
|
769 |
-
example["tokens_to_perturb"] = self.tokens_to_perturb
|
770 |
-
indices_to_perturb = [
|
771 |
-
example_input_ids.index(token) if token in example_input_ids else None
|
772 |
-
for token in self.tokens_to_perturb
|
773 |
-
]
|
774 |
-
indices_to_perturb = [
|
775 |
-
item for item in indices_to_perturb if item is not None
|
776 |
-
]
|
777 |
-
if len(indices_to_perturb) > 0:
|
778 |
-
example["perturb_index"] = indices_to_perturb
|
779 |
-
else:
|
780 |
-
# -100 indicates tokens to overexpress are not present in rank value encoding
|
781 |
-
example["perturb_index"] = [-100]
|
782 |
-
if self.perturb_type == "delete":
|
783 |
-
example = pu.delete_indices(example)
|
784 |
-
elif self.perturb_type == "overexpress":
|
785 |
-
example = pu.overexpress_tokens(
|
786 |
-
example, self.max_len, self.special_token
|
787 |
-
)
|
788 |
-
example["n_overflow"] = pu.calc_n_overflow(
|
789 |
-
self.max_len,
|
790 |
-
example["length"],
|
791 |
-
self.tokens_to_perturb,
|
792 |
-
indices_to_perturb,
|
793 |
-
)
|
794 |
-
return example
|
795 |
-
|
796 |
-
total_batch_length = len(filtered_input_data)
|
797 |
-
if self.cell_states_to_model is None:
|
798 |
-
cos_sims_dict = defaultdict(list)
|
799 |
-
else:
|
800 |
-
cos_sims_dict = {
|
801 |
-
state: defaultdict(list)
|
802 |
-
for state in pu.get_possible_states(self.cell_states_to_model)
|
803 |
-
}
|
804 |
-
|
805 |
-
perturbed_data = filtered_input_data.map(
|
806 |
-
make_group_perturbation_batch, num_proc=self.nproc
|
807 |
-
)
|
808 |
-
|
809 |
-
if self.perturb_type == "overexpress":
|
810 |
-
filtered_input_data = filtered_input_data.add_column(
|
811 |
-
"n_overflow", perturbed_data["n_overflow"]
|
812 |
-
)
|
813 |
-
filtered_input_data = filtered_input_data.map(
|
814 |
-
pu.truncate_by_n_overflow_special, num_proc=self.nproc
|
815 |
-
)
|
816 |
-
|
817 |
-
if self.emb_mode == "cls_and_gene":
|
818 |
-
stored_gene_embs_dict = defaultdict(list)
|
819 |
-
|
820 |
-
# iterate through batches
|
821 |
-
for i in trange(0, total_batch_length, self.forward_batch_size):
|
822 |
-
max_range = min(i + self.forward_batch_size, total_batch_length)
|
823 |
-
inds_select = [i for i in range(i, max_range)]
|
824 |
-
|
825 |
-
minibatch = filtered_input_data.select(inds_select)
|
826 |
-
perturbation_batch = perturbed_data.select(inds_select)
|
827 |
-
|
828 |
-
##### CLS Embedding Mode #####
|
829 |
-
if self.emb_mode == "cls":
|
830 |
-
indices_to_perturb = perturbation_batch["perturb_index"]
|
831 |
-
|
832 |
-
original_cls_emb = get_embs(
|
833 |
-
model,
|
834 |
-
minibatch,
|
835 |
-
"cls",
|
836 |
-
layer_to_quant,
|
837 |
-
self.pad_token_id,
|
838 |
-
self.forward_batch_size,
|
839 |
-
token_gene_dict=self.token_gene_dict,
|
840 |
-
summary_stat=None,
|
841 |
-
silent=True,
|
842 |
-
)
|
843 |
-
|
844 |
-
perturbation_cls_emb = get_embs(
|
845 |
-
model,
|
846 |
-
perturbation_batch,
|
847 |
-
"cls",
|
848 |
-
layer_to_quant,
|
849 |
-
self.pad_token_id,
|
850 |
-
self.forward_batch_size,
|
851 |
-
token_gene_dict=self.token_gene_dict,
|
852 |
-
summary_stat=None,
|
853 |
-
silent=True,
|
854 |
-
)
|
855 |
-
|
856 |
-
# Calculate the cosine similarities
|
857 |
-
cls_cos_sims = pu.quant_cos_sims(
|
858 |
-
perturbation_cls_emb,
|
859 |
-
original_cls_emb,
|
860 |
-
self.cell_states_to_model,
|
861 |
-
self.state_embs_dict,
|
862 |
-
emb_mode="cell",
|
863 |
-
)
|
864 |
-
|
865 |
-
# Update perturbation dictionary
|
866 |
-
if self.cell_states_to_model is None:
|
867 |
-
cos_sims_dict = self.update_perturbation_dictionary(
|
868 |
-
cos_sims_dict,
|
869 |
-
cls_cos_sims,
|
870 |
-
gene_list=None,
|
871 |
-
)
|
872 |
-
else:
|
873 |
-
for state in cos_sims_dict.keys():
|
874 |
-
cos_sims_dict[state] = self.update_perturbation_dictionary(
|
875 |
-
cos_sims_dict[state],
|
876 |
-
cls_cos_sims[state],
|
877 |
-
gene_list=None,
|
878 |
-
)
|
879 |
-
|
880 |
-
##### CLS and Gene Embedding Mode #####
|
881 |
-
elif self.emb_mode == "cls_and_gene":
|
882 |
-
full_original_emb = get_embs(
|
883 |
-
model,
|
884 |
-
minibatch,
|
885 |
-
"gene",
|
886 |
-
layer_to_quant,
|
887 |
-
self.pad_token_id,
|
888 |
-
self.forward_batch_size,
|
889 |
-
self.token_gene_dict,
|
890 |
-
summary_stat=None,
|
891 |
-
silent=True,
|
892 |
-
)
|
893 |
-
indices_to_perturb = perturbation_batch["perturb_index"]
|
894 |
-
# remove indices that were perturbed
|
895 |
-
original_emb = pu.remove_perturbed_indices_set(
|
896 |
-
full_original_emb,
|
897 |
-
self.perturb_type,
|
898 |
-
indices_to_perturb,
|
899 |
-
self.tokens_to_perturb,
|
900 |
-
minibatch["length"],
|
901 |
-
)
|
902 |
-
full_perturbation_emb = get_embs(
|
903 |
-
model,
|
904 |
-
perturbation_batch,
|
905 |
-
"gene",
|
906 |
-
layer_to_quant,
|
907 |
-
self.pad_token_id,
|
908 |
-
self.forward_batch_size,
|
909 |
-
self.token_gene_dict,
|
910 |
-
summary_stat=None,
|
911 |
-
silent=True,
|
912 |
-
)
|
913 |
-
|
914 |
-
# remove special tokens and padding
|
915 |
-
original_emb = original_emb[:, 1:-1, :]
|
916 |
-
if self.perturb_type == "overexpress":
|
917 |
-
perturbation_emb = full_perturbation_emb[
|
918 |
-
:, 1 + len(self.tokens_to_perturb) : -1, :
|
919 |
-
]
|
920 |
-
elif self.perturb_type == "delete":
|
921 |
-
perturbation_emb = full_perturbation_emb[
|
922 |
-
:, 1 : max(perturbation_batch["length"]) - 1, :
|
923 |
-
]
|
924 |
-
|
925 |
-
n_perturbation_genes = perturbation_emb.size()[1]
|
926 |
-
|
927 |
-
gene_cos_sims = pu.quant_cos_sims(
|
928 |
-
perturbation_emb,
|
929 |
-
original_emb,
|
930 |
-
self.cell_states_to_model,
|
931 |
-
self.state_embs_dict,
|
932 |
-
emb_mode="gene",
|
933 |
-
)
|
934 |
-
|
935 |
-
# get cls emb
|
936 |
-
original_cls_emb = full_original_emb[:, 0, :]
|
937 |
-
perturbation_cls_emb = full_perturbation_emb[:, 0, :]
|
938 |
-
|
939 |
-
cls_cos_sims = pu.quant_cos_sims(
|
940 |
-
perturbation_cls_emb,
|
941 |
-
original_cls_emb,
|
942 |
-
self.cell_states_to_model,
|
943 |
-
self.state_embs_dict,
|
944 |
-
emb_mode="cell",
|
945 |
-
)
|
946 |
-
|
947 |
-
# get cosine similarities in gene embeddings
|
948 |
-
# since getting gene embeddings, need gene names
|
949 |
-
|
950 |
-
gene_list = minibatch["input_ids"]
|
951 |
-
# need to truncate gene_list
|
952 |
-
genes_to_exclude = self.tokens_to_perturb + [
|
953 |
-
self.cls_token_id,
|
954 |
-
self.eos_token_id,
|
955 |
-
]
|
956 |
-
gene_list = [
|
957 |
-
[g for g in genes if g not in genes_to_exclude][
|
958 |
-
:n_perturbation_genes
|
959 |
-
]
|
960 |
-
for genes in gene_list
|
961 |
-
]
|
962 |
-
|
963 |
-
for cell_i, genes in enumerate(gene_list):
|
964 |
-
for gene_j, affected_gene in enumerate(genes):
|
965 |
-
if len(self.genes_to_perturb) > 1:
|
966 |
-
tokens_to_perturb = tuple(self.tokens_to_perturb)
|
967 |
-
else:
|
968 |
-
tokens_to_perturb = self.tokens_to_perturb[0]
|
969 |
-
|
970 |
-
# fill in the gene cosine similarities
|
971 |
-
try:
|
972 |
-
stored_gene_embs_dict[
|
973 |
-
(tokens_to_perturb, affected_gene)
|
974 |
-
].append(gene_cos_sims[cell_i, gene_j].item())
|
975 |
-
except KeyError:
|
976 |
-
stored_gene_embs_dict[
|
977 |
-
(tokens_to_perturb, affected_gene)
|
978 |
-
] = gene_cos_sims[cell_i, gene_j].item()
|
979 |
-
|
980 |
-
if self.cell_states_to_model is None:
|
981 |
-
cos_sims_dict = self.update_perturbation_dictionary(
|
982 |
-
cos_sims_dict,
|
983 |
-
cls_cos_sims,
|
984 |
-
gene_list=None,
|
985 |
-
)
|
986 |
-
else:
|
987 |
-
for state in cos_sims_dict.keys():
|
988 |
-
cos_sims_dict[state] = self.update_perturbation_dictionary(
|
989 |
-
cos_sims_dict[state],
|
990 |
-
cls_cos_sims[state],
|
991 |
-
gene_list=None,
|
992 |
-
)
|
993 |
-
del full_original_emb
|
994 |
-
del original_emb
|
995 |
-
del full_perturbation_emb
|
996 |
-
del perturbation_emb
|
997 |
-
del gene_cos_sims
|
998 |
-
|
999 |
-
del original_cls_emb
|
1000 |
-
del perturbation_cls_emb
|
1001 |
-
del cls_cos_sims
|
1002 |
-
del minibatch
|
1003 |
-
del perturbation_batch
|
1004 |
-
|
1005 |
-
torch.cuda.empty_cache()
|
1006 |
-
|
1007 |
-
pu.write_perturbation_dictionary(
|
1008 |
-
cos_sims_dict,
|
1009 |
-
f"{output_path_prefix}_cell_embs_dict_{self.tokens_to_perturb}",
|
1010 |
-
)
|
1011 |
-
|
1012 |
-
if self.emb_mode == "cls_and_gene":
|
1013 |
-
pu.write_perturbation_dictionary(
|
1014 |
-
stored_gene_embs_dict,
|
1015 |
-
f"{output_path_prefix}_gene_embs_dict_{self.tokens_to_perturb}",
|
1016 |
-
)
|
1017 |
-
|
1018 |
def isp_perturb_all(
|
1019 |
self,
|
1020 |
model,
|
@@ -1033,10 +721,8 @@ class InSilicoPerturber:
|
|
1033 |
|
1034 |
if self.emb_mode == "cell_and_gene":
|
1035 |
stored_gene_embs_dict = defaultdict(list)
|
1036 |
-
|
1037 |
-
|
1038 |
-
for h in trange(len(filtered_input_data)):
|
1039 |
-
example_cell = filtered_input_data.select([h])
|
1040 |
full_original_emb = get_embs(
|
1041 |
model,
|
1042 |
example_cell,
|
@@ -1044,30 +730,16 @@ class InSilicoPerturber:
|
|
1044 |
layer_to_quant,
|
1045 |
self.pad_token_id,
|
1046 |
self.forward_batch_size,
|
1047 |
-
self.token_gene_dict,
|
1048 |
summary_stat=None,
|
1049 |
silent=True,
|
1050 |
)
|
1051 |
|
1052 |
-
if self.cell_states_to_model is not None:
|
1053 |
-
original_cell_emb = pu.compute_nonpadded_cell_embedding(
|
1054 |
-
full_original_emb, "mean_pool"
|
1055 |
-
)
|
1056 |
-
|
1057 |
# gene_list is used to assign cos sims back to genes
|
1058 |
-
gene_list = example_cell["input_ids"][0][:]
|
1059 |
# need to remove the anchor gene
|
|
|
1060 |
if self.anchor_token is not None:
|
1061 |
for token in self.anchor_token:
|
1062 |
gene_list.remove(token)
|
1063 |
-
# index 0 is not overexpressed so remove
|
1064 |
-
if self.perturb_type == "overexpress":
|
1065 |
-
gene_list = gene_list[num_inds_perturbed:]
|
1066 |
-
# remove perturbed index for gene list dict
|
1067 |
-
perturbed_gene_dict = {
|
1068 |
-
gene: gene_list[:i] + gene_list[i + 1 :]
|
1069 |
-
for i, gene in enumerate(gene_list)
|
1070 |
-
}
|
1071 |
|
1072 |
perturbation_batch, indices_to_perturb = pu.make_perturbation_batch(
|
1073 |
example_cell,
|
@@ -1078,459 +750,147 @@ class InSilicoPerturber:
|
|
1078 |
self.nproc,
|
1079 |
)
|
1080 |
|
1081 |
-
|
1082 |
-
|
1083 |
-
|
1084 |
-
|
1085 |
-
|
1086 |
-
|
1087 |
-
|
1088 |
-
|
1089 |
-
|
1090 |
-
|
1091 |
-
indices_to_perturb_mini = indices_to_perturb[i:ispall_max_range]
|
1092 |
-
gene_list_mini = gene_list[
|
1093 |
-
i:ispall_max_range
|
1094 |
-
] # only perturbed genes from this minibatch
|
1095 |
-
|
1096 |
-
full_perturbation_emb = get_embs(
|
1097 |
-
model,
|
1098 |
-
perturbation_minibatch,
|
1099 |
-
"gene",
|
1100 |
-
layer_to_quant,
|
1101 |
-
self.pad_token_id,
|
1102 |
-
self.forward_batch_size,
|
1103 |
-
self.token_gene_dict,
|
1104 |
-
summary_stat=None,
|
1105 |
-
silent=True,
|
1106 |
-
)
|
1107 |
-
|
1108 |
-
del perturbation_minibatch
|
1109 |
-
|
1110 |
-
# need to remove overexpressed gene to quantify cosine shifts
|
1111 |
-
if self.perturb_type == "overexpress":
|
1112 |
-
perturbation_emb = full_perturbation_emb[:, num_inds_perturbed:, :]
|
1113 |
-
|
1114 |
-
elif self.perturb_type == "delete":
|
1115 |
-
perturbation_emb = full_perturbation_emb
|
1116 |
-
|
1117 |
-
if (
|
1118 |
-
self.cell_states_to_model is None
|
1119 |
-
or self.emb_mode == "cell_and_gene"
|
1120 |
-
):
|
1121 |
-
original_emb_minibatch = pu.make_comparison_batch(
|
1122 |
-
full_original_emb, indices_to_perturb_mini, perturb_group=False
|
1123 |
-
)
|
1124 |
-
gene_cos_sims = pu.quant_cos_sims(
|
1125 |
-
perturbation_emb,
|
1126 |
-
original_emb_minibatch,
|
1127 |
-
self.cell_states_to_model,
|
1128 |
-
self.state_embs_dict,
|
1129 |
-
emb_mode="gene",
|
1130 |
-
)
|
1131 |
-
del original_emb_minibatch
|
1132 |
-
|
1133 |
-
if self.cell_states_to_model is not None:
|
1134 |
-
perturbation_cell_emb = pu.compute_nonpadded_cell_embedding(
|
1135 |
-
full_perturbation_emb, "mean_pool"
|
1136 |
-
)
|
1137 |
-
|
1138 |
-
cell_cos_sims = pu.quant_cos_sims(
|
1139 |
-
perturbation_cell_emb,
|
1140 |
-
original_cell_emb,
|
1141 |
-
self.cell_states_to_model,
|
1142 |
-
self.state_embs_dict,
|
1143 |
-
emb_mode="cell",
|
1144 |
-
)
|
1145 |
-
del perturbation_cell_emb
|
1146 |
-
|
1147 |
-
if self.emb_mode == "cell_and_gene":
|
1148 |
-
for perturbation_i, perturbed_gene in enumerate(gene_list_mini):
|
1149 |
-
for gene_j, affected_gene in enumerate(
|
1150 |
-
perturbed_gene_dict[perturbed_gene]
|
1151 |
-
):
|
1152 |
-
try:
|
1153 |
-
stored_gene_embs_dict[
|
1154 |
-
(perturbed_gene, affected_gene)
|
1155 |
-
].append(gene_cos_sims[perturbation_i, gene_j].item())
|
1156 |
-
except KeyError:
|
1157 |
-
stored_gene_embs_dict[
|
1158 |
-
(perturbed_gene, affected_gene)
|
1159 |
-
] = gene_cos_sims[perturbation_i, gene_j].item()
|
1160 |
|
1161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
1162 |
|
1163 |
-
|
1164 |
-
|
1165 |
-
cos_sims_dict = self.update_perturbation_dictionary(
|
1166 |
-
cos_sims_dict,
|
1167 |
-
cos_sims_data,
|
1168 |
-
gene_list_mini,
|
1169 |
-
)
|
1170 |
-
else:
|
1171 |
-
cos_sims_data = cell_cos_sims
|
1172 |
-
for state in cos_sims_dict.keys():
|
1173 |
-
cos_sims_dict[state] = self.update_perturbation_dictionary(
|
1174 |
-
cos_sims_dict[state],
|
1175 |
-
cos_sims_data[state],
|
1176 |
-
gene_list_mini,
|
1177 |
-
)
|
1178 |
-
|
1179 |
-
# save dict to disk every self.clear_mem_ncells/10 (default 100) simulated cells
|
1180 |
-
if i % self.clear_mem_ncells / 10 == 0:
|
1181 |
-
pu.write_perturbation_dictionary(
|
1182 |
-
cos_sims_dict,
|
1183 |
-
f"{output_path_prefix}_dict_cell_embs_{h}batch{pickle_batch}",
|
1184 |
-
)
|
1185 |
-
if self.emb_mode == "cell_and_gene":
|
1186 |
-
pu.write_perturbation_dictionary(
|
1187 |
-
stored_gene_embs_dict,
|
1188 |
-
f"{output_path_prefix}_dict_gene_embs_{h}batch{pickle_batch}",
|
1189 |
-
)
|
1190 |
-
|
1191 |
-
# reset and clear memory every self.clear_mem_ncells (default 1000) simulated cells or at the end of the example cell
|
1192 |
-
if i % self.clear_mem_ncells == 0:
|
1193 |
-
pickle_batch += 1
|
1194 |
-
if self.cell_states_to_model is None:
|
1195 |
-
cos_sims_dict = defaultdict(list)
|
1196 |
-
else:
|
1197 |
-
cos_sims_dict = {
|
1198 |
-
state: defaultdict(list)
|
1199 |
-
for state in pu.get_possible_states(
|
1200 |
-
self.cell_states_to_model
|
1201 |
-
)
|
1202 |
-
}
|
1203 |
-
|
1204 |
-
if self.emb_mode == "cell_and_gene":
|
1205 |
-
stored_gene_embs_dict = defaultdict(list)
|
1206 |
-
|
1207 |
-
torch.cuda.empty_cache()
|
1208 |
|
1209 |
-
pu.
|
1210 |
-
|
1211 |
-
f"{output_path_prefix}_dict_cell_embs_{h}batch{pickle_batch}",
|
1212 |
)
|
1213 |
|
1214 |
-
if self.emb_mode == "cell_and_gene":
|
1215 |
-
pu.
|
1216 |
-
|
1217 |
-
|
|
|
|
|
|
|
1218 |
)
|
1219 |
-
|
1220 |
-
pickle_batch = -1
|
1221 |
-
if self.cell_states_to_model is None:
|
1222 |
-
cos_sims_dict = defaultdict(list)
|
1223 |
-
else:
|
1224 |
-
cos_sims_dict = {
|
1225 |
-
state: defaultdict(list)
|
1226 |
-
for state in pu.get_possible_states(self.cell_states_to_model)
|
1227 |
-
}
|
1228 |
-
|
1229 |
-
if self.emb_mode == "cell_and_gene":
|
1230 |
-
stored_gene_embs_dict = defaultdict(list)
|
1231 |
-
|
1232 |
-
# clear memory between cells
|
1233 |
-
del perturbation_batch
|
1234 |
-
del full_original_emb
|
1235 |
if self.cell_states_to_model is not None:
|
1236 |
-
|
1237 |
-
|
1238 |
-
|
1239 |
-
def isp_perturb_all_special(
|
1240 |
-
self,
|
1241 |
-
model,
|
1242 |
-
filtered_input_data: Dataset,
|
1243 |
-
layer_to_quant: int,
|
1244 |
-
output_path_prefix: str,
|
1245 |
-
):
|
1246 |
-
pickle_batch = -1
|
1247 |
-
if self.cell_states_to_model is None:
|
1248 |
-
cos_sims_dict = defaultdict(list)
|
1249 |
-
else:
|
1250 |
-
cos_sims_dict = {
|
1251 |
-
state: defaultdict(list)
|
1252 |
-
for state in pu.get_possible_states(self.cell_states_to_model)
|
1253 |
-
}
|
1254 |
-
|
1255 |
-
if self.emb_mode == "cls_and_gene":
|
1256 |
-
stored_gene_embs_dict = defaultdict(list)
|
1257 |
-
|
1258 |
-
num_inds_perturbed = 1 + self.combos
|
1259 |
-
for h in trange(len(filtered_input_data)):
|
1260 |
-
example_cell = filtered_input_data.select([h])
|
1261 |
-
|
1262 |
-
# get original example cell cls and/or gene embs for comparison
|
1263 |
-
if self.emb_mode == "cls":
|
1264 |
-
original_cls_emb = get_embs(
|
1265 |
-
model,
|
1266 |
-
example_cell,
|
1267 |
-
"cls",
|
1268 |
-
layer_to_quant,
|
1269 |
-
self.pad_token_id,
|
1270 |
-
self.forward_batch_size,
|
1271 |
-
self.token_gene_dict,
|
1272 |
-
summary_stat=None,
|
1273 |
-
silent=True,
|
1274 |
)
|
1275 |
-
|
1276 |
-
|
1277 |
-
model,
|
1278 |
-
example_cell,
|
1279 |
-
"gene",
|
1280 |
-
layer_to_quant,
|
1281 |
-
self.pad_token_id,
|
1282 |
-
self.forward_batch_size,
|
1283 |
-
self.token_gene_dict,
|
1284 |
-
summary_stat=None,
|
1285 |
-
silent=True,
|
1286 |
)
|
1287 |
-
original_cls_emb = full_original_emb[:, 0, :].clone().detach()
|
1288 |
-
|
1289 |
-
# gene_list is used to assign cos sims back to genes
|
1290 |
-
gene_list = example_cell["input_ids"][0][:]
|
1291 |
|
1292 |
-
|
1293 |
-
|
1294 |
-
|
1295 |
-
|
1296 |
-
|
1297 |
-
|
1298 |
-
gene_list.remove(token)
|
1299 |
-
# index 0 is not overexpressed so remove
|
1300 |
-
if self.perturb_type == "overexpress":
|
1301 |
-
gene_list = gene_list[num_inds_perturbed:]
|
1302 |
-
# remove perturbed index for gene list dict
|
1303 |
-
perturbed_gene_dict = {
|
1304 |
-
gene: gene_list[:i] + gene_list[i + 1 :]
|
1305 |
-
for i, gene in enumerate(gene_list)
|
1306 |
-
}
|
1307 |
-
|
1308 |
-
perturbation_batch, indices_to_perturb = pu.make_perturbation_batch_special(
|
1309 |
-
example_cell,
|
1310 |
-
self.perturb_type,
|
1311 |
-
self.tokens_to_perturb,
|
1312 |
-
self.anchor_token,
|
1313 |
-
self.combos,
|
1314 |
-
self.nproc,
|
1315 |
-
)
|
1316 |
-
|
1317 |
-
ispall_total_batch_length = len(perturbation_batch)
|
1318 |
-
for i in trange(
|
1319 |
-
0, ispall_total_batch_length, self.forward_batch_size, leave=False
|
1320 |
-
):
|
1321 |
-
ispall_max_range = min(
|
1322 |
-
i + self.forward_batch_size, ispall_total_batch_length
|
1323 |
-
)
|
1324 |
-
perturbation_minibatch = perturbation_batch.select(
|
1325 |
-
[i for i in range(i, ispall_max_range)]
|
1326 |
)
|
1327 |
-
indices_to_perturb_mini = indices_to_perturb[i:ispall_max_range]
|
1328 |
-
gene_list_mini = gene_list[
|
1329 |
-
i:ispall_max_range
|
1330 |
-
] # only perturbed genes from this minibatch
|
1331 |
-
|
1332 |
-
##### CLS Embedding Mode #####
|
1333 |
-
if self.emb_mode == "cls":
|
1334 |
-
# Extract cls embeddings from perturbed cells
|
1335 |
-
perturbation_cls_emb = get_embs(
|
1336 |
-
model,
|
1337 |
-
perturbation_minibatch,
|
1338 |
-
"cls",
|
1339 |
-
layer_to_quant,
|
1340 |
-
self.pad_token_id,
|
1341 |
-
self.forward_batch_size,
|
1342 |
-
self.token_gene_dict,
|
1343 |
-
summary_stat=None,
|
1344 |
-
silent=True,
|
1345 |
-
)
|
1346 |
|
1347 |
-
|
1348 |
-
|
1349 |
-
|
1350 |
-
|
1351 |
-
|
1352 |
-
|
1353 |
-
emb_mode="cell",
|
1354 |
-
)
|
1355 |
|
1356 |
-
|
1357 |
-
|
1358 |
-
|
1359 |
-
|
1360 |
-
|
1361 |
-
|
1362 |
-
|
1363 |
-
|
1364 |
-
|
1365 |
-
|
1366 |
-
|
1367 |
-
|
1368 |
-
)
|
1369 |
-
|
1370 |
-
del perturbation_minibatch
|
1371 |
-
del perturbation_cls_emb
|
1372 |
-
del cls_cos_sims
|
1373 |
-
|
1374 |
-
##### CLS and Gene Embedding Mode #####
|
1375 |
-
elif self.emb_mode == "cls_and_gene":
|
1376 |
-
full_perturbation_emb = get_embs(
|
1377 |
-
model,
|
1378 |
-
perturbation_minibatch,
|
1379 |
-
"gene",
|
1380 |
-
layer_to_quant,
|
1381 |
-
self.pad_token_id,
|
1382 |
-
self.forward_batch_size,
|
1383 |
-
self.token_gene_dict,
|
1384 |
-
summary_stat=None,
|
1385 |
-
silent=True,
|
1386 |
-
)
|
1387 |
|
1388 |
-
|
1389 |
-
|
1390 |
-
|
1391 |
-
|
1392 |
-
|
1393 |
-
|
1394 |
-
|
1395 |
-
|
1396 |
-
|
1397 |
-
|
1398 |
-
|
1399 |
-
|
1400 |
-
|
1401 |
-
|
|
|
|
|
|
|
|
|
1402 |
)
|
1403 |
|
1404 |
-
|
1405 |
-
|
1406 |
-
|
1407 |
-
|
1408 |
-
|
1409 |
-
|
1410 |
-
|
1411 |
-
|
1412 |
-
|
|
|
1413 |
)
|
1414 |
|
1415 |
-
|
1416 |
-
|
1417 |
-
|
1418 |
-
|
1419 |
-
|
1420 |
-
|
1421 |
-
|
1422 |
-
|
1423 |
-
|
1424 |
-
|
1425 |
-
(perturbed_gene, affected_gene)
|
1426 |
-
] = gene_cos_sims[perturbation_i, gene_j].item()
|
1427 |
|
1428 |
-
|
1429 |
-
|
1430 |
-
full_perturbation_emb[:, 0, :].clone().detach()
|
1431 |
-
)
|
1432 |
|
1433 |
-
|
1434 |
-
perturbation_cls_emb,
|
1435 |
-
original_cls_emb,
|
1436 |
-
self.cell_states_to_model,
|
1437 |
-
self.state_embs_dict,
|
1438 |
-
emb_mode="cell",
|
1439 |
-
)
|
1440 |
|
1441 |
-
|
1442 |
-
|
1443 |
-
|
1444 |
-
cls_cos_sims,
|
1445 |
-
gene_list_mini,
|
1446 |
-
)
|
1447 |
-
else:
|
1448 |
-
for state in cos_sims_dict.keys():
|
1449 |
-
cos_sims_dict[state] = self.update_perturbation_dictionary(
|
1450 |
-
cos_sims_dict[state],
|
1451 |
-
cls_cos_sims[state],
|
1452 |
-
gene_list_mini,
|
1453 |
-
)
|
1454 |
-
|
1455 |
-
del perturbation_minibatch
|
1456 |
-
del original_emb_minibatch
|
1457 |
-
del full_perturbation_emb
|
1458 |
-
del perturbation_emb
|
1459 |
-
del perturbation_cls_emb
|
1460 |
-
del cls_cos_sims
|
1461 |
-
del gene_cos_sims
|
1462 |
-
|
1463 |
-
# save dict to disk every self.clear_mem_ncells/10 (default 100) simulated cells
|
1464 |
-
if i % max(1, self.clear_mem_ncells / 10) == 0:
|
1465 |
-
pu.write_perturbation_dictionary(
|
1466 |
-
cos_sims_dict,
|
1467 |
-
f"{output_path_prefix}_dict_cell_embs_{h}batch{pickle_batch}",
|
1468 |
-
)
|
1469 |
-
if self.emb_mode == "cls_and_gene":
|
1470 |
-
pu.write_perturbation_dictionary(
|
1471 |
-
stored_gene_embs_dict,
|
1472 |
-
f"{output_path_prefix}_dict_gene_embs_{h}batch{pickle_batch}",
|
1473 |
-
)
|
1474 |
-
|
1475 |
-
# reset and clear memory every self.clear_mem_ncells (default 1000) simulated cells or at the end of the example cell
|
1476 |
-
if i % self.clear_mem_ncells == 0:
|
1477 |
-
pickle_batch += 1
|
1478 |
-
if self.cell_states_to_model is None:
|
1479 |
-
cos_sims_dict = defaultdict(list)
|
1480 |
-
else:
|
1481 |
-
cos_sims_dict = {
|
1482 |
-
state: defaultdict(list)
|
1483 |
-
for state in pu.get_possible_states(
|
1484 |
-
self.cell_states_to_model
|
1485 |
-
)
|
1486 |
-
}
|
1487 |
-
|
1488 |
-
if self.emb_mode == "cls_and_gene":
|
1489 |
-
stored_gene_embs_dict = defaultdict(list)
|
1490 |
-
|
1491 |
-
torch.cuda.empty_cache()
|
1492 |
|
|
|
1493 |
pu.write_perturbation_dictionary(
|
1494 |
-
|
1495 |
-
f"{output_path_prefix}
|
1496 |
)
|
1497 |
|
1498 |
-
if self.emb_mode == "cls_and_gene":
|
1499 |
-
pu.write_perturbation_dictionary(
|
1500 |
-
stored_gene_embs_dict,
|
1501 |
-
f"{output_path_prefix}_dict_gene_embs_{h}batch{pickle_batch}",
|
1502 |
-
)
|
1503 |
-
|
1504 |
-
pickle_batch = -1
|
1505 |
-
if self.cell_states_to_model is None:
|
1506 |
-
cos_sims_dict = defaultdict(list)
|
1507 |
-
else:
|
1508 |
-
cos_sims_dict = {
|
1509 |
-
state: defaultdict(list)
|
1510 |
-
for state in pu.get_possible_states(self.cell_states_to_model)
|
1511 |
-
}
|
1512 |
-
|
1513 |
-
if self.emb_mode == "cls_and_gene":
|
1514 |
-
stored_gene_embs_dict = defaultdict(list)
|
1515 |
-
|
1516 |
-
# clear memory between cells
|
1517 |
-
del perturbation_batch
|
1518 |
-
del original_cls_emb
|
1519 |
-
if self.emb_mode == "cls_and_gene":
|
1520 |
-
del full_original_emb
|
1521 |
-
torch.cuda.empty_cache()
|
1522 |
-
|
1523 |
def update_perturbation_dictionary(
|
1524 |
self,
|
1525 |
cos_sims_dict: defaultdict,
|
1526 |
cos_sims_data: torch.Tensor,
|
|
|
|
|
1527 |
gene_list=None,
|
1528 |
):
|
1529 |
if gene_list is not None and cos_sims_data.shape[0] != len(gene_list):
|
1530 |
logger.error(
|
1531 |
f"len(cos_sims_data.shape[0]) != len(gene_list). \n \
|
1532 |
-
{cos_sims_data.shape[0]
|
1533 |
-
{len(gene_list)
|
1534 |
)
|
1535 |
raise
|
1536 |
|
|
|
38 |
import os
|
39 |
import pickle
|
40 |
from collections import defaultdict
|
41 |
+
from typing import List
|
42 |
|
43 |
+
import seaborn as sns
|
44 |
import torch
|
45 |
+
from datasets import Dataset
|
|
|
46 |
from tqdm.auto import trange
|
47 |
|
|
|
48 |
from . import perturber_utils as pu
|
49 |
from .emb_extractor import get_embs
|
50 |
+
from .tokenizer import TOKEN_DICTIONARY_FILE
|
51 |
+
|
52 |
+
sns.set()
|
53 |
|
|
|
54 |
|
55 |
logger = logging.getLogger(__name__)
|
56 |
|
|
|
62 |
"genes_to_perturb": {"all", list},
|
63 |
"combos": {0, 1},
|
64 |
"anchor_gene": {None, str},
|
65 |
+
"model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
|
66 |
"num_classes": {int},
|
67 |
+
"emb_mode": {"cell", "cell_and_gene"},
|
68 |
"cell_emb_style": {"mean_pool"},
|
69 |
"filter_data": {None, dict},
|
70 |
"cell_states_to_model": {None, dict},
|
|
|
72 |
"max_ncells": {None, int},
|
73 |
"cell_inds_to_perturb": {"all", dict},
|
74 |
"emb_layer": {-1, 0},
|
|
|
75 |
"forward_batch_size": {int},
|
76 |
"nproc": {int},
|
77 |
}
|
|
|
95 |
emb_layer=-1,
|
96 |
forward_batch_size=100,
|
97 |
nproc=4,
|
98 |
+
token_dictionary_file=TOKEN_DICTIONARY_FILE,
|
|
|
99 |
):
|
100 |
"""
|
101 |
Initialize in silico perturber.
|
|
|
130 |
| ENSEMBL ID of gene to use as anchor in combination perturbations.
|
131 |
| For example, if combos=1 and anchor_gene="ENSG00000148400":
|
132 |
| anchor gene will be perturbed in combination with each other gene.
|
133 |
+
model_type : {"Pretrained", "GeneClassifier", "CellClassifier"}
|
134 |
+
| Whether model is the pretrained Geneformer or a fine-tuned gene or cell classifier.
|
135 |
num_classes : int
|
136 |
| If model is a gene or cell classifier, specify number of classes it was trained to classify.
|
137 |
| For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
|
138 |
+
emb_mode : {"cell", "cell_and_gene"}
|
139 |
+
| Whether to output impact of perturbation on cell and/or gene embeddings.
|
140 |
| Gene embedding shifts only available as compared to original cell, not comparing to goal state.
|
141 |
cell_emb_style : "mean_pool"
|
142 |
+
| Method for summarizing cell embeddings.
|
143 |
| Currently only option is mean pooling of gene embeddings for given cell.
|
144 |
filter_data : None, dict
|
145 |
| Default is to use all input data for in silico perturbation study.
|
|
|
184 |
| Number of CPU processes to use.
|
185 |
token_dictionary_file : Path
|
186 |
| Path to pickle file containing token dictionary (Ensembl ID:token).
|
|
|
|
|
187 |
"""
|
|
|
|
|
|
|
|
|
188 |
|
189 |
self.perturb_type = perturb_type
|
190 |
self.perturb_rank_shift = perturb_rank_shift
|
|
|
216 |
self.emb_layer = emb_layer
|
217 |
self.forward_batch_size = forward_batch_size
|
218 |
self.nproc = nproc
|
|
|
|
|
219 |
|
220 |
self.validate_options()
|
221 |
|
222 |
# load token dictionary (Ensembl IDs:token)
|
|
|
|
|
223 |
with open(token_dictionary_file, "rb") as f:
|
224 |
self.gene_token_dict = pickle.load(f)
|
|
|
225 |
|
226 |
self.pad_token_id = self.gene_token_dict.get("<pad>")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
227 |
|
228 |
if self.anchor_gene is None:
|
229 |
self.anchor_token = None
|
|
|
281 |
continue
|
282 |
valid_type = False
|
283 |
for option in valid_options:
|
284 |
+
if (option in [bool, int, list, dict]) and isinstance(
|
285 |
attr_value, option
|
286 |
):
|
287 |
valid_type = True
|
|
|
427 |
filtered_input_data = pu.load_and_filter(
|
428 |
self.filter_data, self.nproc, input_data_file
|
429 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
430 |
filtered_input_data = self.apply_additional_filters(filtered_input_data)
|
431 |
|
432 |
if self.perturb_group is True:
|
433 |
+
self.isp_perturb_set(
|
434 |
+
model, filtered_input_data, layer_to_quant, output_path_prefix
|
435 |
+
)
|
|
|
|
|
|
|
|
|
|
|
436 |
else:
|
437 |
+
self.isp_perturb_all(
|
438 |
+
model, filtered_input_data, layer_to_quant, output_path_prefix
|
439 |
+
)
|
|
|
|
|
|
|
|
|
|
|
440 |
|
441 |
def apply_additional_filters(self, filtered_input_data):
|
442 |
# additional filtering of input data dependent on isp mode
|
|
|
499 |
if self.perturb_type == "delete":
|
500 |
example = pu.delete_indices(example)
|
501 |
elif self.perturb_type == "overexpress":
|
502 |
+
example = pu.overexpress_tokens(example, self.max_len)
|
|
|
|
|
503 |
example["n_overflow"] = pu.calc_n_overflow(
|
504 |
self.max_len,
|
505 |
example["length"],
|
|
|
520 |
perturbed_data = filtered_input_data.map(
|
521 |
make_group_perturbation_batch, num_proc=self.nproc
|
522 |
)
|
|
|
523 |
if self.perturb_type == "overexpress":
|
524 |
filtered_input_data = filtered_input_data.add_column(
|
525 |
"n_overflow", perturbed_data["n_overflow"]
|
|
|
552 |
layer_to_quant,
|
553 |
self.pad_token_id,
|
554 |
self.forward_batch_size,
|
|
|
555 |
summary_stat=None,
|
556 |
silent=True,
|
557 |
)
|
|
|
571 |
layer_to_quant,
|
572 |
self.pad_token_id,
|
573 |
self.forward_batch_size,
|
|
|
574 |
summary_stat=None,
|
575 |
silent=True,
|
576 |
)
|
|
|
670 |
cos_sims_dict = self.update_perturbation_dictionary(
|
671 |
cos_sims_dict,
|
672 |
cos_sims_data,
|
673 |
+
filtered_input_data,
|
674 |
+
indices_to_perturb,
|
675 |
gene_list,
|
676 |
)
|
677 |
else:
|
|
|
680 |
cos_sims_dict[state] = self.update_perturbation_dictionary(
|
681 |
cos_sims_dict[state],
|
682 |
cos_sims_data[state],
|
683 |
+
filtered_input_data,
|
684 |
+
indices_to_perturb,
|
685 |
gene_list,
|
686 |
)
|
687 |
del minibatch
|
|
|
703 |
f"{output_path_prefix}_gene_embs_dict_{self.tokens_to_perturb}",
|
704 |
)
|
705 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
706 |
def isp_perturb_all(
|
707 |
self,
|
708 |
model,
|
|
|
721 |
|
722 |
if self.emb_mode == "cell_and_gene":
|
723 |
stored_gene_embs_dict = defaultdict(list)
|
724 |
+
for i in trange(len(filtered_input_data)):
|
725 |
+
example_cell = filtered_input_data.select([i])
|
|
|
|
|
726 |
full_original_emb = get_embs(
|
727 |
model,
|
728 |
example_cell,
|
|
|
730 |
layer_to_quant,
|
731 |
self.pad_token_id,
|
732 |
self.forward_batch_size,
|
|
|
733 |
summary_stat=None,
|
734 |
silent=True,
|
735 |
)
|
736 |
|
|
|
|
|
|
|
|
|
|
|
737 |
# gene_list is used to assign cos sims back to genes
|
|
|
738 |
# need to remove the anchor gene
|
739 |
+
gene_list = example_cell["input_ids"][0][:]
|
740 |
if self.anchor_token is not None:
|
741 |
for token in self.anchor_token:
|
742 |
gene_list.remove(token)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
743 |
|
744 |
perturbation_batch, indices_to_perturb = pu.make_perturbation_batch(
|
745 |
example_cell,
|
|
|
750 |
self.nproc,
|
751 |
)
|
752 |
|
753 |
+
full_perturbation_emb = get_embs(
|
754 |
+
model,
|
755 |
+
perturbation_batch,
|
756 |
+
"gene",
|
757 |
+
layer_to_quant,
|
758 |
+
self.pad_token_id,
|
759 |
+
self.forward_batch_size,
|
760 |
+
summary_stat=None,
|
761 |
+
silent=True,
|
762 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
763 |
|
764 |
+
num_inds_perturbed = 1 + self.combos
|
765 |
+
# need to remove overexpressed gene to quantify cosine shifts
|
766 |
+
if self.perturb_type == "overexpress":
|
767 |
+
perturbation_emb = full_perturbation_emb[:, num_inds_perturbed:, :]
|
768 |
+
gene_list = gene_list[
|
769 |
+
num_inds_perturbed:
|
770 |
+
] # index 0 is not overexpressed
|
771 |
|
772 |
+
elif self.perturb_type == "delete":
|
773 |
+
perturbation_emb = full_perturbation_emb
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
774 |
|
775 |
+
original_batch = pu.make_comparison_batch(
|
776 |
+
full_original_emb, indices_to_perturb, perturb_group=False
|
|
|
777 |
)
|
778 |
|
779 |
+
if self.cell_states_to_model is None or self.emb_mode == "cell_and_gene":
|
780 |
+
gene_cos_sims = pu.quant_cos_sims(
|
781 |
+
perturbation_emb,
|
782 |
+
original_batch,
|
783 |
+
self.cell_states_to_model,
|
784 |
+
self.state_embs_dict,
|
785 |
+
emb_mode="gene",
|
786 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
787 |
if self.cell_states_to_model is not None:
|
788 |
+
original_cell_emb = pu.compute_nonpadded_cell_embedding(
|
789 |
+
full_original_emb, "mean_pool"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
790 |
)
|
791 |
+
perturbation_cell_emb = pu.compute_nonpadded_cell_embedding(
|
792 |
+
full_perturbation_emb, "mean_pool"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
793 |
)
|
|
|
|
|
|
|
|
|
794 |
|
795 |
+
cell_cos_sims = pu.quant_cos_sims(
|
796 |
+
perturbation_cell_emb,
|
797 |
+
original_cell_emb,
|
798 |
+
self.cell_states_to_model,
|
799 |
+
self.state_embs_dict,
|
800 |
+
emb_mode="cell",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
801 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
802 |
|
803 |
+
if self.emb_mode == "cell_and_gene":
|
804 |
+
# remove perturbed index for gene list
|
805 |
+
perturbed_gene_dict = {
|
806 |
+
gene: gene_list[:i] + gene_list[i + 1 :]
|
807 |
+
for i, gene in enumerate(gene_list)
|
808 |
+
}
|
|
|
|
|
809 |
|
810 |
+
for perturbation_i, perturbed_gene in enumerate(gene_list):
|
811 |
+
for gene_j, affected_gene in enumerate(
|
812 |
+
perturbed_gene_dict[perturbed_gene]
|
813 |
+
):
|
814 |
+
try:
|
815 |
+
stored_gene_embs_dict[
|
816 |
+
(perturbed_gene, affected_gene)
|
817 |
+
].append(gene_cos_sims[perturbation_i, gene_j].item())
|
818 |
+
except KeyError:
|
819 |
+
stored_gene_embs_dict[
|
820 |
+
(perturbed_gene, affected_gene)
|
821 |
+
] = gene_cos_sims[perturbation_i, gene_j].item()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
822 |
|
823 |
+
if self.cell_states_to_model is None:
|
824 |
+
cos_sims_data = torch.mean(gene_cos_sims, dim=1)
|
825 |
+
cos_sims_dict = self.update_perturbation_dictionary(
|
826 |
+
cos_sims_dict,
|
827 |
+
cos_sims_data,
|
828 |
+
filtered_input_data,
|
829 |
+
indices_to_perturb,
|
830 |
+
gene_list,
|
831 |
+
)
|
832 |
+
else:
|
833 |
+
cos_sims_data = cell_cos_sims
|
834 |
+
for state in cos_sims_dict.keys():
|
835 |
+
cos_sims_dict[state] = self.update_perturbation_dictionary(
|
836 |
+
cos_sims_dict[state],
|
837 |
+
cos_sims_data[state],
|
838 |
+
filtered_input_data,
|
839 |
+
indices_to_perturb,
|
840 |
+
gene_list,
|
841 |
)
|
842 |
|
843 |
+
# save dict to disk every 100 cells
|
844 |
+
if i % 100 == 0:
|
845 |
+
pu.write_perturbation_dictionary(
|
846 |
+
cos_sims_dict,
|
847 |
+
f"{output_path_prefix}_dict_cell_embs_1Kbatch{pickle_batch}",
|
848 |
+
)
|
849 |
+
if self.emb_mode == "cell_and_gene":
|
850 |
+
pu.write_perturbation_dictionary(
|
851 |
+
stored_gene_embs_dict,
|
852 |
+
f"{output_path_prefix}_dict_gene_embs_1Kbatch{pickle_batch}",
|
853 |
)
|
854 |
|
855 |
+
# reset and clear memory every 1000 cells
|
856 |
+
if i % 1000 == 0:
|
857 |
+
pickle_batch += 1
|
858 |
+
if self.cell_states_to_model is None:
|
859 |
+
cos_sims_dict = defaultdict(list)
|
860 |
+
else:
|
861 |
+
cos_sims_dict = {
|
862 |
+
state: defaultdict(list)
|
863 |
+
for state in pu.get_possible_states(self.cell_states_to_model)
|
864 |
+
}
|
|
|
|
|
865 |
|
866 |
+
if self.emb_mode == "cell_and_gene":
|
867 |
+
stored_gene_embs_dict = defaultdict(list)
|
|
|
|
|
868 |
|
869 |
+
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
|
|
|
|
870 |
|
871 |
+
pu.write_perturbation_dictionary(
|
872 |
+
cos_sims_dict, f"{output_path_prefix}_dict_cell_embs_1Kbatch{pickle_batch}"
|
873 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
874 |
|
875 |
+
if self.emb_mode == "cell_and_gene":
|
876 |
pu.write_perturbation_dictionary(
|
877 |
+
stored_gene_embs_dict,
|
878 |
+
f"{output_path_prefix}_dict_gene_embs_1Kbatch{pickle_batch}",
|
879 |
)
|
880 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
881 |
def update_perturbation_dictionary(
|
882 |
self,
|
883 |
cos_sims_dict: defaultdict,
|
884 |
cos_sims_data: torch.Tensor,
|
885 |
+
filtered_input_data: Dataset,
|
886 |
+
indices_to_perturb: List[List[int]],
|
887 |
gene_list=None,
|
888 |
):
|
889 |
if gene_list is not None and cos_sims_data.shape[0] != len(gene_list):
|
890 |
logger.error(
|
891 |
f"len(cos_sims_data.shape[0]) != len(gene_list). \n \
|
892 |
+
cos_sims_data.shape[0] = {cos_sims_data.shape[0]}.\n \
|
893 |
+
len(gene_list) = {len(gene_list)}."
|
894 |
)
|
895 |
raise
|
896 |
|
geneformer/in_silico_perturber_stats.py
CHANGED
@@ -37,8 +37,10 @@ from scipy.stats import ranksums
|
|
37 |
from sklearn.mixture import GaussianMixture
|
38 |
from tqdm.auto import tqdm, trange
|
39 |
|
40 |
-
from . import ENSEMBL_DICTIONARY_FILE, TOKEN_DICTIONARY_FILE
|
41 |
from .perturber_utils import flatten_list, validate_cell_states_to_model
|
|
|
|
|
|
|
42 |
|
43 |
logger = logging.getLogger(__name__)
|
44 |
|
@@ -114,7 +116,6 @@ def read_dictionaries(
|
|
114 |
state_dict[state_value][key] += new_dict[key]
|
115 |
except KeyError:
|
116 |
state_dict[state_value][key] = new_dict[key]
|
117 |
-
|
118 |
if not file_found:
|
119 |
logger.error(
|
120 |
"No raw data for processing found within provided directory. "
|
@@ -191,69 +192,34 @@ def get_impact_component(test_value, gaussian_mixture_model):
|
|
191 |
|
192 |
|
193 |
# aggregate data for single perturbation in multiple cells
|
194 |
-
def isp_aggregate_grouped_perturb(cos_sims_df, dict_list
|
195 |
-
names = ["
|
196 |
-
|
197 |
-
if isinstance(genes_perturbed, list):
|
198 |
-
if len(genes_perturbed) > 1:
|
199 |
-
gene_ids_df = cos_sims_df.loc[
|
200 |
-
np.isin(
|
201 |
-
[set(idx) for idx in cos_sims_df["Ensembl_ID"]],
|
202 |
-
set(genes_perturbed),
|
203 |
-
),
|
204 |
-
:,
|
205 |
-
]
|
206 |
-
else:
|
207 |
-
gene_ids_df = cos_sims_df.loc[
|
208 |
-
np.isin(cos_sims_df["Ensembl_ID"], genes_perturbed), :
|
209 |
-
]
|
210 |
-
else:
|
211 |
-
logger.error(
|
212 |
-
"aggregate_data is for perturbation of single gene or single group of genes. genes_to_perturb should be formatted as list."
|
213 |
-
)
|
214 |
-
raise
|
215 |
-
|
216 |
-
if gene_ids_df.empty:
|
217 |
-
logger.error("genes_to_perturb not found in data.")
|
218 |
-
raise
|
219 |
-
|
220 |
-
tokens = gene_ids_df["Gene"]
|
221 |
-
symbols = gene_ids_df["Gene_name"]
|
222 |
-
|
223 |
-
for token, symbol in zip(tokens, symbols):
|
224 |
-
cos_shift_data = []
|
225 |
-
for dict_i in dict_list:
|
226 |
-
cos_shift_data += dict_i.get((token, "cell_emb"), [])
|
227 |
-
|
228 |
-
df = pd.DataFrame(columns=names)
|
229 |
-
df["Cosine_sim"] = cos_shift_data
|
230 |
-
df["Gene"] = symbol
|
231 |
-
cos_sims_full_dfs.append(df)
|
232 |
|
233 |
-
|
|
|
|
|
|
|
|
|
|
|
234 |
|
235 |
|
236 |
def find(variable, x):
|
237 |
try:
|
238 |
if x in variable: # Test if variable is iterable and contains x
|
239 |
return True
|
240 |
-
elif x == variable:
|
241 |
-
return True
|
242 |
except (ValueError, TypeError):
|
243 |
return x == variable # Test if variable is x if non-iterable
|
244 |
|
245 |
|
246 |
def isp_aggregate_gene_shifts(
|
247 |
-
cos_sims_df, dict_list, gene_token_id_dict, gene_id_name_dict
|
248 |
):
|
249 |
cos_shift_data = dict()
|
250 |
for i in trange(cos_sims_df.shape[0]):
|
251 |
token = cos_sims_df["Gene"][i]
|
252 |
for dict_i in dict_list:
|
253 |
-
if
|
254 |
-
affected_pairs = [k for k, v in dict_i.items() if k[0] == token]
|
255 |
-
else:
|
256 |
-
affected_pairs = [k for k, v in dict_i.items() if find(k[0], token)]
|
257 |
for key in affected_pairs:
|
258 |
if key in cos_shift_data.keys():
|
259 |
cos_shift_data[key] += dict_i.get(key, [])
|
@@ -266,11 +232,11 @@ def isp_aggregate_gene_shifts(
|
|
266 |
cos_sims_full_df = pd.DataFrame()
|
267 |
cos_sims_full_df["Perturbed"] = [k[0] for k, v in cos_data_mean.items()]
|
268 |
cos_sims_full_df["Gene_name"] = [
|
269 |
-
cos_sims_df[cos_sims_df["Gene"] == k[0]]["Gene_name"]
|
270 |
for k, v in cos_data_mean.items()
|
271 |
]
|
272 |
cos_sims_full_df["Ensembl_ID"] = [
|
273 |
-
cos_sims_df[cos_sims_df["Gene"] == k[0]]["Ensembl_ID"]
|
274 |
for k, v in cos_data_mean.items()
|
275 |
]
|
276 |
|
@@ -282,15 +248,15 @@ def isp_aggregate_gene_shifts(
|
|
282 |
cos_sims_full_df["Affected_Ensembl_ID"] = [
|
283 |
gene_token_id_dict.get(token, np.nan) for token in cos_sims_full_df["Affected"]
|
284 |
]
|
285 |
-
cos_sims_full_df["
|
286 |
-
cos_sims_full_df["
|
287 |
cos_sims_full_df["N_Detections"] = [v[2] for k, v in cos_data_mean.items()]
|
288 |
|
289 |
specific_val = "cell_emb"
|
290 |
cos_sims_full_df["temp"] = list(cos_sims_full_df["Affected"] == specific_val)
|
291 |
-
# reorder so cell embs are at the top and all are subordered by magnitude of cosine
|
292 |
cos_sims_full_df = cos_sims_full_df.sort_values(
|
293 |
-
by=(["temp", "
|
294 |
).drop("temp", axis=1)
|
295 |
|
296 |
return cos_sims_full_df
|
@@ -681,7 +647,7 @@ class InSilicoPerturberStats:
|
|
681 |
cell_states_to_model=None,
|
682 |
pickle_suffix="_raw.pickle",
|
683 |
token_dictionary_file=TOKEN_DICTIONARY_FILE,
|
684 |
-
gene_name_id_dictionary_file=
|
685 |
):
|
686 |
"""
|
687 |
Initialize in silico perturber stats generator.
|
@@ -700,7 +666,7 @@ class InSilicoPerturberStats:
|
|
700 |
| Default is assuming genes_to_perturb in isp experiment was "all" (each gene in each cell).
|
701 |
| Otherwise, may provide a list of ENSEMBL IDs of genes perturbed as a group all together.
|
702 |
combos : {0,1,2}
|
703 |
-
| Whether
|
704 |
anchor_gene : None, str
|
705 |
| ENSEMBL ID of gene to use as anchor in combination perturbations or in testing effect on downstream genes.
|
706 |
| For example, if combos=1 and anchor_gene="ENSG00000136574":
|
@@ -948,11 +914,11 @@ class InSilicoPerturberStats:
|
|
948 |
| 1: within impact component; 0: not within impact component
|
949 |
| "Impact_component_percent": percent of cells in which given perturbation was modeled to be within impact component
|
950 |
|
951 |
-
| In case of aggregating
|
952 |
| "Perturbed": ID(s) of gene(s) being perturbed
|
953 |
| "Affected": ID of affected gene or "cell_emb" indicating the impact on the cell embedding as a whole
|
954 |
-
| "
|
955 |
-
| "
|
956 |
"""
|
957 |
|
958 |
if self.mode not in [
|
@@ -1052,30 +1018,14 @@ class InSilicoPerturberStats:
|
|
1052 |
)
|
1053 |
|
1054 |
elif self.mode == "aggregate_data":
|
1055 |
-
cos_sims_df = isp_aggregate_grouped_perturb(
|
1056 |
-
cos_sims_df_initial, dict_list, self.genes_perturbed
|
1057 |
-
)
|
1058 |
|
1059 |
elif self.mode == "aggregate_gene_shifts":
|
1060 |
-
if (self.genes_perturbed == "all") and (self.combos == 0):
|
1061 |
-
tuple_types = [
|
1062 |
-
True if isinstance(genes, tuple) else False for genes in gene_list
|
1063 |
-
]
|
1064 |
-
if all(tuple_types):
|
1065 |
-
token_dtype = "tuple"
|
1066 |
-
elif not any(tuple_types):
|
1067 |
-
token_dtype = "nontuple"
|
1068 |
-
else:
|
1069 |
-
token_dtype = "mix"
|
1070 |
-
else:
|
1071 |
-
token_dtype = "mix"
|
1072 |
-
|
1073 |
cos_sims_df = isp_aggregate_gene_shifts(
|
1074 |
cos_sims_df_initial,
|
1075 |
dict_list,
|
1076 |
self.gene_token_id_dict,
|
1077 |
self.gene_id_name_dict,
|
1078 |
-
token_dtype,
|
1079 |
)
|
1080 |
|
1081 |
# save perturbation stats to output_path
|
|
|
37 |
from sklearn.mixture import GaussianMixture
|
38 |
from tqdm.auto import tqdm, trange
|
39 |
|
|
|
40 |
from .perturber_utils import flatten_list, validate_cell_states_to_model
|
41 |
+
from .tokenizer import TOKEN_DICTIONARY_FILE
|
42 |
+
|
43 |
+
GENE_NAME_ID_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl"
|
44 |
|
45 |
logger = logging.getLogger(__name__)
|
46 |
|
|
|
116 |
state_dict[state_value][key] += new_dict[key]
|
117 |
except KeyError:
|
118 |
state_dict[state_value][key] = new_dict[key]
|
|
|
119 |
if not file_found:
|
120 |
logger.error(
|
121 |
"No raw data for processing found within provided directory. "
|
|
|
192 |
|
193 |
|
194 |
# aggregate data for single perturbation in multiple cells
|
195 |
+
def isp_aggregate_grouped_perturb(cos_sims_df, dict_list):
|
196 |
+
names = ["Cosine_shift"]
|
197 |
+
cos_sims_full_df = pd.DataFrame(columns=names)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
|
199 |
+
cos_shift_data = []
|
200 |
+
token = cos_sims_df["Gene"][0]
|
201 |
+
for dict_i in dict_list:
|
202 |
+
cos_shift_data += dict_i.get((token, "cell_emb"), [])
|
203 |
+
cos_sims_full_df["Cosine_shift"] = cos_shift_data
|
204 |
+
return cos_sims_full_df
|
205 |
|
206 |
|
207 |
def find(variable, x):
|
208 |
try:
|
209 |
if x in variable: # Test if variable is iterable and contains x
|
210 |
return True
|
|
|
|
|
211 |
except (ValueError, TypeError):
|
212 |
return x == variable # Test if variable is x if non-iterable
|
213 |
|
214 |
|
215 |
def isp_aggregate_gene_shifts(
|
216 |
+
cos_sims_df, dict_list, gene_token_id_dict, gene_id_name_dict
|
217 |
):
|
218 |
cos_shift_data = dict()
|
219 |
for i in trange(cos_sims_df.shape[0]):
|
220 |
token = cos_sims_df["Gene"][i]
|
221 |
for dict_i in dict_list:
|
222 |
+
affected_pairs = [k for k, v in dict_i.items() if find(k[0], token)]
|
|
|
|
|
|
|
223 |
for key in affected_pairs:
|
224 |
if key in cos_shift_data.keys():
|
225 |
cos_shift_data[key] += dict_i.get(key, [])
|
|
|
232 |
cos_sims_full_df = pd.DataFrame()
|
233 |
cos_sims_full_df["Perturbed"] = [k[0] for k, v in cos_data_mean.items()]
|
234 |
cos_sims_full_df["Gene_name"] = [
|
235 |
+
cos_sims_df[cos_sims_df["Gene"] == k[0]]["Gene_name"][0]
|
236 |
for k, v in cos_data_mean.items()
|
237 |
]
|
238 |
cos_sims_full_df["Ensembl_ID"] = [
|
239 |
+
cos_sims_df[cos_sims_df["Gene"] == k[0]]["Ensembl_ID"][0]
|
240 |
for k, v in cos_data_mean.items()
|
241 |
]
|
242 |
|
|
|
248 |
cos_sims_full_df["Affected_Ensembl_ID"] = [
|
249 |
gene_token_id_dict.get(token, np.nan) for token in cos_sims_full_df["Affected"]
|
250 |
]
|
251 |
+
cos_sims_full_df["Cosine_shift_mean"] = [v[0] for k, v in cos_data_mean.items()]
|
252 |
+
cos_sims_full_df["Cosine_shift_stdev"] = [v[1] for k, v in cos_data_mean.items()]
|
253 |
cos_sims_full_df["N_Detections"] = [v[2] for k, v in cos_data_mean.items()]
|
254 |
|
255 |
specific_val = "cell_emb"
|
256 |
cos_sims_full_df["temp"] = list(cos_sims_full_df["Affected"] == specific_val)
|
257 |
+
# reorder so cell embs are at the top and all are subordered by magnitude of cosine shift
|
258 |
cos_sims_full_df = cos_sims_full_df.sort_values(
|
259 |
+
by=(["temp", "Cosine_shift_mean"]), ascending=[False, False]
|
260 |
).drop("temp", axis=1)
|
261 |
|
262 |
return cos_sims_full_df
|
|
|
647 |
cell_states_to_model=None,
|
648 |
pickle_suffix="_raw.pickle",
|
649 |
token_dictionary_file=TOKEN_DICTIONARY_FILE,
|
650 |
+
gene_name_id_dictionary_file=GENE_NAME_ID_DICTIONARY_FILE,
|
651 |
):
|
652 |
"""
|
653 |
Initialize in silico perturber stats generator.
|
|
|
666 |
| Default is assuming genes_to_perturb in isp experiment was "all" (each gene in each cell).
|
667 |
| Otherwise, may provide a list of ENSEMBL IDs of genes perturbed as a group all together.
|
668 |
combos : {0,1,2}
|
669 |
+
| Whether to perturb genes individually (0), in pairs (1), or in triplets (2).
|
670 |
anchor_gene : None, str
|
671 |
| ENSEMBL ID of gene to use as anchor in combination perturbations or in testing effect on downstream genes.
|
672 |
| For example, if combos=1 and anchor_gene="ENSG00000136574":
|
|
|
914 |
| 1: within impact component; 0: not within impact component
|
915 |
| "Impact_component_percent": percent of cells in which given perturbation was modeled to be within impact component
|
916 |
|
917 |
+
| In case of aggregating gene shifts:
|
918 |
| "Perturbed": ID(s) of gene(s) being perturbed
|
919 |
| "Affected": ID of affected gene or "cell_emb" indicating the impact on the cell embedding as a whole
|
920 |
+
| "Cosine_shift_mean": mean of cosine shift of modeled perturbation on affected gene or cell
|
921 |
+
| "Cosine_shift_stdev": standard deviation of cosine shift of modeled perturbation on affected gene or cell
|
922 |
"""
|
923 |
|
924 |
if self.mode not in [
|
|
|
1018 |
)
|
1019 |
|
1020 |
elif self.mode == "aggregate_data":
|
1021 |
+
cos_sims_df = isp_aggregate_grouped_perturb(cos_sims_df_initial, dict_list)
|
|
|
|
|
1022 |
|
1023 |
elif self.mode == "aggregate_gene_shifts":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1024 |
cos_sims_df = isp_aggregate_gene_shifts(
|
1025 |
cos_sims_df_initial,
|
1026 |
dict_list,
|
1027 |
self.gene_token_id_dict,
|
1028 |
self.gene_id_name_dict,
|
|
|
1029 |
)
|
1030 |
|
1031 |
# save perturbation stats to output_path
|
geneformer/mtl/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
# ruff: noqa: F401
|
|
|
|
geneformer/mtl/collators.py
DELETED
@@ -1,76 +0,0 @@
|
|
1 |
-
# imports
|
2 |
-
import torch
|
3 |
-
import pickle
|
4 |
-
from ..collator_for_classification import DataCollatorForGeneClassification
|
5 |
-
from .. import TOKEN_DICTIONARY_FILE
|
6 |
-
|
7 |
-
"""Geneformer collator for multi-task cell classification."""
|
8 |
-
|
9 |
-
class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassification):
|
10 |
-
class_type = "cell"
|
11 |
-
|
12 |
-
@staticmethod
|
13 |
-
def load_token_dictionary():
|
14 |
-
with open(TOKEN_DICTIONARY_FILE, 'rb') as f:
|
15 |
-
return pickle.load(f)
|
16 |
-
|
17 |
-
def __init__(self, *args, **kwargs) -> None:
|
18 |
-
# Load the token dictionary
|
19 |
-
token_dictionary = self.load_token_dictionary()
|
20 |
-
# Use the loaded token dictionary
|
21 |
-
super().__init__(token_dictionary=token_dictionary, *args, **kwargs)
|
22 |
-
|
23 |
-
def _prepare_batch(self, features):
|
24 |
-
# Process inputs as usual
|
25 |
-
batch = self.tokenizer.pad(
|
26 |
-
features,
|
27 |
-
class_type=self.class_type,
|
28 |
-
padding=self.padding,
|
29 |
-
max_length=self.max_length,
|
30 |
-
pad_to_multiple_of=self.pad_to_multiple_of,
|
31 |
-
return_tensors="pt",
|
32 |
-
)
|
33 |
-
|
34 |
-
# Check if labels are present
|
35 |
-
if "label" in features[0]:
|
36 |
-
# Initialize labels dictionary for all tasks
|
37 |
-
labels = {task: [] for task in features[0]["label"].keys()}
|
38 |
-
# Populate labels for each task
|
39 |
-
for feature in features:
|
40 |
-
for task, label in feature["label"].items():
|
41 |
-
labels[task].append(label)
|
42 |
-
|
43 |
-
# Convert label lists to tensors, handling dictionaries appropriately
|
44 |
-
for task in labels:
|
45 |
-
if isinstance(labels[task][0], (list, torch.Tensor)):
|
46 |
-
dtype = torch.long
|
47 |
-
labels[task] = torch.tensor(labels[task], dtype=dtype)
|
48 |
-
elif isinstance(labels[task][0], dict):
|
49 |
-
# Handle dict specifically if needed
|
50 |
-
pass # Resolve nested data structure
|
51 |
-
|
52 |
-
# Update the batch to include task-specific labels
|
53 |
-
batch["labels"] = labels
|
54 |
-
else:
|
55 |
-
# If no labels are present, create empty labels for all tasks
|
56 |
-
batch["labels"] = {
|
57 |
-
task: torch.tensor([], dtype=torch.long)
|
58 |
-
for task in features[0]["input_ids"].keys()
|
59 |
-
}
|
60 |
-
|
61 |
-
return batch
|
62 |
-
|
63 |
-
def __call__(self, features):
|
64 |
-
batch = self._prepare_batch(features)
|
65 |
-
for k, v in batch.items():
|
66 |
-
if torch.is_tensor(v):
|
67 |
-
batch[k] = v.clone().detach()
|
68 |
-
elif isinstance(v, dict):
|
69 |
-
# Assuming nested structure needs conversion
|
70 |
-
batch[k] = {
|
71 |
-
task: torch.tensor(labels, dtype=torch.int64)
|
72 |
-
for task, labels in v.items()
|
73 |
-
}
|
74 |
-
else:
|
75 |
-
batch[k] = torch.tensor(v, dtype=torch.int64)
|
76 |
-
return batch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
geneformer/mtl/data.py
DELETED
@@ -1,150 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
|
3 |
-
from .collators import DataCollatorForMultitaskCellClassification
|
4 |
-
from .imports import *
|
5 |
-
|
6 |
-
|
7 |
-
def load_and_preprocess_data(dataset_path, config, is_test=False, dataset_type=""):
|
8 |
-
try:
|
9 |
-
dataset = load_from_disk(dataset_path)
|
10 |
-
|
11 |
-
task_names = [f"task{i+1}" for i in range(len(config["task_columns"]))]
|
12 |
-
task_to_column = dict(zip(task_names, config["task_columns"]))
|
13 |
-
config["task_names"] = task_names
|
14 |
-
|
15 |
-
if not is_test:
|
16 |
-
available_columns = set(dataset.column_names)
|
17 |
-
for column in task_to_column.values():
|
18 |
-
if column not in available_columns:
|
19 |
-
raise KeyError(
|
20 |
-
f"Column {column} not found in the dataset. Available columns: {list(available_columns)}"
|
21 |
-
)
|
22 |
-
|
23 |
-
label_mappings = {}
|
24 |
-
task_label_mappings = {}
|
25 |
-
cell_id_mapping = {}
|
26 |
-
num_labels_list = []
|
27 |
-
|
28 |
-
# Load or create task label mappings
|
29 |
-
if not is_test:
|
30 |
-
for task, column in task_to_column.items():
|
31 |
-
unique_values = sorted(set(dataset[column])) # Ensure consistency
|
32 |
-
label_mappings[column] = {
|
33 |
-
label: idx for idx, label in enumerate(unique_values)
|
34 |
-
}
|
35 |
-
task_label_mappings[task] = label_mappings[column]
|
36 |
-
num_labels_list.append(len(unique_values))
|
37 |
-
|
38 |
-
# Print the mappings for each task with dataset type prefix
|
39 |
-
for task, mapping in task_label_mappings.items():
|
40 |
-
print(
|
41 |
-
f"{dataset_type.capitalize()} mapping for {task}: {mapping}"
|
42 |
-
) # sanity check, for train/validation splits
|
43 |
-
|
44 |
-
# Save the task label mappings as a pickle file
|
45 |
-
with open(f"{config['results_dir']}/task_label_mappings.pkl", "wb") as f:
|
46 |
-
pickle.dump(task_label_mappings, f)
|
47 |
-
else:
|
48 |
-
# Load task label mappings from pickle file for test data
|
49 |
-
with open(f"{config['results_dir']}/task_label_mappings.pkl", "rb") as f:
|
50 |
-
task_label_mappings = pickle.load(f)
|
51 |
-
|
52 |
-
# Infer num_labels_list from task_label_mappings
|
53 |
-
for task, mapping in task_label_mappings.items():
|
54 |
-
num_labels_list.append(len(mapping))
|
55 |
-
|
56 |
-
# Store unique cell IDs in a separate dictionary
|
57 |
-
for idx, record in enumerate(dataset):
|
58 |
-
cell_id = record.get("unique_cell_id", idx)
|
59 |
-
cell_id_mapping[idx] = cell_id
|
60 |
-
|
61 |
-
# Transform records to the desired format
|
62 |
-
transformed_dataset = []
|
63 |
-
for idx, record in enumerate(dataset):
|
64 |
-
transformed_record = {}
|
65 |
-
transformed_record["input_ids"] = torch.tensor(
|
66 |
-
record["input_ids"], dtype=torch.long
|
67 |
-
)
|
68 |
-
|
69 |
-
# Use index-based cell ID for internal tracking
|
70 |
-
transformed_record["cell_id"] = idx
|
71 |
-
|
72 |
-
if not is_test:
|
73 |
-
# Prepare labels
|
74 |
-
label_dict = {}
|
75 |
-
for task, column in task_to_column.items():
|
76 |
-
label_value = record[column]
|
77 |
-
label_index = task_label_mappings[task][label_value]
|
78 |
-
label_dict[task] = label_index
|
79 |
-
transformed_record["label"] = label_dict
|
80 |
-
else:
|
81 |
-
# Create dummy labels for test data
|
82 |
-
label_dict = {task: -1 for task in config["task_names"]}
|
83 |
-
transformed_record["label"] = label_dict
|
84 |
-
|
85 |
-
transformed_dataset.append(transformed_record)
|
86 |
-
|
87 |
-
return transformed_dataset, cell_id_mapping, num_labels_list
|
88 |
-
except KeyError as e:
|
89 |
-
print(f"Missing configuration or dataset key: {e}")
|
90 |
-
except Exception as e:
|
91 |
-
print(f"An error occurred while loading or preprocessing data: {e}")
|
92 |
-
return None, None, None
|
93 |
-
|
94 |
-
|
95 |
-
def preload_and_process_data(config):
|
96 |
-
# Load and preprocess data once
|
97 |
-
train_dataset, train_cell_id_mapping, num_labels_list = load_and_preprocess_data(
|
98 |
-
config["train_path"], config, dataset_type="train"
|
99 |
-
)
|
100 |
-
val_dataset, val_cell_id_mapping, _ = load_and_preprocess_data(
|
101 |
-
config["val_path"], config, dataset_type="validation"
|
102 |
-
)
|
103 |
-
return (
|
104 |
-
train_dataset,
|
105 |
-
train_cell_id_mapping,
|
106 |
-
val_dataset,
|
107 |
-
val_cell_id_mapping,
|
108 |
-
num_labels_list,
|
109 |
-
)
|
110 |
-
|
111 |
-
|
112 |
-
def get_data_loader(preprocessed_dataset, batch_size):
|
113 |
-
nproc = os.cpu_count() ### I/O operations
|
114 |
-
|
115 |
-
data_collator = DataCollatorForMultitaskCellClassification()
|
116 |
-
|
117 |
-
loader = DataLoader(
|
118 |
-
preprocessed_dataset,
|
119 |
-
batch_size=batch_size,
|
120 |
-
shuffle=True,
|
121 |
-
collate_fn=data_collator,
|
122 |
-
num_workers=nproc,
|
123 |
-
pin_memory=True,
|
124 |
-
)
|
125 |
-
return loader
|
126 |
-
|
127 |
-
|
128 |
-
def preload_data(config):
|
129 |
-
# Preprocessing the data before the Optuna trials start
|
130 |
-
train_loader = get_data_loader("train", config)
|
131 |
-
val_loader = get_data_loader("val", config)
|
132 |
-
return train_loader, val_loader
|
133 |
-
|
134 |
-
|
135 |
-
def load_and_preprocess_test_data(config):
|
136 |
-
"""
|
137 |
-
Load and preprocess test data, treating it as unlabeled.
|
138 |
-
"""
|
139 |
-
return load_and_preprocess_data(config["test_path"], config, is_test=True)
|
140 |
-
|
141 |
-
|
142 |
-
def prepare_test_loader(config):
|
143 |
-
"""
|
144 |
-
Prepare DataLoader for the test dataset.
|
145 |
-
"""
|
146 |
-
test_dataset, cell_id_mapping, num_labels_list = load_and_preprocess_test_data(
|
147 |
-
config
|
148 |
-
)
|
149 |
-
test_loader = get_data_loader(test_dataset, config["batch_size"])
|
150 |
-
return test_loader, cell_id_mapping, num_labels_list
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
geneformer/mtl/eval_utils.py
DELETED
@@ -1,88 +0,0 @@
|
|
1 |
-
import pandas as pd
|
2 |
-
|
3 |
-
from .imports import * # noqa # isort:skip
|
4 |
-
from .data import prepare_test_loader # noqa # isort:skip
|
5 |
-
from .model import GeneformerMultiTask
|
6 |
-
|
7 |
-
|
8 |
-
def evaluate_test_dataset(model, device, test_loader, cell_id_mapping, config):
|
9 |
-
task_pred_labels = {task_name: [] for task_name in config["task_names"]}
|
10 |
-
task_pred_probs = {task_name: [] for task_name in config["task_names"]}
|
11 |
-
cell_ids = []
|
12 |
-
|
13 |
-
# # Load task label mappings from pickle file
|
14 |
-
# with open(f"{config['results_dir']}/task_label_mappings.pkl", "rb") as f:
|
15 |
-
# task_label_mappings = pickle.load(f)
|
16 |
-
|
17 |
-
model.eval()
|
18 |
-
with torch.no_grad():
|
19 |
-
for batch in test_loader:
|
20 |
-
input_ids = batch["input_ids"].to(device)
|
21 |
-
attention_mask = batch["attention_mask"].to(device)
|
22 |
-
_, logits, _ = model(input_ids, attention_mask)
|
23 |
-
for sample_idx in range(len(batch["input_ids"])):
|
24 |
-
cell_id = cell_id_mapping[batch["cell_id"][sample_idx].item()]
|
25 |
-
cell_ids.append(cell_id)
|
26 |
-
for i, task_name in enumerate(config["task_names"]):
|
27 |
-
pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item()
|
28 |
-
pred_prob = (
|
29 |
-
torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy()
|
30 |
-
)
|
31 |
-
task_pred_labels[task_name].append(pred_label)
|
32 |
-
task_pred_probs[task_name].append(pred_prob)
|
33 |
-
|
34 |
-
# Save test predictions with cell IDs and probabilities to CSV
|
35 |
-
test_results_dir = config["results_dir"]
|
36 |
-
os.makedirs(test_results_dir, exist_ok=True)
|
37 |
-
test_preds_file = os.path.join(test_results_dir, "test_preds.csv")
|
38 |
-
|
39 |
-
rows = []
|
40 |
-
for sample_idx in range(len(cell_ids)):
|
41 |
-
row = {"Cell ID": cell_ids[sample_idx]}
|
42 |
-
for task_name in config["task_names"]:
|
43 |
-
row[f"{task_name} Prediction"] = task_pred_labels[task_name][sample_idx]
|
44 |
-
row[f"{task_name} Probabilities"] = ",".join(
|
45 |
-
map(str, task_pred_probs[task_name][sample_idx])
|
46 |
-
)
|
47 |
-
rows.append(row)
|
48 |
-
|
49 |
-
df = pd.DataFrame(rows)
|
50 |
-
df.to_csv(test_preds_file, index=False)
|
51 |
-
print(f"Test predictions saved to {test_preds_file}")
|
52 |
-
|
53 |
-
|
54 |
-
def load_and_evaluate_test_model(config):
|
55 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
56 |
-
test_loader, cell_id_mapping, num_labels_list = prepare_test_loader(config)
|
57 |
-
model_directory = os.path.join(config["model_save_path"], "GeneformerMultiTask")
|
58 |
-
hyperparams_path = os.path.join(model_directory, "hyperparameters.json")
|
59 |
-
|
60 |
-
# Load the saved best hyperparameters
|
61 |
-
with open(hyperparams_path, "r") as f:
|
62 |
-
best_hyperparams = json.load(f)
|
63 |
-
|
64 |
-
# Extract the task weights if present, otherwise set to None
|
65 |
-
task_weights = best_hyperparams.get("task_weights", None)
|
66 |
-
normalized_task_weights = task_weights if task_weights else []
|
67 |
-
|
68 |
-
# Print the loaded hyperparameters
|
69 |
-
print("Loaded hyperparameters:")
|
70 |
-
for param, value in best_hyperparams.items():
|
71 |
-
if param == "task_weights":
|
72 |
-
print(f"normalized_task_weights: {value}")
|
73 |
-
else:
|
74 |
-
print(f"{param}: {value}")
|
75 |
-
|
76 |
-
best_model_path = os.path.join(model_directory, "pytorch_model.bin")
|
77 |
-
best_model = GeneformerMultiTask(
|
78 |
-
config["pretrained_path"],
|
79 |
-
num_labels_list,
|
80 |
-
dropout_rate=best_hyperparams["dropout_rate"],
|
81 |
-
use_task_weights=config["use_task_weights"],
|
82 |
-
task_weights=normalized_task_weights,
|
83 |
-
)
|
84 |
-
best_model.load_state_dict(torch.load(best_model_path))
|
85 |
-
best_model.to(device)
|
86 |
-
|
87 |
-
evaluate_test_dataset(best_model, device, test_loader, cell_id_mapping, config)
|
88 |
-
print("Evaluation completed.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
geneformer/mtl/imports.py
DELETED
@@ -1,43 +0,0 @@
|
|
1 |
-
import functools
|
2 |
-
import gc
|
3 |
-
import json
|
4 |
-
import os
|
5 |
-
import pickle
|
6 |
-
import sys
|
7 |
-
import warnings
|
8 |
-
from enum import Enum
|
9 |
-
from itertools import chain
|
10 |
-
from typing import Dict, List, Optional, Union
|
11 |
-
|
12 |
-
import numpy as np
|
13 |
-
import optuna
|
14 |
-
import pandas as pd
|
15 |
-
import torch
|
16 |
-
import torch.nn as nn
|
17 |
-
import torch.nn.functional as F
|
18 |
-
import torch.optim as optim
|
19 |
-
from datasets import load_from_disk
|
20 |
-
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, roc_curve
|
21 |
-
from sklearn.model_selection import train_test_split
|
22 |
-
from sklearn.preprocessing import LabelEncoder
|
23 |
-
from torch.utils.data import DataLoader
|
24 |
-
from transformers import (
|
25 |
-
AdamW,
|
26 |
-
BatchEncoding,
|
27 |
-
BertConfig,
|
28 |
-
BertModel,
|
29 |
-
DataCollatorForTokenClassification,
|
30 |
-
SpecialTokensMixin,
|
31 |
-
get_cosine_schedule_with_warmup,
|
32 |
-
get_linear_schedule_with_warmup,
|
33 |
-
get_scheduler,
|
34 |
-
)
|
35 |
-
from transformers.utils import logging, to_py_obj
|
36 |
-
|
37 |
-
from .collators import DataCollatorForMultitaskCellClassification
|
38 |
-
|
39 |
-
# local modules
|
40 |
-
from .data import get_data_loader, preload_and_process_data
|
41 |
-
from .model import GeneformerMultiTask
|
42 |
-
from .optuna_utils import create_optuna_study
|
43 |
-
from .utils import save_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
geneformer/mtl/model.py
DELETED
@@ -1,121 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
from transformers import BertConfig, BertModel
|
4 |
-
|
5 |
-
|
6 |
-
class AttentionPool(nn.Module):
|
7 |
-
"""Attention-based pooling layer."""
|
8 |
-
|
9 |
-
def __init__(self, hidden_size):
|
10 |
-
super(AttentionPool, self).__init__()
|
11 |
-
self.attention_weights = nn.Parameter(torch.randn(hidden_size, 1))
|
12 |
-
nn.init.xavier_uniform_(
|
13 |
-
self.attention_weights
|
14 |
-
) # https://pytorch.org/docs/stable/nn.init.html
|
15 |
-
|
16 |
-
def forward(self, hidden_states):
|
17 |
-
attention_scores = torch.matmul(hidden_states, self.attention_weights)
|
18 |
-
attention_scores = torch.softmax(attention_scores, dim=1)
|
19 |
-
pooled_output = torch.sum(hidden_states * attention_scores, dim=1)
|
20 |
-
return pooled_output
|
21 |
-
|
22 |
-
|
23 |
-
class GeneformerMultiTask(nn.Module):
|
24 |
-
def __init__(
|
25 |
-
self,
|
26 |
-
pretrained_path,
|
27 |
-
num_labels_list,
|
28 |
-
dropout_rate=0.1,
|
29 |
-
use_task_weights=False,
|
30 |
-
task_weights=None,
|
31 |
-
max_layers_to_freeze=0,
|
32 |
-
use_attention_pooling=False,
|
33 |
-
):
|
34 |
-
super(GeneformerMultiTask, self).__init__()
|
35 |
-
self.config = BertConfig.from_pretrained(pretrained_path)
|
36 |
-
self.bert = BertModel(self.config)
|
37 |
-
self.num_labels_list = num_labels_list
|
38 |
-
self.use_task_weights = use_task_weights
|
39 |
-
self.dropout = nn.Dropout(dropout_rate)
|
40 |
-
self.use_attention_pooling = use_attention_pooling
|
41 |
-
|
42 |
-
if use_task_weights and (
|
43 |
-
task_weights is None or len(task_weights) != len(num_labels_list)
|
44 |
-
):
|
45 |
-
raise ValueError(
|
46 |
-
"Task weights must be defined and match the number of tasks when 'use_task_weights' is True."
|
47 |
-
)
|
48 |
-
self.task_weights = (
|
49 |
-
task_weights if use_task_weights else [1.0] * len(num_labels_list)
|
50 |
-
)
|
51 |
-
|
52 |
-
# Freeze the specified initial layers
|
53 |
-
for layer in self.bert.encoder.layer[:max_layers_to_freeze]:
|
54 |
-
for param in layer.parameters():
|
55 |
-
param.requires_grad = False
|
56 |
-
|
57 |
-
self.attention_pool = (
|
58 |
-
AttentionPool(self.config.hidden_size) if use_attention_pooling else None
|
59 |
-
)
|
60 |
-
|
61 |
-
self.classification_heads = nn.ModuleList(
|
62 |
-
[
|
63 |
-
nn.Linear(self.config.hidden_size, num_labels)
|
64 |
-
for num_labels in num_labels_list
|
65 |
-
]
|
66 |
-
)
|
67 |
-
# initialization of the classification heads: https://pytorch.org/docs/stable/nn.init.html
|
68 |
-
for head in self.classification_heads:
|
69 |
-
nn.init.xavier_uniform_(head.weight)
|
70 |
-
nn.init.zeros_(head.bias)
|
71 |
-
|
72 |
-
def forward(self, input_ids, attention_mask, labels=None):
|
73 |
-
try:
|
74 |
-
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
75 |
-
except Exception as e:
|
76 |
-
raise RuntimeError(f"Error during BERT forward pass: {e}")
|
77 |
-
|
78 |
-
sequence_output = outputs.last_hidden_state
|
79 |
-
|
80 |
-
try:
|
81 |
-
pooled_output = (
|
82 |
-
self.attention_pool(sequence_output)
|
83 |
-
if self.use_attention_pooling
|
84 |
-
else sequence_output[:, 0, :]
|
85 |
-
)
|
86 |
-
pooled_output = self.dropout(pooled_output)
|
87 |
-
except Exception as e:
|
88 |
-
raise RuntimeError(f"Error during pooling and dropout: {e}")
|
89 |
-
|
90 |
-
total_loss = 0
|
91 |
-
logits = []
|
92 |
-
losses = []
|
93 |
-
|
94 |
-
for task_id, (head, num_labels) in enumerate(
|
95 |
-
zip(self.classification_heads, self.num_labels_list)
|
96 |
-
):
|
97 |
-
try:
|
98 |
-
task_logits = head(pooled_output)
|
99 |
-
except Exception as e:
|
100 |
-
raise RuntimeError(
|
101 |
-
f"Error during forward pass of classification head {task_id}: {e}"
|
102 |
-
)
|
103 |
-
|
104 |
-
logits.append(task_logits)
|
105 |
-
|
106 |
-
if labels is not None:
|
107 |
-
try:
|
108 |
-
loss_fct = nn.CrossEntropyLoss()
|
109 |
-
task_loss = loss_fct(
|
110 |
-
task_logits.view(-1, num_labels), labels[task_id].view(-1)
|
111 |
-
)
|
112 |
-
if self.use_task_weights:
|
113 |
-
task_loss *= self.task_weights[task_id]
|
114 |
-
total_loss += task_loss
|
115 |
-
losses.append(task_loss.item())
|
116 |
-
except Exception as e:
|
117 |
-
raise RuntimeError(
|
118 |
-
f"Error during loss computation for task {task_id}: {e}"
|
119 |
-
)
|
120 |
-
|
121 |
-
return total_loss, logits, losses if labels is not None else logits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
geneformer/mtl/optuna_utils.py
DELETED
@@ -1,27 +0,0 @@
|
|
1 |
-
import optuna
|
2 |
-
from optuna.integration import TensorBoardCallback
|
3 |
-
|
4 |
-
|
5 |
-
def save_trial_callback(study, trial, trials_result_path):
|
6 |
-
with open(trials_result_path, "a") as f:
|
7 |
-
f.write(
|
8 |
-
f"Trial {trial.number}: Value (F1 Macro): {trial.value}, Params: {trial.params}\n"
|
9 |
-
)
|
10 |
-
|
11 |
-
|
12 |
-
def create_optuna_study(objective, n_trials, trials_result_path, tensorboard_log_dir):
|
13 |
-
study = optuna.create_study(direction="maximize")
|
14 |
-
|
15 |
-
# init TensorBoard callback
|
16 |
-
tensorboard_callback = TensorBoardCallback(
|
17 |
-
dirname=tensorboard_log_dir, metric_name="F1 Macro"
|
18 |
-
)
|
19 |
-
|
20 |
-
# callback and TensorBoard callback
|
21 |
-
callbacks = [
|
22 |
-
lambda study, trial: save_trial_callback(study, trial, trials_result_path),
|
23 |
-
tensorboard_callback,
|
24 |
-
]
|
25 |
-
|
26 |
-
study.optimize(objective, n_trials=n_trials, callbacks=callbacks)
|
27 |
-
return study
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
geneformer/mtl/train.py
DELETED
@@ -1,380 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import random
|
3 |
-
|
4 |
-
import numpy as np
|
5 |
-
import pandas as pd
|
6 |
-
import torch
|
7 |
-
from torch.utils.tensorboard import SummaryWriter
|
8 |
-
from tqdm import tqdm
|
9 |
-
|
10 |
-
from .imports import *
|
11 |
-
from .model import GeneformerMultiTask
|
12 |
-
from .utils import calculate_task_specific_metrics, get_layer_freeze_range
|
13 |
-
|
14 |
-
|
15 |
-
def set_seed(seed):
|
16 |
-
random.seed(seed)
|
17 |
-
np.random.seed(seed)
|
18 |
-
torch.manual_seed(seed)
|
19 |
-
torch.cuda.manual_seed_all(seed)
|
20 |
-
torch.backends.cudnn.deterministic = True
|
21 |
-
torch.backends.cudnn.benchmark = False
|
22 |
-
|
23 |
-
|
24 |
-
def initialize_wandb(config):
|
25 |
-
if config.get("use_wandb", False):
|
26 |
-
import wandb
|
27 |
-
|
28 |
-
wandb.init(project=config["wandb_project"], config=config)
|
29 |
-
print("Weights & Biases (wandb) initialized and will be used for logging.")
|
30 |
-
else:
|
31 |
-
print(
|
32 |
-
"Weights & Biases (wandb) is not enabled. Logging will use other methods."
|
33 |
-
)
|
34 |
-
|
35 |
-
|
36 |
-
def create_model(config, num_labels_list, device):
|
37 |
-
model = GeneformerMultiTask(
|
38 |
-
config["pretrained_path"],
|
39 |
-
num_labels_list,
|
40 |
-
dropout_rate=config["dropout_rate"],
|
41 |
-
use_task_weights=config["use_task_weights"],
|
42 |
-
task_weights=config["task_weights"],
|
43 |
-
max_layers_to_freeze=config["max_layers_to_freeze"],
|
44 |
-
use_attention_pooling=config["use_attention_pooling"],
|
45 |
-
)
|
46 |
-
if config["use_data_parallel"]:
|
47 |
-
model = nn.DataParallel(model)
|
48 |
-
return model.to(device)
|
49 |
-
|
50 |
-
|
51 |
-
def setup_optimizer_and_scheduler(model, config, total_steps):
|
52 |
-
optimizer = AdamW(
|
53 |
-
model.parameters(),
|
54 |
-
lr=config["learning_rate"],
|
55 |
-
weight_decay=config["weight_decay"],
|
56 |
-
)
|
57 |
-
warmup_steps = int(config["warmup_ratio"] * total_steps)
|
58 |
-
|
59 |
-
if config["lr_scheduler_type"] == "linear":
|
60 |
-
scheduler = get_linear_schedule_with_warmup(
|
61 |
-
optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps
|
62 |
-
)
|
63 |
-
elif config["lr_scheduler_type"] == "cosine":
|
64 |
-
scheduler = get_cosine_schedule_with_warmup(
|
65 |
-
optimizer,
|
66 |
-
num_warmup_steps=warmup_steps,
|
67 |
-
num_training_steps=total_steps,
|
68 |
-
num_cycles=0.5,
|
69 |
-
)
|
70 |
-
|
71 |
-
return optimizer, scheduler
|
72 |
-
|
73 |
-
|
74 |
-
def train_epoch(
|
75 |
-
model, train_loader, optimizer, scheduler, device, config, writer, epoch
|
76 |
-
):
|
77 |
-
model.train()
|
78 |
-
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']}")
|
79 |
-
for batch_idx, batch in enumerate(progress_bar):
|
80 |
-
optimizer.zero_grad()
|
81 |
-
input_ids = batch["input_ids"].to(device)
|
82 |
-
attention_mask = batch["attention_mask"].to(device)
|
83 |
-
labels = [
|
84 |
-
batch["labels"][task_name].to(device) for task_name in config["task_names"]
|
85 |
-
]
|
86 |
-
|
87 |
-
loss, _, _ = model(input_ids, attention_mask, labels)
|
88 |
-
loss.backward()
|
89 |
-
|
90 |
-
if config["gradient_clipping"]:
|
91 |
-
torch.nn.utils.clip_grad_norm_(model.parameters(), config["max_grad_norm"])
|
92 |
-
|
93 |
-
optimizer.step()
|
94 |
-
scheduler.step()
|
95 |
-
|
96 |
-
writer.add_scalar(
|
97 |
-
"Training Loss", loss.item(), epoch * len(train_loader) + batch_idx
|
98 |
-
)
|
99 |
-
if config.get("use_wandb", False):
|
100 |
-
import wandb
|
101 |
-
|
102 |
-
wandb.log({"Training Loss": loss.item()})
|
103 |
-
|
104 |
-
# Update progress bar
|
105 |
-
progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})
|
106 |
-
|
107 |
-
return loss.item() # Return the last batch loss
|
108 |
-
|
109 |
-
|
110 |
-
def validate_model(model, val_loader, device, config):
|
111 |
-
model.eval()
|
112 |
-
val_loss = 0.0
|
113 |
-
task_true_labels = {task_name: [] for task_name in config["task_names"]}
|
114 |
-
task_pred_labels = {task_name: [] for task_name in config["task_names"]}
|
115 |
-
task_pred_probs = {task_name: [] for task_name in config["task_names"]}
|
116 |
-
|
117 |
-
with torch.no_grad():
|
118 |
-
for batch in val_loader:
|
119 |
-
input_ids = batch["input_ids"].to(device)
|
120 |
-
attention_mask = batch["attention_mask"].to(device)
|
121 |
-
labels = [
|
122 |
-
batch["labels"][task_name].to(device)
|
123 |
-
for task_name in config["task_names"]
|
124 |
-
]
|
125 |
-
loss, logits, _ = model(input_ids, attention_mask, labels)
|
126 |
-
val_loss += loss.item()
|
127 |
-
|
128 |
-
for sample_idx in range(len(batch["input_ids"])):
|
129 |
-
for i, task_name in enumerate(config["task_names"]):
|
130 |
-
true_label = batch["labels"][task_name][sample_idx].item()
|
131 |
-
pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item()
|
132 |
-
pred_prob = (
|
133 |
-
torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy()
|
134 |
-
)
|
135 |
-
task_true_labels[task_name].append(true_label)
|
136 |
-
task_pred_labels[task_name].append(pred_label)
|
137 |
-
task_pred_probs[task_name].append(pred_prob)
|
138 |
-
|
139 |
-
val_loss /= len(val_loader)
|
140 |
-
return val_loss, task_true_labels, task_pred_labels, task_pred_probs
|
141 |
-
|
142 |
-
|
143 |
-
def log_metrics(task_metrics, val_loss, config, writer, epochs):
|
144 |
-
for task_name, metrics in task_metrics.items():
|
145 |
-
print(
|
146 |
-
f"{task_name} - Validation F1 Macro: {metrics['f1']:.4f}, Validation Accuracy: {metrics['accuracy']:.4f}"
|
147 |
-
)
|
148 |
-
if config.get("use_wandb", False):
|
149 |
-
import wandb
|
150 |
-
|
151 |
-
wandb.log(
|
152 |
-
{
|
153 |
-
f"{task_name} Validation F1 Macro": metrics["f1"],
|
154 |
-
f"{task_name} Validation Accuracy": metrics["accuracy"],
|
155 |
-
}
|
156 |
-
)
|
157 |
-
|
158 |
-
writer.add_scalar("Validation Loss", val_loss, epochs)
|
159 |
-
for task_name, metrics in task_metrics.items():
|
160 |
-
writer.add_scalar(f"{task_name} - Validation F1 Macro", metrics["f1"], epochs)
|
161 |
-
writer.add_scalar(
|
162 |
-
f"{task_name} - Validation Accuracy", metrics["accuracy"], epochs
|
163 |
-
)
|
164 |
-
|
165 |
-
|
166 |
-
def save_validation_predictions(
|
167 |
-
val_cell_id_mapping,
|
168 |
-
task_true_labels,
|
169 |
-
task_pred_labels,
|
170 |
-
task_pred_probs,
|
171 |
-
config,
|
172 |
-
trial_number=None,
|
173 |
-
):
|
174 |
-
if trial_number is not None:
|
175 |
-
trial_results_dir = os.path.join(config["results_dir"], f"trial_{trial_number}")
|
176 |
-
os.makedirs(trial_results_dir, exist_ok=True)
|
177 |
-
val_preds_file = os.path.join(trial_results_dir, "val_preds.csv")
|
178 |
-
else:
|
179 |
-
val_preds_file = os.path.join(config["results_dir"], "manual_run_val_preds.csv")
|
180 |
-
|
181 |
-
rows = []
|
182 |
-
for sample_idx in range(len(val_cell_id_mapping)):
|
183 |
-
row = {"Cell ID": val_cell_id_mapping[sample_idx]}
|
184 |
-
for task_name in config["task_names"]:
|
185 |
-
row[f"{task_name} True"] = task_true_labels[task_name][sample_idx]
|
186 |
-
row[f"{task_name} Pred"] = task_pred_labels[task_name][sample_idx]
|
187 |
-
row[f"{task_name} Probabilities"] = ",".join(
|
188 |
-
map(str, task_pred_probs[task_name][sample_idx])
|
189 |
-
)
|
190 |
-
rows.append(row)
|
191 |
-
|
192 |
-
df = pd.DataFrame(rows)
|
193 |
-
df.to_csv(val_preds_file, index=False)
|
194 |
-
print(f"Validation predictions saved to {val_preds_file}")
|
195 |
-
|
196 |
-
|
197 |
-
def train_model(
|
198 |
-
config,
|
199 |
-
device,
|
200 |
-
train_loader,
|
201 |
-
val_loader,
|
202 |
-
train_cell_id_mapping,
|
203 |
-
val_cell_id_mapping,
|
204 |
-
num_labels_list,
|
205 |
-
):
|
206 |
-
set_seed(config["seed"])
|
207 |
-
initialize_wandb(config)
|
208 |
-
|
209 |
-
model = create_model(config, num_labels_list, device)
|
210 |
-
total_steps = len(train_loader) * config["epochs"]
|
211 |
-
optimizer, scheduler = setup_optimizer_and_scheduler(model, config, total_steps)
|
212 |
-
|
213 |
-
log_dir = os.path.join(config["tensorboard_log_dir"], "manual_run")
|
214 |
-
writer = SummaryWriter(log_dir=log_dir)
|
215 |
-
|
216 |
-
epoch_progress = tqdm(range(config["epochs"]), desc="Training Progress")
|
217 |
-
for epoch in epoch_progress:
|
218 |
-
last_loss = train_epoch(
|
219 |
-
model, train_loader, optimizer, scheduler, device, config, writer, epoch
|
220 |
-
)
|
221 |
-
epoch_progress.set_postfix({"last_loss": f"{last_loss:.4f}"})
|
222 |
-
|
223 |
-
val_loss, task_true_labels, task_pred_labels, task_pred_probs = validate_model(
|
224 |
-
model, val_loader, device, config
|
225 |
-
)
|
226 |
-
task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels)
|
227 |
-
|
228 |
-
log_metrics(task_metrics, val_loss, config, writer, config["epochs"])
|
229 |
-
writer.close()
|
230 |
-
|
231 |
-
save_validation_predictions(
|
232 |
-
val_cell_id_mapping, task_true_labels, task_pred_labels, task_pred_probs, config
|
233 |
-
)
|
234 |
-
|
235 |
-
if config.get("use_wandb", False):
|
236 |
-
import wandb
|
237 |
-
|
238 |
-
wandb.finish()
|
239 |
-
|
240 |
-
print(f"\nFinal Validation Loss: {val_loss:.4f}")
|
241 |
-
return val_loss, model # Return both the validation loss and the trained model
|
242 |
-
|
243 |
-
|
244 |
-
def objective(
|
245 |
-
trial,
|
246 |
-
train_loader,
|
247 |
-
val_loader,
|
248 |
-
train_cell_id_mapping,
|
249 |
-
val_cell_id_mapping,
|
250 |
-
num_labels_list,
|
251 |
-
config,
|
252 |
-
device,
|
253 |
-
):
|
254 |
-
set_seed(config["seed"]) # Set the seed before each trial
|
255 |
-
initialize_wandb(config)
|
256 |
-
|
257 |
-
# Hyperparameters
|
258 |
-
config["learning_rate"] = trial.suggest_float(
|
259 |
-
"learning_rate",
|
260 |
-
config["hyperparameters"]["learning_rate"]["low"],
|
261 |
-
config["hyperparameters"]["learning_rate"]["high"],
|
262 |
-
log=config["hyperparameters"]["learning_rate"]["log"],
|
263 |
-
)
|
264 |
-
config["warmup_ratio"] = trial.suggest_float(
|
265 |
-
"warmup_ratio",
|
266 |
-
config["hyperparameters"]["warmup_ratio"]["low"],
|
267 |
-
config["hyperparameters"]["warmup_ratio"]["high"],
|
268 |
-
)
|
269 |
-
config["weight_decay"] = trial.suggest_float(
|
270 |
-
"weight_decay",
|
271 |
-
config["hyperparameters"]["weight_decay"]["low"],
|
272 |
-
config["hyperparameters"]["weight_decay"]["high"],
|
273 |
-
)
|
274 |
-
config["dropout_rate"] = trial.suggest_float(
|
275 |
-
"dropout_rate",
|
276 |
-
config["hyperparameters"]["dropout_rate"]["low"],
|
277 |
-
config["hyperparameters"]["dropout_rate"]["high"],
|
278 |
-
)
|
279 |
-
config["lr_scheduler_type"] = trial.suggest_categorical(
|
280 |
-
"lr_scheduler_type", config["hyperparameters"]["lr_scheduler_type"]["choices"]
|
281 |
-
)
|
282 |
-
config["use_attention_pooling"] = trial.suggest_categorical(
|
283 |
-
"use_attention_pooling", [False]
|
284 |
-
)
|
285 |
-
|
286 |
-
if config["use_task_weights"]:
|
287 |
-
config["task_weights"] = [
|
288 |
-
trial.suggest_float(
|
289 |
-
f"task_weight_{i}",
|
290 |
-
config["hyperparameters"]["task_weights"]["low"],
|
291 |
-
config["hyperparameters"]["task_weights"]["high"],
|
292 |
-
)
|
293 |
-
for i in range(len(num_labels_list))
|
294 |
-
]
|
295 |
-
weight_sum = sum(config["task_weights"])
|
296 |
-
config["task_weights"] = [
|
297 |
-
weight / weight_sum for weight in config["task_weights"]
|
298 |
-
]
|
299 |
-
else:
|
300 |
-
config["task_weights"] = None
|
301 |
-
|
302 |
-
# Dynamic range for max_layers_to_freeze
|
303 |
-
freeze_range = get_layer_freeze_range(config["pretrained_path"])
|
304 |
-
config["max_layers_to_freeze"] = trial.suggest_int(
|
305 |
-
"max_layers_to_freeze",
|
306 |
-
freeze_range["min"],
|
307 |
-
freeze_range["max"]
|
308 |
-
)
|
309 |
-
|
310 |
-
model = create_model(config, num_labels_list, device)
|
311 |
-
total_steps = len(train_loader) * config["epochs"]
|
312 |
-
optimizer, scheduler = setup_optimizer_and_scheduler(model, config, total_steps)
|
313 |
-
|
314 |
-
log_dir = os.path.join(config["tensorboard_log_dir"], f"trial_{trial.number}")
|
315 |
-
writer = SummaryWriter(log_dir=log_dir)
|
316 |
-
|
317 |
-
for epoch in range(config["epochs"]):
|
318 |
-
train_epoch(
|
319 |
-
model, train_loader, optimizer, scheduler, device, config, writer, epoch
|
320 |
-
)
|
321 |
-
|
322 |
-
val_loss, task_true_labels, task_pred_labels, task_pred_probs = validate_model(
|
323 |
-
model, val_loader, device, config
|
324 |
-
)
|
325 |
-
task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels)
|
326 |
-
|
327 |
-
log_metrics(task_metrics, val_loss, config, writer, config["epochs"])
|
328 |
-
writer.close()
|
329 |
-
|
330 |
-
save_validation_predictions(
|
331 |
-
val_cell_id_mapping,
|
332 |
-
task_true_labels,
|
333 |
-
task_pred_labels,
|
334 |
-
task_pred_probs,
|
335 |
-
config,
|
336 |
-
trial.number,
|
337 |
-
)
|
338 |
-
|
339 |
-
trial.set_user_attr("model_state_dict", model.state_dict())
|
340 |
-
trial.set_user_attr("task_weights", config["task_weights"])
|
341 |
-
|
342 |
-
trial.report(val_loss, config["epochs"])
|
343 |
-
|
344 |
-
if trial.should_prune():
|
345 |
-
raise optuna.TrialPruned()
|
346 |
-
|
347 |
-
if config.get("use_wandb", False):
|
348 |
-
import wandb
|
349 |
-
|
350 |
-
wandb.log(
|
351 |
-
{
|
352 |
-
"trial_number": trial.number,
|
353 |
-
"val_loss": val_loss,
|
354 |
-
**{
|
355 |
-
f"{task_name}_f1": metrics["f1"]
|
356 |
-
for task_name, metrics in task_metrics.items()
|
357 |
-
},
|
358 |
-
**{
|
359 |
-
f"{task_name}_accuracy": metrics["accuracy"]
|
360 |
-
for task_name, metrics in task_metrics.items()
|
361 |
-
},
|
362 |
-
**{
|
363 |
-
k: v
|
364 |
-
for k, v in config.items()
|
365 |
-
if k
|
366 |
-
in [
|
367 |
-
"learning_rate",
|
368 |
-
"warmup_ratio",
|
369 |
-
"weight_decay",
|
370 |
-
"dropout_rate",
|
371 |
-
"lr_scheduler_type",
|
372 |
-
"use_attention_pooling",
|
373 |
-
"max_layers_to_freeze",
|
374 |
-
]
|
375 |
-
},
|
376 |
-
}
|
377 |
-
)
|
378 |
-
wandb.finish()
|
379 |
-
|
380 |
-
return val_loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
geneformer/mtl/train_utils.py
DELETED
@@ -1,161 +0,0 @@
|
|
1 |
-
import random
|
2 |
-
|
3 |
-
from .data import get_data_loader, preload_and_process_data
|
4 |
-
from .imports import *
|
5 |
-
from .model import GeneformerMultiTask
|
6 |
-
from .train import objective, train_model
|
7 |
-
from .utils import save_model
|
8 |
-
|
9 |
-
|
10 |
-
def set_seed(seed):
|
11 |
-
random.seed(seed)
|
12 |
-
np.random.seed(seed)
|
13 |
-
torch.manual_seed(seed)
|
14 |
-
torch.cuda.manual_seed_all(seed)
|
15 |
-
torch.backends.cudnn.deterministic = True
|
16 |
-
torch.backends.cudnn.benchmark = False
|
17 |
-
|
18 |
-
|
19 |
-
def run_manual_tuning(config):
|
20 |
-
# Set seed for reproducibility
|
21 |
-
set_seed(config["seed"])
|
22 |
-
|
23 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
24 |
-
(
|
25 |
-
train_dataset,
|
26 |
-
train_cell_id_mapping,
|
27 |
-
val_dataset,
|
28 |
-
val_cell_id_mapping,
|
29 |
-
num_labels_list,
|
30 |
-
) = preload_and_process_data(config)
|
31 |
-
train_loader = get_data_loader(train_dataset, config["batch_size"])
|
32 |
-
val_loader = get_data_loader(val_dataset, config["batch_size"])
|
33 |
-
|
34 |
-
# Print the manual hyperparameters being used
|
35 |
-
print("\nManual hyperparameters being used:")
|
36 |
-
for key, value in config["manual_hyperparameters"].items():
|
37 |
-
print(f"{key}: {value}")
|
38 |
-
print() # Add an empty line for better readability
|
39 |
-
|
40 |
-
# Use the manual hyperparameters
|
41 |
-
for key, value in config["manual_hyperparameters"].items():
|
42 |
-
config[key] = value
|
43 |
-
|
44 |
-
# Train the model
|
45 |
-
val_loss, trained_model = train_model(
|
46 |
-
config,
|
47 |
-
device,
|
48 |
-
train_loader,
|
49 |
-
val_loader,
|
50 |
-
train_cell_id_mapping,
|
51 |
-
val_cell_id_mapping,
|
52 |
-
num_labels_list,
|
53 |
-
)
|
54 |
-
|
55 |
-
print(f"\nValidation loss with manual hyperparameters: {val_loss}")
|
56 |
-
|
57 |
-
# Save the trained model
|
58 |
-
model_save_directory = os.path.join(
|
59 |
-
config["model_save_path"], "GeneformerMultiTask"
|
60 |
-
)
|
61 |
-
save_model(trained_model, model_save_directory)
|
62 |
-
|
63 |
-
# Save the hyperparameters
|
64 |
-
hyperparams_to_save = {
|
65 |
-
**config["manual_hyperparameters"],
|
66 |
-
"dropout_rate": config["dropout_rate"],
|
67 |
-
"use_task_weights": config["use_task_weights"],
|
68 |
-
"task_weights": config["task_weights"],
|
69 |
-
"max_layers_to_freeze": config["max_layers_to_freeze"],
|
70 |
-
"use_attention_pooling": config["use_attention_pooling"],
|
71 |
-
}
|
72 |
-
hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json")
|
73 |
-
with open(hyperparams_path, "w") as f:
|
74 |
-
json.dump(hyperparams_to_save, f)
|
75 |
-
print(f"Manual hyperparameters saved to {hyperparams_path}")
|
76 |
-
|
77 |
-
return val_loss
|
78 |
-
|
79 |
-
|
80 |
-
def run_optuna_study(config):
|
81 |
-
# Set seed for reproducibility
|
82 |
-
set_seed(config["seed"])
|
83 |
-
|
84 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
85 |
-
(
|
86 |
-
train_dataset,
|
87 |
-
train_cell_id_mapping,
|
88 |
-
val_dataset,
|
89 |
-
val_cell_id_mapping,
|
90 |
-
num_labels_list,
|
91 |
-
) = preload_and_process_data(config)
|
92 |
-
train_loader = get_data_loader(train_dataset, config["batch_size"])
|
93 |
-
val_loader = get_data_loader(val_dataset, config["batch_size"])
|
94 |
-
|
95 |
-
if config["use_manual_hyperparameters"]:
|
96 |
-
train_model(
|
97 |
-
config,
|
98 |
-
device,
|
99 |
-
train_loader,
|
100 |
-
val_loader,
|
101 |
-
train_cell_id_mapping,
|
102 |
-
val_cell_id_mapping,
|
103 |
-
num_labels_list,
|
104 |
-
)
|
105 |
-
else:
|
106 |
-
objective_with_config_and_data = functools.partial(
|
107 |
-
objective,
|
108 |
-
train_loader=train_loader,
|
109 |
-
val_loader=val_loader,
|
110 |
-
train_cell_id_mapping=train_cell_id_mapping,
|
111 |
-
val_cell_id_mapping=val_cell_id_mapping,
|
112 |
-
num_labels_list=num_labels_list,
|
113 |
-
config=config,
|
114 |
-
device=device,
|
115 |
-
)
|
116 |
-
|
117 |
-
study = optuna.create_study(
|
118 |
-
direction="minimize", # Minimize validation loss
|
119 |
-
study_name=config["study_name"],
|
120 |
-
# storage=config["storage"],
|
121 |
-
load_if_exists=True,
|
122 |
-
)
|
123 |
-
|
124 |
-
study.optimize(objective_with_config_and_data, n_trials=config["n_trials"])
|
125 |
-
|
126 |
-
# After finding the best trial
|
127 |
-
best_params = study.best_trial.params
|
128 |
-
best_task_weights = study.best_trial.user_attrs["task_weights"]
|
129 |
-
print("Saving the best model and its hyperparameters...")
|
130 |
-
|
131 |
-
# Saving model as before
|
132 |
-
best_model = GeneformerMultiTask(
|
133 |
-
config["pretrained_path"],
|
134 |
-
num_labels_list,
|
135 |
-
dropout_rate=best_params["dropout_rate"],
|
136 |
-
use_task_weights=config["use_task_weights"],
|
137 |
-
task_weights=best_task_weights,
|
138 |
-
)
|
139 |
-
|
140 |
-
# Get the best model state dictionary
|
141 |
-
best_model_state_dict = study.best_trial.user_attrs["model_state_dict"]
|
142 |
-
|
143 |
-
# Remove the "module." prefix from the state dictionary keys if present
|
144 |
-
best_model_state_dict = {
|
145 |
-
k.replace("module.", ""): v for k, v in best_model_state_dict.items()
|
146 |
-
}
|
147 |
-
|
148 |
-
# Load the modified state dictionary into the model, skipping unexpected keys
|
149 |
-
best_model.load_state_dict(best_model_state_dict, strict=False)
|
150 |
-
|
151 |
-
model_save_directory = os.path.join(
|
152 |
-
config["model_save_path"], "GeneformerMultiTask"
|
153 |
-
)
|
154 |
-
save_model(best_model, model_save_directory)
|
155 |
-
|
156 |
-
# Additionally, save the best hyperparameters and task weights
|
157 |
-
hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json")
|
158 |
-
|
159 |
-
with open(hyperparams_path, "w") as f:
|
160 |
-
json.dump({**best_params, "task_weights": best_task_weights}, f)
|
161 |
-
print(f"Best hyperparameters and task weights saved to {hyperparams_path}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
geneformer/mtl/utils.py
DELETED
@@ -1,129 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import shutil
|
3 |
-
|
4 |
-
from sklearn.metrics import accuracy_score, f1_score
|
5 |
-
from sklearn.preprocessing import LabelEncoder
|
6 |
-
from transformers import AutoConfig, BertConfig, BertModel
|
7 |
-
|
8 |
-
from .imports import *
|
9 |
-
|
10 |
-
|
11 |
-
def save_model(model, model_save_directory):
|
12 |
-
if not os.path.exists(model_save_directory):
|
13 |
-
os.makedirs(model_save_directory)
|
14 |
-
|
15 |
-
# Get the state dict
|
16 |
-
if isinstance(model, nn.DataParallel):
|
17 |
-
model_state_dict = (
|
18 |
-
model.module.state_dict()
|
19 |
-
) # Use model.module to access the underlying model
|
20 |
-
else:
|
21 |
-
model_state_dict = model.state_dict()
|
22 |
-
|
23 |
-
# Remove the "module." prefix from the keys if present
|
24 |
-
model_state_dict = {
|
25 |
-
k.replace("module.", ""): v for k, v in model_state_dict.items()
|
26 |
-
}
|
27 |
-
|
28 |
-
model_save_path = os.path.join(model_save_directory, "pytorch_model.bin")
|
29 |
-
torch.save(model_state_dict, model_save_path)
|
30 |
-
|
31 |
-
# Save the model configuration
|
32 |
-
if isinstance(model, nn.DataParallel):
|
33 |
-
model.module.config.to_json_file(
|
34 |
-
os.path.join(model_save_directory, "config.json")
|
35 |
-
)
|
36 |
-
else:
|
37 |
-
model.config.to_json_file(os.path.join(model_save_directory, "config.json"))
|
38 |
-
|
39 |
-
print(f"Model and configuration saved to {model_save_directory}")
|
40 |
-
|
41 |
-
|
42 |
-
def calculate_task_specific_metrics(task_true_labels, task_pred_labels):
|
43 |
-
task_metrics = {}
|
44 |
-
for task_name in task_true_labels.keys():
|
45 |
-
true_labels = task_true_labels[task_name]
|
46 |
-
pred_labels = task_pred_labels[task_name]
|
47 |
-
f1 = f1_score(true_labels, pred_labels, average="macro")
|
48 |
-
accuracy = accuracy_score(true_labels, pred_labels)
|
49 |
-
task_metrics[task_name] = {"f1": f1, "accuracy": accuracy}
|
50 |
-
return task_metrics
|
51 |
-
|
52 |
-
|
53 |
-
def calculate_combined_f1(combined_labels, combined_preds):
|
54 |
-
# Initialize the LabelEncoder
|
55 |
-
le = LabelEncoder()
|
56 |
-
|
57 |
-
# Fit and transform combined labels and predictions to numerical values
|
58 |
-
le.fit(combined_labels + combined_preds)
|
59 |
-
encoded_true_labels = le.transform(combined_labels)
|
60 |
-
encoded_pred_labels = le.transform(combined_preds)
|
61 |
-
|
62 |
-
# Print out the mapping for sanity check
|
63 |
-
print("\nLabel Encoder Mapping:")
|
64 |
-
for index, class_label in enumerate(le.classes_):
|
65 |
-
print(f"'{class_label}': {index}")
|
66 |
-
|
67 |
-
# Calculate accuracy
|
68 |
-
accuracy = accuracy_score(encoded_true_labels, encoded_pred_labels)
|
69 |
-
|
70 |
-
# Calculate F1 Macro score
|
71 |
-
f1 = f1_score(encoded_true_labels, encoded_pred_labels, average="macro")
|
72 |
-
|
73 |
-
return f1, accuracy
|
74 |
-
|
75 |
-
|
76 |
-
# def save_model_without_heads(original_model_save_directory):
|
77 |
-
# # Create a new directory for the model without heads
|
78 |
-
# new_model_save_directory = original_model_save_directory + "_No_Heads"
|
79 |
-
# if not os.path.exists(new_model_save_directory):
|
80 |
-
# os.makedirs(new_model_save_directory)
|
81 |
-
|
82 |
-
# # Load the model state dictionary
|
83 |
-
# model_state_dict = torch.load(
|
84 |
-
# os.path.join(original_model_save_directory, "pytorch_model.bin")
|
85 |
-
# )
|
86 |
-
|
87 |
-
# # Initialize a new BERT model without the classification heads
|
88 |
-
# config = BertConfig.from_pretrained(
|
89 |
-
# os.path.join(original_model_save_directory, "config.json")
|
90 |
-
# )
|
91 |
-
# model_without_heads = BertModel(config)
|
92 |
-
|
93 |
-
# # Filter the state dict to exclude classification heads
|
94 |
-
# model_without_heads_state_dict = {
|
95 |
-
# k: v
|
96 |
-
# for k, v in model_state_dict.items()
|
97 |
-
# if not k.startswith("classification_heads")
|
98 |
-
# }
|
99 |
-
|
100 |
-
# # Load the filtered state dict into the model
|
101 |
-
# model_without_heads.load_state_dict(model_without_heads_state_dict, strict=False)
|
102 |
-
|
103 |
-
# # Save the model without heads
|
104 |
-
# model_save_path = os.path.join(new_model_save_directory, "pytorch_model.bin")
|
105 |
-
# torch.save(model_without_heads.state_dict(), model_save_path)
|
106 |
-
|
107 |
-
# # Copy the configuration file
|
108 |
-
# shutil.copy(
|
109 |
-
# os.path.join(original_model_save_directory, "config.json"),
|
110 |
-
# new_model_save_directory,
|
111 |
-
# )
|
112 |
-
|
113 |
-
# print(f"Model without classification heads saved to {new_model_save_directory}")
|
114 |
-
|
115 |
-
|
116 |
-
def get_layer_freeze_range(pretrained_path):
|
117 |
-
"""
|
118 |
-
Dynamically determines the number of layers to freeze based on the model depth from its configuration.
|
119 |
-
Args:
|
120 |
-
pretrained_path (str): Path to the pretrained model directory or model identifier.
|
121 |
-
Returns:
|
122 |
-
dict: A dictionary with 'min' and 'max' keys indicating the range of layers to freeze.
|
123 |
-
"""
|
124 |
-
if pretrained_path:
|
125 |
-
config = AutoConfig.from_pretrained(pretrained_path)
|
126 |
-
total_layers = config.num_hidden_layers
|
127 |
-
return {"min": 0, "max": total_layers - 1}
|
128 |
-
else:
|
129 |
-
return {"min": 0, "max": 0}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|