This view is limited to 50 files because it contains too many changes.  See the raw diff here.
Files changed (50) hide show
  1. .gitattributes +2 -2
  2. MANIFEST.in +3 -4
  3. README.md +8 -27
  4. config.json +8 -9
  5. docs/source/about.rst +5 -9
  6. docs/source/api.rst +0 -8
  7. docs/source/geneformer.mtl_classifier.rst +0 -11
  8. docs/source/geneformer.tokenizer.rst +1 -2
  9. docs/source/index.rst +1 -1
  10. examples/multitask_cell_classification.ipynb +0 -420
  11. examples/pretraining_new_model/pretrain_geneformer_w_deepspeed.py +1 -3
  12. examples/tokenizing_scRNAseq_data.ipynb +1 -5
  13. fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/config.json +0 -0
  14. fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/optimizer.pt +0 -0
  15. fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/pytorch_model.bin +0 -0
  16. fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/rng_state.pth +0 -0
  17. fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/scheduler.pt +0 -0
  18. fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/trainer_state.json +0 -0
  19. fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/training_args.bin +0 -0
  20. fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/config.json +0 -24
  21. fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/pytorch_model.bin +0 -3
  22. {gf-12L-30M-i2048 → geneformer-12L-30M}/config.json +0 -0
  23. {gf-12L-30M-i2048 → geneformer-12L-30M}/pytorch_model.bin +0 -0
  24. {gf-12L-30M-i2048 → geneformer-12L-30M}/training_args.bin +0 -0
  25. geneformer/__init__.py +1 -14
  26. geneformer/classifier.py +56 -189
  27. geneformer/classifier_utils.py +35 -258
  28. geneformer/collator_for_classification.py +74 -139
  29. geneformer/emb_extractor.py +50 -101
  30. geneformer/ensembl_mapping_dict_gc95M.pkl +0 -3
  31. geneformer/evaluation_utils.py +5 -5
  32. geneformer/gene_dictionaries_30m/ensembl_mapping_dict_gc30M.pkl +0 -3
  33. geneformer/gene_dictionaries_30m/gene_median_dictionary_gc30M.pkl +0 -3
  34. geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl +0 -3
  35. geneformer/gene_median_dictionary.pkl +0 -0
  36. geneformer/gene_median_dictionary_gc95M.pkl +0 -3
  37. geneformer/{gene_dictionaries_30m/gene_name_id_dict_gc30M.pkl → gene_name_id_dict.pkl} +0 -0
  38. geneformer/gene_name_id_dict_gc95M.pkl +0 -3
  39. geneformer/in_silico_perturber.py +136 -776
  40. geneformer/in_silico_perturber_stats.py +26 -76
  41. geneformer/mtl/__init__.py +0 -1
  42. geneformer/mtl/collators.py +0 -76
  43. geneformer/mtl/data.py +0 -150
  44. geneformer/mtl/eval_utils.py +0 -88
  45. geneformer/mtl/imports.py +0 -43
  46. geneformer/mtl/model.py +0 -121
  47. geneformer/mtl/optuna_utils.py +0 -27
  48. geneformer/mtl/train.py +0 -380
  49. geneformer/mtl/train_utils.py +0 -161
  50. geneformer/mtl/utils.py +0 -129
.gitattributes CHANGED
@@ -14,11 +14,10 @@
14
  *.ot filter=lfs diff=lfs merge=lfs -text
15
  *.parquet filter=lfs diff=lfs merge=lfs -text
16
  *.pb filter=lfs diff=lfs merge=lfs -text
17
- *.pkl filter=lfs diff=lfs merge=lfs -text
18
  *.pt filter=lfs diff=lfs merge=lfs -text
19
  *.pth filter=lfs diff=lfs merge=lfs -text
20
  *.rar filter=lfs diff=lfs merge=lfs -text
21
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
22
  *.tar.* filter=lfs diff=lfs merge=lfs -text
23
  *.tflite filter=lfs diff=lfs merge=lfs -text
24
  *.tgz filter=lfs diff=lfs merge=lfs -text
@@ -26,4 +25,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
26
  *.zip filter=lfs diff=lfs merge=lfs -text
27
  *.zstandard filter=lfs diff=lfs merge=lfs -text
28
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
29
  model.safetensors filter=lfs diff=lfs merge=lfs -text
 
14
  *.ot filter=lfs diff=lfs merge=lfs -text
15
  *.parquet filter=lfs diff=lfs merge=lfs -text
16
  *.pb filter=lfs diff=lfs merge=lfs -text
 
17
  *.pt filter=lfs diff=lfs merge=lfs -text
18
  *.pth filter=lfs diff=lfs merge=lfs -text
19
  *.rar filter=lfs diff=lfs merge=lfs -text
20
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
  *.tar.* filter=lfs diff=lfs merge=lfs -text
22
  *.tflite filter=lfs diff=lfs merge=lfs -text
23
  *.tgz filter=lfs diff=lfs merge=lfs -text
 
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ geneformer/gene_name_id_dict.pkl filter=lfs diff=lfs merge=lfs -text
29
  model.safetensors filter=lfs diff=lfs merge=lfs -text
MANIFEST.in CHANGED
@@ -1,4 +1,3 @@
1
- include geneformer/gene_median_dictionary_gc95M.pkl
2
- include geneformer/gene_name_id_dict_gc95M.pkl
3
- include geneformer/ensembl_mapping_dict_gc95M.pkl
4
- include geneformer/token_dictionary_gc95M.pkl
 
1
+ include geneformer/gene_median_dictionary.pkl
2
+ include geneformer/token_dictionary.pkl
3
+ include geneformer/gene_name_id_dict.pkl
 
README.md CHANGED
@@ -3,38 +3,23 @@ datasets: ctheodoris/Genecorpus-30M
3
  license: apache-2.0
4
  ---
5
  # Geneformer
6
- Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes to enable context-aware predictions in settings with limited data in network biology.
7
 
8
- - See [our manuscript](https://rdcu.be/ddrx0) for details of the original model trained on ~30 million transcriptomes in June 2021 and the initial report of our in silico perturbation and cell and gene classification strategies.
9
- - See [our manuscript](https://www.biorxiv.org/content/10.1101/2024.08.16.608180v1.full.pdf) for details of the expanded model trained on ~95 million transcriptomes in April 2024 and our continual learning, multitask learning, and quantization strategies.
10
  - See [geneformer.readthedocs.io](https://geneformer.readthedocs.io) for documentation.
11
 
12
  # Model Description
13
- Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes representing a broad range of human tissues. Geneformer was originally pretrained in June 2021 on [Genecorpus-30M](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M), a corpus comprised of ~30 million single cell transcriptomes. We excluded cells with high mutational burdens (e.g. malignant cells and immortalized cell lines) that could lead to substantial network rewiring without companion genome sequencing to facilitate interpretation. Then, in April 2024, Geneformer was pretrained on ~95 million non-cancer transcriptomes, followed by continual learning on ~14 million cancer transcriptomes to yield a cancer domain-tuned model.
14
 
15
- Each single cell’s transcriptome is presented to the model as a rank value encoding where genes are ranked by their expression in that cell scaled by their expression across the entire Genecorpus-30M. The rank value encoding provides a nonparametric representation of that cell’s transcriptome and takes advantage of the many observations of each gene’s expression across the pretraining corpus to prioritize genes that distinguish cell state. Specifically, this method will deprioritize ubiquitously highly-expressed housekeeping genes by scaling them to a lower rank. Conversely, genes such as transcription factors that may be lowly expressed when they are expressed but highly distinguish cell state will move to a higher rank within the encoding. Furthermore, this rank-based approach may be more robust against technical artifacts that may systematically bias the absolute transcript counts value while the overall relative ranking of genes within each cell remains more stable.
16
-
17
- The rank value encoding of each single cell’s transcriptome then proceeds through N layers of transformer encoder units, where N varies dependent on the model size. Pretraining was accomplished using a masked learning objective where 15% of the genes within each transcriptome were masked and the model was trained to predict which gene should be within each masked position in that specific cell state using the context of the remaining unmasked genes. A major strength of this approach is that it is entirely self-supervised and can be accomplished on completely unlabeled data, which allows the inclusion of large amounts of training data without being restricted to samples with accompanying labels.
18
 
19
  We detail applications and results in [our manuscript](https://rdcu.be/ddrx0).
20
 
21
- During pretraining, Geneformer gained a fundamental understanding of network dynamics, encoding network hierarchy in the model’s attention weights in a completely self-supervised manner. With both zero-shot learning and fine-tuning with limited task-specific data, Geneformer consistently boosted predictive accuracy in a diverse panel of downstream tasks relevant to chromatin and network dynamics. In silico perturbation with zero-shot learning identified a novel transcription factor in cardiomyocytes that we experimentally validated to be critical to their ability to generate contractile force. In silico treatment with limited patient data revealed candidate therapeutic targets for cardiomyopathy that we experimentally validated to significantly improve the ability of cardiomyocytes to generate contractile force in an induced pluripotent stem cell (iPSC) model of the disease. Overall, Geneformer represents a foundational deep learning model pretrained on a large-scale corpus human single cell transcriptomes to gain a fundamental understanding of gene network dynamics that can now be democratized to a vast array of downstream tasks to accelerate discovery of key network regulators and candidate therapeutic targets.
22
-
23
- The repository includes the following pretrained models:
24
-
25
- L=layers\
26
- M=millions of cells used for pretraining\
27
- i=input size\
28
- (pretraining date)
29
 
30
- - GF-6L-30M-i2048 (June 2021)
31
- - GF-12L-30M-i2048 (June 2021)
32
- - GF-12L-95M-i4096 (April 2024)
33
- - GF-20L-95M-i4096 (April 2024)
34
 
35
- The current default model in the main directory of the repository is GF-12L-95M-i4096.
36
-
37
- The repository also contains fined tuned models in the fine_tuned_models directory and the cancer-tuned model following continual learning on ~14 million cancer cells, GF-12L-95M-i4096_CLcancer.
38
 
39
  # Application
40
  The pretrained Geneformer model can be used directly for zero-shot learning, for example for in silico perturbation analysis, or by fine-tuning towards the relevant downstream task, such as gene or cell state classification.
@@ -64,7 +49,7 @@ Example applications demonstrated in [our manuscript](https://rdcu.be/ddrx0) inc
64
  - in silico perturbation to determine transcription factor cooperativity
65
 
66
  # Installation
67
- In addition to the pretrained model, contained herein are functions for tokenizing and collating data specific to single cell transcriptomics, pretraining the model, fine-tuning the model, extracting and plotting cell embeddings, and performing in silico pertrubation with either the pretrained or fine-tuned models. To install (~20s):
68
 
69
  ```bash
70
  # Make sure you have git-lfs installed (https://git-lfs.com)
@@ -85,7 +70,3 @@ For usage, see [examples](https://huggingface.co/ctheodoris/Geneformer/tree/main
85
  Please note that the fine-tuning examples are meant to be generally applicable and the input datasets and labels will vary dependent on the downstream task. Example input files for a few of the downstream tasks demonstrated in the manuscript are located within the [example_input_files directory](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files) in the dataset repository, but these only represent a few example fine-tuning applications.
86
 
87
  Please note that GPU resources are required for efficient usage of Geneformer. Additionally, we strongly recommend tuning hyperparameters for each downstream fine-tuning application as this can significantly boost predictive potential in the downstream task (e.g. max learning rate, learning schedule, number of layers to freeze, etc.).
88
-
89
- # Citations
90
- - C V Theodoris#, L Xiao, A Chopra, M D Chaffin, Z R Al Sayed, M C Hill, H Mantineo, E Brydon, Z Zeng, X S Liu, P T Ellinor#. Transfer learning enables predictions in network biology. _**Nature**_, 31 May 2023. (#co-corresponding authors)
91
- - H Chen*, M S Venkatesh*, J Gomez Ortega, S V Mahesh, T Nandi, R Madduri, K Pelka†, C V Theodoris†#. Quantized multi-task learning for context-specific representations of gene network dynamics. _**bioRxiv**_, 19 Aug 2024. (*co-first authors, †co-senior authors, #corresponding author)
 
3
  license: apache-2.0
4
  ---
5
  # Geneformer
6
+ Geneformer is a foundation transformer model pretrained on a large-scale corpus of ~30 million single cell transcriptomes to enable context-aware predictions in settings with limited data in network biology.
7
 
8
+ - See [our manuscript](https://rdcu.be/ddrx0) for details.
 
9
  - See [geneformer.readthedocs.io](https://geneformer.readthedocs.io) for documentation.
10
 
11
  # Model Description
12
+ Geneformer is a foundation transformer model pretrained on [Genecorpus-30M](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M), a pretraining corpus comprised of ~30 million single cell transcriptomes from a broad range of human tissues. We excluded cells with high mutational burdens (e.g. malignant cells and immortalized cell lines) that could lead to substantial network rewiring without companion genome sequencing to facilitate interpretation. Each single cell’s transcriptome is presented to the model as a rank value encoding where genes are ranked by their expression in that cell normalized by their expression across the entire Genecorpus-30M. The rank value encoding provides a nonparametric representation of that cell’s transcriptome and takes advantage of the many observations of each gene’s expression across Genecorpus-30M to prioritize genes that distinguish cell state. Specifically, this method will deprioritize ubiquitously highly-expressed housekeeping genes by normalizing them to a lower rank. Conversely, genes such as transcription factors that may be lowly expressed when they are expressed but highly distinguish cell state will move to a higher rank within the encoding. Furthermore, this rank-based approach may be more robust against technical artifacts that may systematically bias the absolute transcript counts value while the overall relative ranking of genes within each cell remains more stable.
13
 
14
+ The rank value encoding of each single cell��s transcriptome then proceeds through six transformer encoder units. Pretraining was accomplished using a masked learning objective where 15% of the genes within each transcriptome were masked and the model was trained to predict which gene should be within each masked position in that specific cell state using the context of the remaining unmasked genes. A major strength of this approach is that it is entirely self-supervised and can be accomplished on completely unlabeled data, which allows the inclusion of large amounts of training data without being restricted to samples with accompanying labels.
 
 
15
 
16
  We detail applications and results in [our manuscript](https://rdcu.be/ddrx0).
17
 
18
+ During pretraining, Geneformer gained a fundamental understanding of network dynamics, encoding network hierarchy in the model’s attention weights in a completely self-supervised manner. With both zero-shot learning and fine-tuning with limited task-specific data, Geneformer consistently boosted predictive accuracy in a diverse panel of downstream tasks relevant to chromatin and network dynamics. In silico perturbation with zero-shot learning identified a novel transcription factor in cardiomyocytes that we experimentally validated to be critical to their ability to generate contractile force. In silico treatment with limited patient data revealed candidate therapeutic targets for cardiomyopathy that we experimentally validated to significantly improve the ability of cardiomyocytes to generate contractile force in an iPSC model of the disease. Overall, Geneformer represents a foundational deep learning model pretrained on ~30 million human single cell transcriptomes to gain a fundamental understanding of gene network dynamics that can now be democratized to a vast array of downstream tasks to accelerate discovery of key network regulators and candidate therapeutic targets.
 
 
 
 
 
 
 
19
 
20
+ In [our manuscript](https://rdcu.be/ddrx0), we report results for the 6 layer Geneformer model pretrained on Genecorpus-30M. We additionally provide within this repository a 12 layer Geneformer model, scaled up with retained width:depth aspect ratio, also pretrained on Genecorpus-30M.
 
 
 
21
 
22
+ Both the 6 and 12 layer Geneformer models were pretrained in June 2021.
 
 
23
 
24
  # Application
25
  The pretrained Geneformer model can be used directly for zero-shot learning, for example for in silico perturbation analysis, or by fine-tuning towards the relevant downstream task, such as gene or cell state classification.
 
49
  - in silico perturbation to determine transcription factor cooperativity
50
 
51
  # Installation
52
+ In addition to the pretrained model, contained herein are functions for tokenizing and collating data specific to single cell transcriptomics, pretraining the model, fine-tuning the model, extracting and plotting cell embeddings, and performing in silico pertrubation with either the pretrained or fine-tuned models. To install:
53
 
54
  ```bash
55
  # Make sure you have git-lfs installed (https://git-lfs.com)
 
70
  Please note that the fine-tuning examples are meant to be generally applicable and the input datasets and labels will vary dependent on the downstream task. Example input files for a few of the downstream tasks demonstrated in the manuscript are located within the [example_input_files directory](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files) in the dataset repository, but these only represent a few example fine-tuning applications.
71
 
72
  Please note that GPU resources are required for efficient usage of Geneformer. Additionally, we strongly recommend tuning hyperparameters for each downstream fine-tuning application as this can significantly boost predictive potential in the downstream task (e.g. max learning rate, learning schedule, number of layers to freeze, etc.).
 
 
 
 
config.json CHANGED
@@ -3,22 +3,21 @@
3
  "BertForMaskedLM"
4
  ],
5
  "attention_probs_dropout_prob": 0.02,
6
- "classifier_dropout": null,
7
  "hidden_act": "relu",
8
  "hidden_dropout_prob": 0.02,
9
- "hidden_size": 512,
10
  "initializer_range": 0.02,
11
- "intermediate_size": 1024,
12
  "layer_norm_eps": 1e-12,
13
- "max_position_embeddings": 4096,
14
  "model_type": "bert",
15
- "num_attention_heads": 8,
16
- "num_hidden_layers": 12,
17
  "pad_token_id": 0,
18
  "position_embedding_type": "absolute",
19
- "torch_dtype": "float32",
20
- "transformers_version": "4.37.1",
21
  "type_vocab_size": 2,
22
  "use_cache": true,
23
- "vocab_size": 20275
24
  }
 
3
  "BertForMaskedLM"
4
  ],
5
  "attention_probs_dropout_prob": 0.02,
6
+ "gradient_checkpointing": false,
7
  "hidden_act": "relu",
8
  "hidden_dropout_prob": 0.02,
9
+ "hidden_size": 256,
10
  "initializer_range": 0.02,
11
+ "intermediate_size": 512,
12
  "layer_norm_eps": 1e-12,
13
+ "max_position_embeddings": 2048,
14
  "model_type": "bert",
15
+ "num_attention_heads": 4,
16
+ "num_hidden_layers": 6,
17
  "pad_token_id": 0,
18
  "position_embedding_type": "absolute",
19
+ "transformers_version": "4.6.0",
 
20
  "type_vocab_size": 2,
21
  "use_cache": true,
22
+ "vocab_size": 25426
23
  }
docs/source/about.rst CHANGED
@@ -4,13 +4,11 @@ About
4
  Model Description
5
  -----------------
6
 
7
- **Geneformer** is a context-aware, attention-based deep learning model pretrained on a large-scale corpus of single-cell transcriptomes to enable context-specific predictions in settings with limited data in network biology. During pretraining, Geneformer gained a fundamental understanding of network dynamics, encoding network hierarchy in the attention weights of the model in a completely self-supervised manner. With both zero-shot learning and fine-tuning with limited task-specific data, Geneformer consistently boosted predictive accuracy in a diverse panel of downstream tasks relevant to chromatin and network dynamics. In silico perturbation with zero-shot learning identified a novel transcription factor in cardiomyocytes that we experimentally validated to be critical to their ability to generate contractile force. In silico treatment with limited patient data revealed candidate therapeutic targets for cardiomyopathy that we experimentally validated to significantly improve the ability of cardiomyocytes to generate contractile force in an iPSC model of the disease. Overall, Geneformer represents a foundational deep learning model pretrained on a large-scale corpus of human single cell transcriptomes to gain a fundamental understanding of gene network dynamics that can now be democratized to a vast array of downstream tasks to accelerate discovery of key network regulators and candidate therapeutic targets.
8
 
9
- In `our manuscript <https://rdcu.be/ddrx0>`_, we report results for the original 6 layer Geneformer model pretrained on Genecorpus-30M. We additionally provide within the repository a 12 layer Geneformer model, scaled up with retained width:depth aspect ratio, also pretrained on Genecorpus-30M.
10
 
11
- Both the `6 <https://huggingface.co/ctheodoris/Geneformer/blob/main/gf-6L-30M-i2048/model.safetensors>`_ and `12 <https://huggingface.co/ctheodoris/Geneformer/blob/main/gf-12L-30M-i2048/pytorch_model.bin>`_ layer Geneformer models were pretrained in June 2021.
12
-
13
- Also see `our 2024 manuscript <https://www.biorxiv.org/content/10.1101/2024.08.16.608180v1.full.pdf>`_, for details of the `expanded model <https://huggingface.co/ctheodoris/Geneformer/blob/main/model.safetensors>`_ trained on ~95 million transcriptomes in April 2024 and our continual learning, multitask learning, and quantization strategies.
14
 
15
  Application
16
  -----------
@@ -41,9 +39,7 @@ Example applications demonstrated in `our manuscript <https://rdcu.be/ddrx0>`_ i
41
  | - in silico perturbation to determine transcription factor targets
42
  | - in silico perturbation to determine transcription factor cooperativity
43
 
44
- Citations
45
- ---------
46
 
47
  | C V Theodoris #, L Xiao, A Chopra, M D Chaffin, Z R Al Sayed, M C Hill, H Mantineo, E Brydon, Z Zeng, X S Liu, P T Ellinor #. `Transfer learning enables predictions in network biology. <https://rdcu.be/ddrx0>`_ *Nature*, 31 May 2023. (# co-corresponding authors)
48
-
49
- | H Chen \*, M S Venkatesh \*, J Gomez Ortega, S V Mahesh, T Nandi, R Madduri, K Pelka †, C V Theodoris † #. `Quantized multi-task learning for context-specific representations of gene network dynamics. <https://www.biorxiv.org/content/10.1101/2024.08.16.608180v1.full.pdf>`_ *bioRxiv*, 19 Aug 2024. (\* co-first authors, † co-senior authors, # corresponding author)
 
4
  Model Description
5
  -----------------
6
 
7
+ **Geneformer** is a context-aware, attention-based deep learning model pretrained on a large-scale corpus of ~30 million single-cell transcriptomes to enable context-specific predictions in settings with limited data in network biology. During pretraining, Geneformer gained a fundamental understanding of network dynamics, encoding network hierarchy in the attention weights of the model in a completely self-supervised manner. With both zero-shot learning and fine-tuning with limited task-specific data, Geneformer consistently boosted predictive accuracy in a diverse panel of downstream tasks relevant to chromatin and network dynamics. In silico perturbation with zero-shot learning identified a novel transcription factor in cardiomyocytes that we experimentally validated to be critical to their ability to generate contractile force. In silico treatment with limited patient data revealed candidate therapeutic targets for cardiomyopathy that we experimentally validated to significantly improve the ability of cardiomyocytes to generate contractile force in an iPSC model of the disease. Overall, Geneformer represents a foundational deep learning model pretrained on ~30 million human single cell transcriptomes to gain a fundamental understanding of gene network dynamics that can now be democratized to a vast array of downstream tasks to accelerate discovery of key network regulators and candidate therapeutic targets.
8
 
9
+ In `our manuscript <https://rdcu.be/ddrx0>`_, we report results for the 6 layer Geneformer model pretrained on Genecorpus-30M. We additionally provide within the repository a 12 layer Geneformer model, scaled up with retained width:depth aspect ratio, also pretrained on Genecorpus-30M.
10
 
11
+ Both the `6 <https://huggingface.co/ctheodoris/Geneformer/blob/main/pytorch_model.bin>`_ and `12 <https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer-12L-30M/pytorch_model.bin>`_ layer Geneformer models were pretrained in June 2021.
 
 
12
 
13
  Application
14
  -----------
 
39
  | - in silico perturbation to determine transcription factor targets
40
  | - in silico perturbation to determine transcription factor cooperativity
41
 
42
+ Citation
43
+ --------
44
 
45
  | C V Theodoris #, L Xiao, A Chopra, M D Chaffin, Z R Al Sayed, M C Hill, H Mantineo, E Brydon, Z Zeng, X S Liu, P T Ellinor #. `Transfer learning enables predictions in network biology. <https://rdcu.be/ddrx0>`_ *Nature*, 31 May 2023. (# co-corresponding authors)
 
 
docs/source/api.rst CHANGED
@@ -17,14 +17,6 @@ Classifier
17
 
18
  geneformer.classifier
19
 
20
- Multitask Classifier
21
- --------------------
22
-
23
- .. toctree::
24
- :maxdepth: 1
25
-
26
- geneformer.mtl_classifier
27
-
28
  Embedding Extractor
29
  -------------------
30
 
 
17
 
18
  geneformer.classifier
19
 
 
 
 
 
 
 
 
 
20
  Embedding Extractor
21
  -------------------
22
 
docs/source/geneformer.mtl_classifier.rst DELETED
@@ -1,11 +0,0 @@
1
- geneformer.mtl\_classifier
2
- ==========================
3
-
4
- .. automodule:: geneformer.mtl_classifier
5
- :members:
6
- :undoc-members:
7
- :show-inheritance:
8
- :exclude-members:
9
- valid_option_dict,
10
- validate_options,
11
- validate_additional_options
 
 
 
 
 
 
 
 
 
 
 
 
docs/source/geneformer.tokenizer.rst CHANGED
@@ -11,5 +11,4 @@ geneformer.tokenizer
11
  tokenize_files,
12
  tokenize_loom,
13
  rank_genes,
14
- tokenize_cell,
15
- sum_ensembl_ids
 
11
  tokenize_files,
12
  tokenize_loom,
13
  rank_genes,
14
+ tokenize_cell
 
docs/source/index.rst CHANGED
@@ -1,7 +1,7 @@
1
  Geneformer
2
  ==========
3
 
4
- Geneformer is a foundation transformer model pretrained on a large-scale corpus of single cell transcriptomes to enable context-aware predictions in network biology.
5
 
6
  See `our manuscript <https://rdcu.be/ddrx0>`_ for details.
7
 
 
1
  Geneformer
2
  ==========
3
 
4
+ Geneformer is a foundation transformer model pretrained on a large-scale corpus of ~30 million single cell transcriptomes to enable context-aware predictions in network biology.
5
 
6
  See `our manuscript <https://rdcu.be/ddrx0>`_ for details.
7
 
examples/multitask_cell_classification.ipynb DELETED
@@ -1,420 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "id": "866f100c-e11a-4e7b-a37c-831775d845a7",
6
- "metadata": {},
7
- "source": [
8
- "# Geneformer Multi-Task Cell Classifier Tutorial\n",
9
- "\n",
10
- "This tutorial demonstrates how to use the Geneformer Multi-Task Cell Classifier and optimizatize hyperparameter for fine-tuning"
11
- ]
12
- },
13
- {
14
- "cell_type": "markdown",
15
- "id": "311ba456-b44d-40c7-941d-3fc03bcda85a",
16
- "metadata": {},
17
- "source": [
18
- "## 1. Installation and Imports\n",
19
- "\n",
20
- "First import the necessary modules."
21
- ]
22
- },
23
- {
24
- "cell_type": "code",
25
- "execution_count": 3,
26
- "id": "cd9defdc-0524-4c3b-a741-27117ed3a5be",
27
- "metadata": {},
28
- "outputs": [],
29
- "source": [
30
- "from geneformer import MTLClassifier"
31
- ]
32
- },
33
- {
34
- "cell_type": "markdown",
35
- "id": "790e9c3c-f6d9-44b3-b9a5-05725760f4fd",
36
- "metadata": {},
37
- "source": [
38
- "## 2. Set up Paths and Parameters\n",
39
- "\n",
40
- "Now, let's set up the necessary paths and parameters for our classifier. We'll also define our task columns, which are specific columns from our dataset that represent the classification tasks we want to train the model on."
41
- ]
42
- },
43
- {
44
- "cell_type": "code",
45
- "execution_count": null,
46
- "id": "04a04197-8e45-47f8-a86f-202209ea10ae",
47
- "metadata": {},
48
- "outputs": [],
49
- "source": [
50
- "# Define paths\n",
51
- "pretrained_path = \"/path/to/pretrained/Geneformer/model\" \n",
52
- "# input data is tokenized rank value encodings generated by Geneformer tokenizer (see tokenizing_scRNAseq_data.ipynb)\n",
53
- "train_path = \"/path/to/train/data.dataset\"\n",
54
- "val_path = \"/path/to/val/data.dataset\"\n",
55
- "test_path = \"/path/to/test/data.dataset\"\n",
56
- "results_dir = \"/path/to/results/directory\"\n",
57
- "model_save_path = \"/path/to/model/save/path\"\n",
58
- "tensorboard_log_dir = \"/path/to/tensorboard/log/dir\"\n",
59
- "\n",
60
- "# Define tasks and hyperparameters\n",
61
- "# task_columns should be a list of column names from your dataset\n",
62
- "# Each column represents a specific classification task (e.g. cell type, disease state)\n",
63
- "task_columns = [\"cell_type\", \"disease_state\"] # Example task columns\n",
64
- "\n",
65
- "hyperparameters = {\n",
66
- " \"learning_rate\": {\"type\": \"float\", \"low\": 1e-5, \"high\": 1e-3, \"log\": True},\n",
67
- " \"warmup_ratio\": {\"type\": \"float\", \"low\": 0.005, \"high\": 0.01},\n",
68
- " \"weight_decay\": {\"type\": \"float\", \"low\": 0.01, \"high\": 0.1},\n",
69
- " \"dropout_rate\": {\"type\": \"float\", \"low\": 0.0, \"high\": 0.7},\n",
70
- " \"lr_scheduler_type\": {\"type\": \"categorical\", \"choices\": [\"cosine\"]},\n",
71
- " \"task_weights\": {\"type\": \"float\", \"low\": 0.1, \"high\": 2.0}\n",
72
- "}"
73
- ]
74
- },
75
- {
76
- "cell_type": "markdown",
77
- "id": "31857690-a739-435a-aefd-f171fafc1b78",
78
- "metadata": {},
79
- "source": [
80
- "In the code above, we've defined `task_columns` as `[\"cell_type\", \"disease_state\"]`. This means our model will be trained to classify cells based on two tasks:\n",
81
- "1. Identifying the cell type\n",
82
- "2. Determining the disease state\n",
83
- "3. Note: \"unique_cell_id\" is a required column in the dataset for logging and inference purposes\n",
84
- "\n",
85
- "These column names should correspond to actual columns in your dataset. Each column should contain the labels for that specific classification task.\n",
86
- "\n",
87
- "For example, your dataset might look something like this:\n",
88
- "\n",
89
- " | unique_cell_id | input_ids | ... | cell_type | disease_state |\n",
90
- " |----------------|-----------|-----|-----------|---------------|\n",
91
- " | cell1 | ... | ... | neuron | healthy |\n",
92
- " | cell2 | ... | ... | astrocyte | diseased |\n",
93
- " | ... | ... | ... | ... | ... |\n",
94
- "The model will learn to predict classes within 'cell_type' and 'disease_state' "
95
- ]
96
- },
97
- {
98
- "cell_type": "markdown",
99
- "id": "b9e3050a-6162-4c01-b6fd-8784bf4ab1e4",
100
- "metadata": {},
101
- "source": [
102
- "## 3. Initialize the MTLClassifier\n",
103
- "\n",
104
- "Now, let's create an instance of the MTLClassifier with our defined parameters and task columns."
105
- ]
106
- },
107
- {
108
- "cell_type": "code",
109
- "execution_count": null,
110
- "id": "e27caac9-670c-409d-9313-50201c665cb9",
111
- "metadata": {},
112
- "outputs": [],
113
- "source": [
114
- "mc = MTLClassifier(\n",
115
- " task_columns=task_columns, # Our defined classification tasks\n",
116
- " study_name=\"MTLClassifier_example\",\n",
117
- " pretrained_path=pretrained_path,\n",
118
- " train_path=train_path,\n",
119
- " val_path=val_path,\n",
120
- " test_path=test_path,\n",
121
- " model_save_path=model_save_path,\n",
122
- " results_dir=results_dir,\n",
123
- " tensorboard_log_dir=tensorboard_log_dir,\n",
124
- " hyperparameters=hyperparameters,\n",
125
- " n_trials=15, # Number of trials for hyperparameter optimization (at least 50 suggested)\n",
126
- " epochs=1, # Number of training epochs (1 suggested to prevent overfitting)\n",
127
- " batch_size=8, # Adjust based on available GPU memory\n",
128
- " seed=42\n",
129
- ")"
130
- ]
131
- },
132
- {
133
- "cell_type": "markdown",
134
- "id": "0d729444-e3ad-4584-9659-0c464ac97462",
135
- "metadata": {},
136
- "source": [
137
- "## 4. Run Hyperparameter Optimization\n",
138
- "\n",
139
- "Now, let's run the Optuna study to optimize our hyperparameters for both classification tasks."
140
- ]
141
- },
142
- {
143
- "cell_type": "code",
144
- "execution_count": null,
145
- "id": "9298aa3e-6a52-4aa8-b9ff-b63d97beac93",
146
- "metadata": {},
147
- "outputs": [],
148
- "source": [
149
- "mc.run_optuna_study()"
150
- ]
151
- },
152
- {
153
- "cell_type": "markdown",
154
- "id": "af23075d-d07b-43d3-bc5d-4df4d5d7199b",
155
- "metadata": {},
156
- "source": [
157
- "## 5. Evaluate the Model on Test Data\n",
158
- "\n",
159
- "After optimization, we can evaluate our model on the test dataset. This will provide performance metrics for both classification tasks. CSV containing following keys will be generated in specified results directiory \"Cell ID, task(1...n) True,task(1.,.n) Pred,task(1...n) Probabilities\""
160
- ]
161
- },
162
- {
163
- "cell_type": "code",
164
- "execution_count": null,
165
- "id": "461bf8d3-b964-4ff4-994f-9f3d313d4614",
166
- "metadata": {},
167
- "outputs": [],
168
- "source": [
169
- "mc.load_and_evaluate_test_model()"
170
- ]
171
- },
172
- {
173
- "cell_type": "markdown",
174
- "id": "31cfeb2d-6673-4b02-a79c-2533cc5e4d28",
175
- "metadata": {},
176
- "source": [
177
- "## 6. (Optional) Manual Hyperparameter Tuning\n",
178
- "\n",
179
- "If you prefer to set hyperparameters manually, you can use the following approach:"
180
- ]
181
- },
182
- {
183
- "cell_type": "code",
184
- "execution_count": null,
185
- "id": "8ee6b99f-42e9-4abf-a292-aa9047735e0e",
186
- "metadata": {},
187
- "outputs": [],
188
- "source": [
189
- "manual_hyperparameters = {\n",
190
- " \"learning_rate\": 0.001,\n",
191
- " \"warmup_ratio\": 0.01,\n",
192
- " \"weight_decay\": 0.1,\n",
193
- " \"dropout_rate\": 0.1,\n",
194
- " \"lr_scheduler_type\": \"cosine\",\n",
195
- " \"task_weights\": [1, 1], # Weights for each task (cell_type, disease_state)\n",
196
- " \"max_layers_to_freeze\": 2\n",
197
- "}\n",
198
- "\n",
199
- "mc_manual = MTLClassifier(\n",
200
- " task_columns=task_columns,\n",
201
- " study_name=\"mtl_manual\",\n",
202
- " pretrained_path=pretrained_path,\n",
203
- " train_path=train_path,\n",
204
- " val_path=val_path,\n",
205
- " test_path=test_path,\n",
206
- " model_save_path=model_save_path,\n",
207
- " results_dir=results_dir,\n",
208
- " tensorboard_log_dir=tensorboard_log_dir,\n",
209
- " manual_hyperparameters=manual_hyperparameters,\n",
210
- " use_manual_hyperparameters=True,\n",
211
- " epochs=10,\n",
212
- " batch_size=32,\n",
213
- " seed=42\n",
214
- ")\n",
215
- "\n",
216
- "mc_manual.run_manual_tuning()"
217
- ]
218
- },
219
- {
220
- "cell_type": "markdown",
221
- "id": "dbaac008-fc00-4b71-8e78-89b2d922d9d8",
222
- "metadata": {},
223
- "source": [
224
- "# Geneformer In Silico Perturber Tutorial (MTL Quantized)\n",
225
- "This demonstrates how to use the Geneformer In Silico Perturber with a Multi-Task Learning (MTL) model in a quantized configuration to optimize runtime and memory."
226
- ]
227
- },
228
- {
229
- "cell_type": "code",
230
- "execution_count": null,
231
- "id": "2e15ad57-736c-48f0-be87-39cf5015bc5c",
232
- "metadata": {},
233
- "outputs": [],
234
- "source": [
235
- "from geneformer import InSilicoPerturber, EmbExtractor, InSilicoPerturberStats"
236
- ]
237
- },
238
- {
239
- "cell_type": "code",
240
- "execution_count": null,
241
- "id": "43c18140-151e-4d44-95b4-a9b3a47172cf",
242
- "metadata": {},
243
- "outputs": [],
244
- "source": [
245
- "# Define paths\n",
246
- "model_directory = \"/path/to/model/save/path\"\n",
247
- "input_data_file = \"/path/to/input/data.dataset\"\n",
248
- "output_directory = \"/path/to/output/directory\"\n",
249
- "output_prefix = \"mtl_quantized_perturbation\"\n",
250
- "\n",
251
- "# Define parameters\n",
252
- "perturb_type = \"delete\" # or \"overexpress\"\n",
253
- "\n",
254
- "# Define cell states to model\n",
255
- "cell_states_to_model = {\n",
256
- " \"state_key\": \"disease_state\", \n",
257
- " \"start_state\": \"disease\", \n",
258
- " \"goal_state\": \"control\"\n",
259
- "}\n",
260
- "\n",
261
- "# Define filter data\n",
262
- "filter_data_dict = {\n",
263
- " \"cell_type\": [\"Fibroblast\"]\n",
264
- "}"
265
- ]
266
- },
267
- {
268
- "cell_type": "markdown",
269
- "id": "3010d0bf-b23c-45c1-ac12-8c472dc8b7a1",
270
- "metadata": {},
271
- "source": [
272
- "## 3. Extract State Embeddings\n",
273
- "\n",
274
- "Before we initialize the InSilicoPerturber, we need to extract the state embeddings using the EmbExtractor."
275
- ]
276
- },
277
- {
278
- "cell_type": "code",
279
- "execution_count": null,
280
- "id": "215f0a90-8041-417d-a5d3-b2483626c3b2",
281
- "metadata": {},
282
- "outputs": [],
283
- "source": [
284
- "# Initialize EmbExtractor\n",
285
- "embex = EmbExtractor(\n",
286
- " filter_data_dict=filter_data_dict,\n",
287
- " max_ncells=1000, # Number of cells to extract embeddings for\n",
288
- " emb_layer=0, # Use the second to last layer\n",
289
- " emb_mode = \"cls\",\n",
290
- " summary_stat=\"exact_mean\",\n",
291
- " forward_batch_size=8, # Adjust based on available GPU memory\n",
292
- " nproc=4\n",
293
- ")\n",
294
- "\n",
295
- "# Extract state embeddings\n",
296
- "state_embs_dict = embex.get_state_embs(\n",
297
- " cell_states_to_model,\n",
298
- " model_directory=model_directory,\n",
299
- " input_data_file=input_data_file,\n",
300
- " output_directory=output_directory,\n",
301
- " output_prefix=output_prefix\n",
302
- ")"
303
- ]
304
- },
305
- {
306
- "cell_type": "markdown",
307
- "id": "23f14e36-4529-4fb2-8af9-7f4875cf81e3",
308
- "metadata": {},
309
- "source": [
310
- "## 4. Initialize the InSilicoPerturber\n",
311
- "\n",
312
- "Now that we have our state embeddings, let's create an instance of the InSilicoPerturber with MTL and quantized configurations."
313
- ]
314
- },
315
- {
316
- "cell_type": "code",
317
- "execution_count": null,
318
- "id": "09f985a1-91bc-4e8d-8001-a3663531b570",
319
- "metadata": {},
320
- "outputs": [],
321
- "source": [
322
- "# Initialize InSilicoPerturber\n",
323
- "isp = InSilicoPerturber(\n",
324
- " perturb_type=perturb_type,\n",
325
- " genes_to_perturb=\"all\", # Perturb all genes\n",
326
- " model_type=\"MTLCellClassifier-Quantized\", # Use quantized MTL model\n",
327
- " emb_mode=\"cls\", # Use CLS token embedding\n",
328
- " cell_states_to_model=cell_states_to_model,\n",
329
- " state_embs_dict=state_embs_dict,\n",
330
- " max_ncells=1000, # Number of cells to perturb (larger number increases power)\n",
331
- " emb_layer=0, \n",
332
- " forward_batch_size=8, # Adjust based on available GPU memory\n",
333
- " nproc=1\n",
334
- ")"
335
- ]
336
- },
337
- {
338
- "cell_type": "markdown",
339
- "id": "cfcc2c1e-fd7f-4a36-99fc-ac7f43e5be6b",
340
- "metadata": {},
341
- "source": [
342
- "## 5. Run In Silico Perturbation\n",
343
- "\n",
344
- "Run the in silico perturbation on the dataset."
345
- ]
346
- },
347
- {
348
- "cell_type": "code",
349
- "execution_count": null,
350
- "id": "cf030c09-8ae4-45a7-aaf7-3fc2af4fe296",
351
- "metadata": {},
352
- "outputs": [],
353
- "source": [
354
- "# Run perturbation and output intermediate files\n",
355
- "isp.perturb_data(\n",
356
- " model_directory=model_directory,\n",
357
- " input_data_file=input_data_file,\n",
358
- " output_directory=output_directory,\n",
359
- " output_prefix=output_prefix\n",
360
- ")"
361
- ]
362
- },
363
- {
364
- "cell_type": "markdown",
365
- "id": "bb8ec074-6f2f-422b-a973-37ed32a15c38",
366
- "metadata": {},
367
- "source": [
368
- "## 6. Process Results with InSilicoPerturberStats\n",
369
- "\n",
370
- "After running the perturbation, we'll use InSilicoPerturberStats to process the intermediate files and generate the final statistics."
371
- ]
372
- },
373
- {
374
- "cell_type": "code",
375
- "execution_count": null,
376
- "id": "0a748043-43fc-47ad-ace5-f0ae3dd34674",
377
- "metadata": {},
378
- "outputs": [],
379
- "source": [
380
- "# Initialize InSilicoPerturberStats\n",
381
- "ispstats = InSilicoPerturberStats(\n",
382
- " mode=\"goal_state_shift\",\n",
383
- " genes_perturbed=\"all\",\n",
384
- " combos=0,\n",
385
- " anchor_gene=None,\n",
386
- " cell_states_to_model=cell_states_to_model\n",
387
- ")\n",
388
- "\n",
389
- "# Process stats and output final .csv\n",
390
- "ispstats.get_stats(\n",
391
- " input_data_file,\n",
392
- " None,\n",
393
- " output_directory,\n",
394
- " output_prefix\n",
395
- ")"
396
- ]
397
- }
398
- ],
399
- "metadata": {
400
- "kernelspec": {
401
- "display_name": "Python 3 (ipykernel)",
402
- "language": "python",
403
- "name": "python3"
404
- },
405
- "language_info": {
406
- "codemirror_mode": {
407
- "name": "ipython",
408
- "version": 3
409
- },
410
- "file_extension": ".py",
411
- "mimetype": "text/x-python",
412
- "name": "python",
413
- "nbconvert_exporter": "python",
414
- "pygments_lexer": "ipython3",
415
- "version": "3.11.5"
416
- }
417
- },
418
- "nbformat": 4,
419
- "nbformat_minor": 5
420
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/pretraining_new_model/pretrain_geneformer_w_deepspeed.py CHANGED
@@ -138,9 +138,7 @@ training_args = {
138
  "per_device_train_batch_size": geneformer_batch_size,
139
  "num_train_epochs": epochs,
140
  "save_strategy": "steps",
141
- "save_steps": np.floor(
142
- num_examples / geneformer_batch_size / 8
143
- ), # 8 saves per epoch
144
  "logging_steps": 1000,
145
  "output_dir": training_output_dir,
146
  "logging_dir": logging_dir,
 
138
  "per_device_train_batch_size": geneformer_batch_size,
139
  "num_train_epochs": epochs,
140
  "save_strategy": "steps",
141
+ "save_steps": np.floor(num_examples / geneformer_batch_size / 8), # 8 saves per epoch
 
 
142
  "logging_steps": 1000,
143
  "output_dir": training_output_dir,
144
  "logging_dir": logging_dir,
examples/tokenizing_scRNAseq_data.ipynb CHANGED
@@ -25,11 +25,7 @@
25
  "\n",
26
  "#### Additionally, if the original .loom file contains a cell column attribute called \"filter_pass\", this column will be used as a binary indicator of whether to include these cells in the tokenized data. All cells with \"1\" in this attribute will be tokenized, whereas the others will be excluded. One may use this column to indicate QC filtering or other criteria for selection for inclusion in the final tokenized dataset.\n",
27
  "\n",
28
- "#### If one's data is in other formats besides .loom or .h5ad, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom or .h5ad format prior to running the transcriptome tokenizer.\n",
29
- "\n",
30
- "#### OF NOTE: PLEASE ENSURE THE CORRECT TOKEN DICTIONARY AND GENE MEDIAN FILE IS USED FOR THE CORRECT MODEL.\n",
31
- "\n",
32
- "#### The 95M model series also require the special_token argument to be set to True and model_input_size to be 4096."
33
  ]
34
  },
35
  {
 
25
  "\n",
26
  "#### Additionally, if the original .loom file contains a cell column attribute called \"filter_pass\", this column will be used as a binary indicator of whether to include these cells in the tokenized data. All cells with \"1\" in this attribute will be tokenized, whereas the others will be excluded. One may use this column to indicate QC filtering or other criteria for selection for inclusion in the final tokenized dataset.\n",
27
  "\n",
28
+ "#### If one's data is in other formats besides .loom or .h5ad, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom or .h5ad format prior to running the transcriptome tokenizer."
 
 
 
 
29
  ]
30
  },
31
  {
fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/config.json RENAMED
File without changes
fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/optimizer.pt RENAMED
File without changes
fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/pytorch_model.bin RENAMED
File without changes
fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/rng_state.pth RENAMED
File without changes
fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/scheduler.pt RENAMED
File without changes
fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/trainer_state.json RENAMED
File without changes
fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/training_args.bin RENAMED
File without changes
fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/config.json DELETED
@@ -1,24 +0,0 @@
1
- {
2
- "architectures": [
3
- "BertForMaskedLM"
4
- ],
5
- "attention_probs_dropout_prob": 0.02,
6
- "classifier_dropout": null,
7
- "hidden_act": "relu",
8
- "hidden_dropout_prob": 0.02,
9
- "hidden_size": 512,
10
- "initializer_range": 0.02,
11
- "intermediate_size": 1024,
12
- "layer_norm_eps": 1e-12,
13
- "max_position_embeddings": 4096,
14
- "model_type": "bert",
15
- "num_attention_heads": 8,
16
- "num_hidden_layers": 12,
17
- "pad_token_id": 0,
18
- "position_embedding_type": "absolute",
19
- "torch_dtype": "float32",
20
- "transformers_version": "4.37.2",
21
- "type_vocab_size": 2,
22
- "use_cache": true,
23
- "vocab_size": 20275
24
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/pytorch_model.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:07b28d8c7bb789d59755c42d32f6182cc04d2cf34aafaa6397aa50e4fdf1a9b4
3
- size 152363342
 
 
 
 
{gf-12L-30M-i2048 → geneformer-12L-30M}/config.json RENAMED
File without changes
{gf-12L-30M-i2048 → geneformer-12L-30M}/pytorch_model.bin RENAMED
File without changes
{gf-12L-30M-i2048 → geneformer-12L-30M}/training_args.bin RENAMED
File without changes
geneformer/__init__.py CHANGED
@@ -1,14 +1,4 @@
1
  # ruff: noqa: F401
2
- import warnings
3
- from pathlib import Path
4
-
5
- warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*") # noqa # isort:skip
6
-
7
- GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary_gc95M.pkl"
8
- TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary_gc95M.pkl"
9
- ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict_gc95M.pkl"
10
- ENSEMBL_MAPPING_FILE = Path(__file__).parent / "ensembl_mapping_dict_gc95M.pkl"
11
-
12
  from . import (
13
  collator_for_classification,
14
  emb_extractor,
@@ -21,7 +11,7 @@ from .collator_for_classification import (
21
  DataCollatorForCellClassification,
22
  DataCollatorForGeneClassification,
23
  )
24
- from .emb_extractor import EmbExtractor, get_embs
25
  from .in_silico_perturber import InSilicoPerturber
26
  from .in_silico_perturber_stats import InSilicoPerturberStats
27
  from .pretrainer import GeneformerPretrainer
@@ -29,6 +19,3 @@ from .tokenizer import TranscriptomeTokenizer
29
 
30
  from . import classifier # noqa # isort:skip
31
  from .classifier import Classifier # noqa # isort:skip
32
-
33
- from . import mtl_classifier # noqa # isort:skip
34
- from .mtl_classifier import MTLClassifier # noqa # isort:skip
 
1
  # ruff: noqa: F401
 
 
 
 
 
 
 
 
 
 
2
  from . import (
3
  collator_for_classification,
4
  emb_extractor,
 
11
  DataCollatorForCellClassification,
12
  DataCollatorForGeneClassification,
13
  )
14
+ from .emb_extractor import EmbExtractor
15
  from .in_silico_perturber import InSilicoPerturber
16
  from .in_silico_perturber_stats import InSilicoPerturberStats
17
  from .pretrainer import GeneformerPretrainer
 
19
 
20
  from . import classifier # noqa # isort:skip
21
  from .classifier import Classifier # noqa # isort:skip
 
 
 
geneformer/classifier.py CHANGED
@@ -53,18 +53,16 @@ from pathlib import Path
53
  import numpy as np
54
  import pandas as pd
55
  import seaborn as sns
 
56
  from tqdm.auto import tqdm, trange
57
  from transformers import Trainer
58
  from transformers.training_args import TrainingArguments
59
 
60
- from . import (
61
- TOKEN_DICTIONARY_FILE,
62
- DataCollatorForCellClassification,
63
- DataCollatorForGeneClassification,
64
- )
65
  from . import classifier_utils as cu
66
  from . import evaluation_utils as eu
67
  from . import perturber_utils as pu
 
68
 
69
  sns.set()
70
 
@@ -75,7 +73,6 @@ logger = logging.getLogger(__name__)
75
  class Classifier:
76
  valid_option_dict = {
77
  "classifier": {"cell", "gene"},
78
- "quantize": {bool, dict},
79
  "cell_state_dict": {None, dict},
80
  "gene_class_dict": {None, dict},
81
  "filter_data": {None, dict},
@@ -89,7 +86,6 @@ class Classifier:
89
  "no_eval": {bool},
90
  "stratify_splits_col": {None, str},
91
  "forward_batch_size": {int},
92
- "token_dictionary_file": {None, str},
93
  "nproc": {int},
94
  "ngpu": {int},
95
  }
@@ -97,7 +93,6 @@ class Classifier:
97
  def __init__(
98
  self,
99
  classifier=None,
100
- quantize=False,
101
  cell_state_dict=None,
102
  gene_class_dict=None,
103
  filter_data=None,
@@ -112,7 +107,6 @@ class Classifier:
112
  stratify_splits_col=None,
113
  no_eval=False,
114
  forward_batch_size=100,
115
- token_dictionary_file=None,
116
  nproc=4,
117
  ngpu=1,
118
  ):
@@ -123,13 +117,6 @@ class Classifier:
123
 
124
  classifier : {"cell", "gene"}
125
  | Whether to fine-tune a cell state or gene classifier.
126
- quantize : bool, dict
127
- | Whether to fine-tune a quantized model.
128
- | If True and no config provided, will use default.
129
- | Will use custom config if provided.
130
- | Configs should be provided as dictionary of BitsAndBytesConfig (transformers) and LoraConfig (peft).
131
- | For example: {"bnb_config": BitsAndBytesConfig(...),
132
- | "peft_config": LoraConfig(...)}
133
  cell_state_dict : None, dict
134
  | Cell states to fine-tune model to distinguish.
135
  | Two-item dictionary with keys: state_key and states
@@ -188,9 +175,6 @@ class Classifier:
188
  | Otherwise, will perform eval during training.
189
  forward_batch_size : int
190
  | Batch size for forward pass (for evaluation, not training).
191
- token_dictionary_file : None, str
192
- | Default is to use token dictionary file from Geneformer
193
- | Otherwise, will load custom gene token dictionary.
194
  nproc : int
195
  | Number of CPU processes to use.
196
  ngpu : int
@@ -199,11 +183,6 @@ class Classifier:
199
  """
200
 
201
  self.classifier = classifier
202
- if self.classifier == "cell":
203
- self.model_type = "CellClassifier"
204
- elif self.classifier == "gene":
205
- self.model_type = "GeneClassifier"
206
- self.quantize = quantize
207
  self.cell_state_dict = cell_state_dict
208
  self.gene_class_dict = gene_class_dict
209
  self.filter_data = filter_data
@@ -222,7 +201,6 @@ class Classifier:
222
  self.stratify_splits_col = stratify_splits_col
223
  self.no_eval = no_eval
224
  self.forward_batch_size = forward_batch_size
225
- self.token_dictionary_file = token_dictionary_file
226
  self.nproc = nproc
227
  self.ngpu = ngpu
228
 
@@ -244,9 +222,7 @@ class Classifier:
244
  ] = self.cell_state_dict["states"]
245
 
246
  # load token dictionary (Ensembl IDs:token)
247
- if self.token_dictionary_file is None:
248
- self.token_dictionary_file = TOKEN_DICTIONARY_FILE
249
- with open(self.token_dictionary_file, "rb") as f:
250
  self.gene_token_dict = pickle.load(f)
251
 
252
  self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
@@ -269,7 +245,7 @@ class Classifier:
269
  f"Genes to classify {missing_genes} are not in token dictionary."
270
  )
271
  self.gene_class_dict = {
272
- k: list(set([self.gene_token_dict.get(gene) for gene in v]))
273
  for k, v in self.gene_class_dict.items()
274
  }
275
  empty_classes = []
@@ -291,7 +267,7 @@ class Classifier:
291
  continue
292
  valid_type = False
293
  for option in valid_options:
294
- if (option in [int, float, list, dict, bool, str]) and isinstance(
295
  attr_value, option
296
  ):
297
  valid_type = True
@@ -417,15 +393,6 @@ class Classifier:
417
  )
418
  raise
419
 
420
- if (attr_to_split is not None) and (attr_to_balance is None):
421
- logger.error(
422
- "Splitting by attribute while balancing confounders requires both attr_to_split and attr_to_balance to be defined."
423
- )
424
- raise
425
-
426
- if not isinstance(attr_to_balance, list):
427
- attr_to_balance = [attr_to_balance]
428
-
429
  if self.classifier == "cell":
430
  # remove cell states representing < rare_threshold of cells
431
  data = cu.remove_rare(
@@ -467,8 +434,8 @@ class Classifier:
467
  test_data_output_path = (
468
  Path(output_directory) / f"{output_prefix}_labeled_test"
469
  ).with_suffix(".dataset")
470
- data_dict["train"].save_to_disk(str(train_data_output_path))
471
- data_dict["test"].save_to_disk(str(test_data_output_path))
472
  elif (test_size is not None) and (self.classifier == "cell"):
473
  if 1 > test_size > 0:
474
  if attr_to_split is None:
@@ -483,8 +450,8 @@ class Classifier:
483
  test_data_output_path = (
484
  Path(output_directory) / f"{output_prefix}_labeled_test"
485
  ).with_suffix(".dataset")
486
- data_dict["train"].save_to_disk(str(train_data_output_path))
487
- data_dict["test"].save_to_disk(str(test_data_output_path))
488
  else:
489
  data_dict, balance_df = cu.balance_attr_splits(
490
  data,
@@ -505,19 +472,19 @@ class Classifier:
505
  test_data_output_path = (
506
  Path(output_directory) / f"{output_prefix}_labeled_test"
507
  ).with_suffix(".dataset")
508
- data_dict["train"].save_to_disk(str(train_data_output_path))
509
- data_dict["test"].save_to_disk(str(test_data_output_path))
510
  else:
511
  data_output_path = (
512
  Path(output_directory) / f"{output_prefix}_labeled"
513
  ).with_suffix(".dataset")
514
- data.save_to_disk(str(data_output_path))
515
  print(data_output_path)
516
  else:
517
  data_output_path = (
518
  Path(output_directory) / f"{output_prefix}_labeled"
519
  ).with_suffix(".dataset")
520
- data.save_to_disk(str(data_output_path))
521
 
522
  def train_all_data(
523
  self,
@@ -527,7 +494,6 @@ class Classifier:
527
  output_directory,
528
  output_prefix,
529
  save_eval_output=True,
530
- gene_balance=False,
531
  ):
532
  """
533
  Train cell state or gene classifier using all data.
@@ -548,9 +514,6 @@ class Classifier:
548
  save_eval_output : bool
549
  | Whether to save cross-fold eval output
550
  | Saves as pickle file of dictionary of eval metrics
551
- gene_balance : None, bool
552
- | Whether to automatically balance genes in training set.
553
- | Only available for binary gene classifications.
554
 
555
  **Output**
556
 
@@ -558,12 +521,6 @@ class Classifier:
558
 
559
  """
560
 
561
- if (gene_balance is True) and (len(self.gene_class_dict.values()) != 2):
562
- logger.error(
563
- "Automatically balancing gene sets for training is only available for binary gene classifications."
564
- )
565
- raise
566
-
567
  ##### Load data and prepare output directory #####
568
  # load numerical id to class dictionary (id:class)
569
  with open(id_class_dict_file, "rb") as f:
@@ -595,7 +552,7 @@ class Classifier:
595
  )
596
  assert len(targets) == len(labels)
597
  data = cu.prep_gene_classifier_all_data(
598
- data, targets, labels, self.max_ncells, self.nproc, gene_balance
599
  )
600
 
601
  trainer = self.train_classifier(
@@ -614,15 +571,12 @@ class Classifier:
614
  split_id_dict=None,
615
  attr_to_split=None,
616
  attr_to_balance=None,
617
- gene_balance=False,
618
  max_trials=100,
619
  pval_threshold=0.1,
620
  save_eval_output=True,
621
  predict_eval=True,
622
  predict_trainer=False,
623
  n_hyperopt_trials=0,
624
- save_gene_split_datasets=True,
625
- debug_gene_split_datasets=False,
626
  ):
627
  """
628
  (Cross-)validate cell state or gene classifier.
@@ -657,9 +611,6 @@ class Classifier:
657
  attr_to_balance : None, list
658
  | List of attribute keys on which to balance data while splitting on attr_to_split
659
  | e.g. ["age", "sex"] for balancing these characteristics while splitting by patient
660
- gene_balance : None, bool
661
- | Whether to automatically balance genes in training set.
662
- | Only available for binary gene classifications.
663
  max_trials : None, int
664
  | Maximum number of trials of random splitting to try to achieve balanced other attribute
665
  | If no split is found without significant (p < pval_threshold) differences in other attributes, will select best
@@ -678,19 +629,12 @@ class Classifier:
678
  n_hyperopt_trials : int
679
  | Number of trials to run for hyperparameter optimization
680
  | If 0, will not optimize hyperparameters
681
- save_gene_split_datasets : bool
682
- | Whether or not to save train, valid, and test gene-labeled datasets
683
  """
 
684
  if self.num_crossval_splits == 0:
685
  logger.error("num_crossval_splits must be 1 or 5 to validate.")
686
  raise
687
 
688
- if (gene_balance is True) and (len(self.gene_class_dict.values()) != 2):
689
- logger.error(
690
- "Automatically balancing gene sets for training is only available for binary gene classifications."
691
- )
692
- raise
693
-
694
  # ensure number of genes in each class is > 5 if validating model
695
  if self.classifier == "gene":
696
  insuff_classes = [k for k, v in self.gene_class_dict.items() if len(v) < 5]
@@ -771,7 +715,7 @@ class Classifier:
771
  else:
772
  # 5-fold cross-validate
773
  num_cells = len(data)
774
- fifth_cells = int(np.floor(num_cells * 0.2))
775
  num_eval = min((self.eval_size * num_cells), fifth_cells)
776
  start = i * fifth_cells
777
  end = start + num_eval
@@ -828,20 +772,17 @@ class Classifier:
828
  ]
829
  )
830
  assert len(targets) == len(labels)
831
- n_splits = int(1 / (1 - self.train_size))
832
- skf = cu.StratifiedKFold3(n_splits=n_splits, random_state=0, shuffle=True)
833
  # (Cross-)validate
834
- test_ratio = self.oos_test_size / (self.eval_size + self.oos_test_size)
835
- for train_index, eval_index, test_index in tqdm(
836
- skf.split(targets, labels, test_ratio)
837
- ):
838
  print(
839
  f"****** Validation split: {iteration_num}/{self.num_crossval_splits} ******\n"
840
  )
841
  ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}")
842
  # filter data for examples containing classes for this split
843
  # subsample to max_ncells and relabel data in column "labels"
844
- train_data, eval_data = cu.prep_gene_classifier_train_eval_split(
845
  data,
846
  targets,
847
  labels,
@@ -850,42 +791,8 @@ class Classifier:
850
  self.max_ncells,
851
  iteration_num,
852
  self.nproc,
853
- gene_balance,
854
  )
855
 
856
- if save_gene_split_datasets is True:
857
- for split_name in ["train", "valid"]:
858
- labeled_dataset_output_path = (
859
- Path(output_dir)
860
- / f"{output_prefix}_{split_name}_gene_labeled_ksplit{iteration_num}"
861
- ).with_suffix(".dataset")
862
- if split_name == "train":
863
- train_data.save_to_disk(str(labeled_dataset_output_path))
864
- elif split_name == "valid":
865
- eval_data.save_to_disk(str(labeled_dataset_output_path))
866
-
867
- if self.oos_test_size > 0:
868
- test_data = cu.prep_gene_classifier_split(
869
- data,
870
- targets,
871
- labels,
872
- test_index,
873
- "test",
874
- self.max_ncells,
875
- iteration_num,
876
- self.nproc,
877
- )
878
- if save_gene_split_datasets is True:
879
- test_labeled_dataset_output_path = (
880
- Path(output_dir)
881
- / f"{output_prefix}_test_gene_labeled_ksplit{iteration_num}"
882
- ).with_suffix(".dataset")
883
- test_data.save_to_disk(str(test_labeled_dataset_output_path))
884
- if debug_gene_split_datasets is True:
885
- logger.error(
886
- "Exiting after saving gene split datasets given debug_gene_split_datasets = True."
887
- )
888
- raise
889
  if n_hyperopt_trials == 0:
890
  trainer = self.train_classifier(
891
  model_directory,
@@ -895,15 +802,6 @@ class Classifier:
895
  ksplit_output_dir,
896
  predict_trainer,
897
  )
898
- result = self.evaluate_model(
899
- trainer.model,
900
- num_classes,
901
- id_class_dict,
902
- eval_data,
903
- predict_eval,
904
- ksplit_output_dir,
905
- output_prefix,
906
- )
907
  else:
908
  trainer = self.hyperopt_classifier(
909
  model_directory,
@@ -913,27 +811,20 @@ class Classifier:
913
  ksplit_output_dir,
914
  n_trials=n_hyperopt_trials,
915
  )
916
-
917
- model = cu.load_best_model(
918
- ksplit_output_dir, self.model_type, num_classes
919
- )
920
-
921
- if self.oos_test_size > 0:
922
- result = self.evaluate_model(
923
- model,
924
- num_classes,
925
- id_class_dict,
926
- test_data,
927
- predict_eval,
928
- ksplit_output_dir,
929
- output_prefix,
930
- )
931
  else:
932
- if iteration_num == self.num_crossval_splits:
933
- return
934
- else:
935
- iteration_num = iteration_num + 1
936
- continue
 
 
 
 
 
 
937
  results += [result]
938
  all_conf_mat = all_conf_mat + result["conf_mat"]
939
  # break after 1 or 5 splits, each with train/eval proportions dictated by eval_size
@@ -1034,13 +925,12 @@ class Classifier:
1034
  subprocess.call(f"mkdir {output_directory}", shell=True)
1035
 
1036
  ##### Load model and training args #####
1037
- model = pu.load_model(
1038
- self.model_type,
1039
- num_classes,
1040
- model_directory,
1041
- "train",
1042
- quantize=self.quantize,
1043
- )
1044
  def_training_args, def_freeze_layers = cu.get_default_train_args(
1045
  model, self.classifier, train_data, output_directory
1046
  )
@@ -1056,31 +946,18 @@ class Classifier:
1056
  if eval_data is None:
1057
  def_training_args["evaluation_strategy"] = "no"
1058
  def_training_args["load_best_model_at_end"] = False
1059
- def_training_args.update(
1060
- {"save_strategy": "epoch", "save_total_limit": 1}
1061
- ) # only save last model for each run
1062
  training_args_init = TrainingArguments(**def_training_args)
1063
 
1064
  ##### Fine-tune the model #####
1065
  # define the data collator
1066
  if self.classifier == "cell":
1067
- data_collator = DataCollatorForCellClassification(
1068
- token_dictionary=self.gene_token_dict
1069
- )
1070
  elif self.classifier == "gene":
1071
- data_collator = DataCollatorForGeneClassification(
1072
- token_dictionary=self.gene_token_dict
1073
- )
1074
 
1075
  # define function to initiate model
1076
  def model_init():
1077
- model = pu.load_model(
1078
- self.model_type,
1079
- num_classes,
1080
- model_directory,
1081
- "train",
1082
- quantize=self.quantize,
1083
- )
1084
 
1085
  if self.freeze_layers is not None:
1086
  def_freeze_layers = self.freeze_layers
@@ -1091,8 +968,7 @@ class Classifier:
1091
  for param in module.parameters():
1092
  param.requires_grad = False
1093
 
1094
- if self.quantize is False:
1095
- model = model.to("cuda:0")
1096
  return model
1097
 
1098
  # create the trainer
@@ -1142,7 +1018,6 @@ class Classifier:
1142
  metric="eval_macro_f1",
1143
  metric_columns=["loss", "eval_loss", "eval_accuracy", "eval_macro_f1"],
1144
  ),
1145
- storage_path=output_directory,
1146
  )
1147
 
1148
  return trainer
@@ -1205,13 +1080,11 @@ class Classifier:
1205
  subprocess.call(f"mkdir {output_directory}", shell=True)
1206
 
1207
  ##### Load model and training args #####
1208
- model = pu.load_model(
1209
- self.model_type,
1210
- num_classes,
1211
- model_directory,
1212
- "train",
1213
- quantize=self.quantize,
1214
- )
1215
 
1216
  def_training_args, def_freeze_layers = cu.get_default_train_args(
1217
  model, self.classifier, train_data, output_directory
@@ -1241,13 +1114,9 @@ class Classifier:
1241
  ##### Fine-tune the model #####
1242
  # define the data collator
1243
  if self.classifier == "cell":
1244
- data_collator = DataCollatorForCellClassification(
1245
- token_dictionary=self.gene_token_dict
1246
- )
1247
  elif self.classifier == "gene":
1248
- data_collator = DataCollatorForGeneClassification(
1249
- token_dictionary=self.gene_token_dict
1250
- )
1251
 
1252
  # create the trainer
1253
  trainer = Trainer(
@@ -1369,13 +1238,11 @@ class Classifier:
1369
  test_data = pu.load_and_filter(None, self.nproc, test_data_file)
1370
 
1371
  # load previously fine-tuned model
1372
- model = pu.load_model(
1373
- self.model_type,
1374
- num_classes,
1375
- model_directory,
1376
- "eval",
1377
- quantize=self.quantize,
1378
- )
1379
 
1380
  # evaluate the model
1381
  result = self.evaluate_model(
 
53
  import numpy as np
54
  import pandas as pd
55
  import seaborn as sns
56
+ from sklearn.model_selection import StratifiedKFold
57
  from tqdm.auto import tqdm, trange
58
  from transformers import Trainer
59
  from transformers.training_args import TrainingArguments
60
 
61
+ from . import DataCollatorForCellClassification, DataCollatorForGeneClassification
 
 
 
 
62
  from . import classifier_utils as cu
63
  from . import evaluation_utils as eu
64
  from . import perturber_utils as pu
65
+ from .tokenizer import TOKEN_DICTIONARY_FILE
66
 
67
  sns.set()
68
 
 
73
  class Classifier:
74
  valid_option_dict = {
75
  "classifier": {"cell", "gene"},
 
76
  "cell_state_dict": {None, dict},
77
  "gene_class_dict": {None, dict},
78
  "filter_data": {None, dict},
 
86
  "no_eval": {bool},
87
  "stratify_splits_col": {None, str},
88
  "forward_batch_size": {int},
 
89
  "nproc": {int},
90
  "ngpu": {int},
91
  }
 
93
  def __init__(
94
  self,
95
  classifier=None,
 
96
  cell_state_dict=None,
97
  gene_class_dict=None,
98
  filter_data=None,
 
107
  stratify_splits_col=None,
108
  no_eval=False,
109
  forward_batch_size=100,
 
110
  nproc=4,
111
  ngpu=1,
112
  ):
 
117
 
118
  classifier : {"cell", "gene"}
119
  | Whether to fine-tune a cell state or gene classifier.
 
 
 
 
 
 
 
120
  cell_state_dict : None, dict
121
  | Cell states to fine-tune model to distinguish.
122
  | Two-item dictionary with keys: state_key and states
 
175
  | Otherwise, will perform eval during training.
176
  forward_batch_size : int
177
  | Batch size for forward pass (for evaluation, not training).
 
 
 
178
  nproc : int
179
  | Number of CPU processes to use.
180
  ngpu : int
 
183
  """
184
 
185
  self.classifier = classifier
 
 
 
 
 
186
  self.cell_state_dict = cell_state_dict
187
  self.gene_class_dict = gene_class_dict
188
  self.filter_data = filter_data
 
201
  self.stratify_splits_col = stratify_splits_col
202
  self.no_eval = no_eval
203
  self.forward_batch_size = forward_batch_size
 
204
  self.nproc = nproc
205
  self.ngpu = ngpu
206
 
 
222
  ] = self.cell_state_dict["states"]
223
 
224
  # load token dictionary (Ensembl IDs:token)
225
+ with open(TOKEN_DICTIONARY_FILE, "rb") as f:
 
 
226
  self.gene_token_dict = pickle.load(f)
227
 
228
  self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
 
245
  f"Genes to classify {missing_genes} are not in token dictionary."
246
  )
247
  self.gene_class_dict = {
248
+ k: set([self.gene_token_dict.get(gene) for gene in v])
249
  for k, v in self.gene_class_dict.items()
250
  }
251
  empty_classes = []
 
267
  continue
268
  valid_type = False
269
  for option in valid_options:
270
+ if (option in [int, float, list, dict, bool]) and isinstance(
271
  attr_value, option
272
  ):
273
  valid_type = True
 
393
  )
394
  raise
395
 
 
 
 
 
 
 
 
 
 
396
  if self.classifier == "cell":
397
  # remove cell states representing < rare_threshold of cells
398
  data = cu.remove_rare(
 
434
  test_data_output_path = (
435
  Path(output_directory) / f"{output_prefix}_labeled_test"
436
  ).with_suffix(".dataset")
437
+ data_dict["train"].save_to_disk(train_data_output_path)
438
+ data_dict["test"].save_to_disk(test_data_output_path)
439
  elif (test_size is not None) and (self.classifier == "cell"):
440
  if 1 > test_size > 0:
441
  if attr_to_split is None:
 
450
  test_data_output_path = (
451
  Path(output_directory) / f"{output_prefix}_labeled_test"
452
  ).with_suffix(".dataset")
453
+ data_dict["train"].save_to_disk(train_data_output_path)
454
+ data_dict["test"].save_to_disk(test_data_output_path)
455
  else:
456
  data_dict, balance_df = cu.balance_attr_splits(
457
  data,
 
472
  test_data_output_path = (
473
  Path(output_directory) / f"{output_prefix}_labeled_test"
474
  ).with_suffix(".dataset")
475
+ data_dict["train"].save_to_disk(train_data_output_path)
476
+ data_dict["test"].save_to_disk(test_data_output_path)
477
  else:
478
  data_output_path = (
479
  Path(output_directory) / f"{output_prefix}_labeled"
480
  ).with_suffix(".dataset")
481
+ data.save_to_disk(data_output_path)
482
  print(data_output_path)
483
  else:
484
  data_output_path = (
485
  Path(output_directory) / f"{output_prefix}_labeled"
486
  ).with_suffix(".dataset")
487
+ data.save_to_disk(data_output_path)
488
 
489
  def train_all_data(
490
  self,
 
494
  output_directory,
495
  output_prefix,
496
  save_eval_output=True,
 
497
  ):
498
  """
499
  Train cell state or gene classifier using all data.
 
514
  save_eval_output : bool
515
  | Whether to save cross-fold eval output
516
  | Saves as pickle file of dictionary of eval metrics
 
 
 
517
 
518
  **Output**
519
 
 
521
 
522
  """
523
 
 
 
 
 
 
 
524
  ##### Load data and prepare output directory #####
525
  # load numerical id to class dictionary (id:class)
526
  with open(id_class_dict_file, "rb") as f:
 
552
  )
553
  assert len(targets) == len(labels)
554
  data = cu.prep_gene_classifier_all_data(
555
+ data, targets, labels, self.max_ncells, self.nproc
556
  )
557
 
558
  trainer = self.train_classifier(
 
571
  split_id_dict=None,
572
  attr_to_split=None,
573
  attr_to_balance=None,
 
574
  max_trials=100,
575
  pval_threshold=0.1,
576
  save_eval_output=True,
577
  predict_eval=True,
578
  predict_trainer=False,
579
  n_hyperopt_trials=0,
 
 
580
  ):
581
  """
582
  (Cross-)validate cell state or gene classifier.
 
611
  attr_to_balance : None, list
612
  | List of attribute keys on which to balance data while splitting on attr_to_split
613
  | e.g. ["age", "sex"] for balancing these characteristics while splitting by patient
 
 
 
614
  max_trials : None, int
615
  | Maximum number of trials of random splitting to try to achieve balanced other attribute
616
  | If no split is found without significant (p < pval_threshold) differences in other attributes, will select best
 
629
  n_hyperopt_trials : int
630
  | Number of trials to run for hyperparameter optimization
631
  | If 0, will not optimize hyperparameters
 
 
632
  """
633
+
634
  if self.num_crossval_splits == 0:
635
  logger.error("num_crossval_splits must be 1 or 5 to validate.")
636
  raise
637
 
 
 
 
 
 
 
638
  # ensure number of genes in each class is > 5 if validating model
639
  if self.classifier == "gene":
640
  insuff_classes = [k for k, v in self.gene_class_dict.items() if len(v) < 5]
 
715
  else:
716
  # 5-fold cross-validate
717
  num_cells = len(data)
718
+ fifth_cells = num_cells * 0.2
719
  num_eval = min((self.eval_size * num_cells), fifth_cells)
720
  start = i * fifth_cells
721
  end = start + num_eval
 
772
  ]
773
  )
774
  assert len(targets) == len(labels)
775
+ n_splits = int(1 / self.eval_size)
776
+ skf = StratifiedKFold(n_splits=n_splits, random_state=0, shuffle=True)
777
  # (Cross-)validate
778
+ for train_index, eval_index in tqdm(skf.split(targets, labels)):
 
 
 
779
  print(
780
  f"****** Validation split: {iteration_num}/{self.num_crossval_splits} ******\n"
781
  )
782
  ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}")
783
  # filter data for examples containing classes for this split
784
  # subsample to max_ncells and relabel data in column "labels"
785
+ train_data, eval_data = cu.prep_gene_classifier_split(
786
  data,
787
  targets,
788
  labels,
 
791
  self.max_ncells,
792
  iteration_num,
793
  self.nproc,
 
794
  )
795
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
796
  if n_hyperopt_trials == 0:
797
  trainer = self.train_classifier(
798
  model_directory,
 
802
  ksplit_output_dir,
803
  predict_trainer,
804
  )
 
 
 
 
 
 
 
 
 
805
  else:
806
  trainer = self.hyperopt_classifier(
807
  model_directory,
 
811
  ksplit_output_dir,
812
  n_trials=n_hyperopt_trials,
813
  )
814
+ if iteration_num == self.num_crossval_splits:
815
+ return
 
 
 
 
 
 
 
 
 
 
 
 
 
816
  else:
817
+ iteration_num = iteration_num + 1
818
+ continue
819
+ result = self.evaluate_model(
820
+ trainer.model,
821
+ num_classes,
822
+ id_class_dict,
823
+ eval_data,
824
+ predict_eval,
825
+ ksplit_output_dir,
826
+ output_prefix,
827
+ )
828
  results += [result]
829
  all_conf_mat = all_conf_mat + result["conf_mat"]
830
  # break after 1 or 5 splits, each with train/eval proportions dictated by eval_size
 
925
  subprocess.call(f"mkdir {output_directory}", shell=True)
926
 
927
  ##### Load model and training args #####
928
+ if self.classifier == "cell":
929
+ model_type = "CellClassifier"
930
+ elif self.classifier == "gene":
931
+ model_type = "GeneClassifier"
932
+
933
+ model = pu.load_model(model_type, num_classes, model_directory, "train")
 
934
  def_training_args, def_freeze_layers = cu.get_default_train_args(
935
  model, self.classifier, train_data, output_directory
936
  )
 
946
  if eval_data is None:
947
  def_training_args["evaluation_strategy"] = "no"
948
  def_training_args["load_best_model_at_end"] = False
 
 
 
949
  training_args_init = TrainingArguments(**def_training_args)
950
 
951
  ##### Fine-tune the model #####
952
  # define the data collator
953
  if self.classifier == "cell":
954
+ data_collator = DataCollatorForCellClassification()
 
 
955
  elif self.classifier == "gene":
956
+ data_collator = DataCollatorForGeneClassification()
 
 
957
 
958
  # define function to initiate model
959
  def model_init():
960
+ model = pu.load_model(model_type, num_classes, model_directory, "train")
 
 
 
 
 
 
961
 
962
  if self.freeze_layers is not None:
963
  def_freeze_layers = self.freeze_layers
 
968
  for param in module.parameters():
969
  param.requires_grad = False
970
 
971
+ model = model.to("cuda:0")
 
972
  return model
973
 
974
  # create the trainer
 
1018
  metric="eval_macro_f1",
1019
  metric_columns=["loss", "eval_loss", "eval_accuracy", "eval_macro_f1"],
1020
  ),
 
1021
  )
1022
 
1023
  return trainer
 
1080
  subprocess.call(f"mkdir {output_directory}", shell=True)
1081
 
1082
  ##### Load model and training args #####
1083
+ if self.classifier == "cell":
1084
+ model_type = "CellClassifier"
1085
+ elif self.classifier == "gene":
1086
+ model_type = "GeneClassifier"
1087
+ model = pu.load_model(model_type, num_classes, model_directory, "train")
 
 
1088
 
1089
  def_training_args, def_freeze_layers = cu.get_default_train_args(
1090
  model, self.classifier, train_data, output_directory
 
1114
  ##### Fine-tune the model #####
1115
  # define the data collator
1116
  if self.classifier == "cell":
1117
+ data_collator = DataCollatorForCellClassification()
 
 
1118
  elif self.classifier == "gene":
1119
+ data_collator = DataCollatorForGeneClassification()
 
 
1120
 
1121
  # create the trainer
1122
  trainer = Trainer(
 
1238
  test_data = pu.load_and_filter(None, self.nproc, test_data_file)
1239
 
1240
  # load previously fine-tuned model
1241
+ if self.classifier == "cell":
1242
+ model_type = "CellClassifier"
1243
+ elif self.classifier == "gene":
1244
+ model_type = "GeneClassifier"
1245
+ model = pu.load_model(model_type, num_classes, model_directory, "eval")
 
 
1246
 
1247
  # evaluate the model
1248
  result = self.evaluate_model(
geneformer/classifier_utils.py CHANGED
@@ -1,6 +1,4 @@
1
- import json
2
  import logging
3
- import os
4
  import random
5
  from collections import Counter, defaultdict
6
 
@@ -8,7 +6,6 @@ import numpy as np
8
  import pandas as pd
9
  from scipy.stats import chisquare, ranksums
10
  from sklearn.metrics import accuracy_score, f1_score
11
- from sklearn.model_selection import StratifiedKFold, train_test_split
12
 
13
  from . import perturber_utils as pu
14
 
@@ -136,98 +133,64 @@ def label_gene_classes(example, class_id_dict, gene_class_dict):
136
  ]
137
 
138
 
139
- def prep_gene_classifier_train_eval_split(
140
- data,
141
- targets,
142
- labels,
143
- train_index,
144
- eval_index,
145
- max_ncells,
146
- iteration_num,
147
- num_proc,
148
- balance=False,
149
- ):
150
- # generate cross-validation splits
151
- train_data = prep_gene_classifier_split(
152
- data,
153
- targets,
154
- labels,
155
- train_index,
156
- "train",
157
- max_ncells,
158
- iteration_num,
159
- num_proc,
160
- balance,
161
- )
162
- eval_data = prep_gene_classifier_split(
163
- data,
164
- targets,
165
- labels,
166
- eval_index,
167
- "eval",
168
- max_ncells,
169
- iteration_num,
170
- num_proc,
171
- balance,
172
- )
173
- return train_data, eval_data
174
-
175
-
176
  def prep_gene_classifier_split(
177
- data,
178
- targets,
179
- labels,
180
- index,
181
- subset_name,
182
- max_ncells,
183
- iteration_num,
184
- num_proc,
185
- balance=False,
186
  ):
187
  # generate cross-validation splits
188
  targets = np.array(targets)
189
  labels = np.array(labels)
190
- targets_subset = targets[index]
191
- labels_subset = labels[index]
192
- label_dict_subset = dict(zip(targets_subset, labels_subset))
 
193
 
194
  # function to filter by whether contains train or eval labels
195
- def if_contains_subset_label(example):
196
- a = targets_subset
 
 
 
 
 
197
  b = example["input_ids"]
198
  return not set(a).isdisjoint(b)
199
 
200
  # filter dataset for examples containing classes for this split
201
- logger.info(f"Filtering data for {subset_name} genes in split {iteration_num}")
202
- subset_data = data.filter(if_contains_subset_label, num_proc=num_proc)
203
  logger.info(
204
- f"Filtered {round((1-len(subset_data)/len(data))*100)}%; {len(subset_data)} remain\n"
 
 
 
 
 
205
  )
206
-
207
- # balance gene subsets if train
208
- if (subset_name == "train") and (balance is True):
209
- subset_data, label_dict_subset = balance_gene_split(
210
- subset_data, label_dict_subset, num_proc
211
- )
212
 
213
  # subsample to max_ncells
214
- subset_data = downsample_and_shuffle(subset_data, max_ncells, None, None)
 
215
 
216
  # relabel genes for this split
217
- def subset_classes_to_ids(example):
218
  example["labels"] = [
219
- label_dict_subset.get(token_id, -100) for token_id in example["input_ids"]
220
  ]
221
  return example
222
 
223
- subset_data = subset_data.map(subset_classes_to_ids, num_proc=num_proc)
 
 
 
 
224
 
225
- return subset_data
 
226
 
 
227
 
228
- def prep_gene_classifier_all_data(
229
- data, targets, labels, max_ncells, num_proc, balance=False
230
- ):
231
  targets = np.array(targets)
232
  labels = np.array(labels)
233
  label_dict_train = dict(zip(targets, labels))
@@ -245,11 +208,6 @@ def prep_gene_classifier_all_data(
245
  f"Filtered {round((1-len(train_data)/len(data))*100)}%; {len(train_data)} remain\n"
246
  )
247
 
248
- if balance is True:
249
- train_data, label_dict_train = balance_gene_split(
250
- train_data, label_dict_train, num_proc
251
- )
252
-
253
  # subsample to max_ncells
254
  train_data = downsample_and_shuffle(train_data, max_ncells, None, None)
255
 
@@ -265,145 +223,6 @@ def prep_gene_classifier_all_data(
265
  return train_data
266
 
267
 
268
- def balance_gene_split(subset_data, label_dict_subset, num_proc):
269
- # count occurrence of genes in each label category
270
- label0_counts, label1_counts = count_genes_for_balancing(
271
- subset_data, label_dict_subset, num_proc
272
- )
273
- label_ratio_0to1 = label0_counts / label1_counts
274
-
275
- if 8 / 10 <= label_ratio_0to1 <= 10 / 8:
276
- # gene sets already balanced
277
- logger.info(
278
- "Gene sets were already balanced within 0.8-1.25 fold and did not require balancing.\n"
279
- )
280
- return subset_data, label_dict_subset
281
- else:
282
- label_ratio_0to1_orig = label_ratio_0to1 + 0
283
- label_dict_subset_orig = label_dict_subset.copy()
284
- # balance gene sets
285
- max_ntrials = 25
286
- boost = 1
287
- if label_ratio_0to1 > 10 / 8:
288
- # downsample label 0
289
- for i in range(max_ntrials):
290
- label0 = 0
291
- label0_genes = [k for k, v in label_dict_subset.items() if v == label0]
292
- label0_ngenes = len(label0_genes)
293
- label0_nremove = max(
294
- 1,
295
- int(
296
- np.floor(
297
- label0_ngenes - label0_ngenes / (label_ratio_0to1 * boost)
298
- )
299
- ),
300
- )
301
- random.seed(i)
302
- label0_remove_genes = random.sample(label0_genes, label0_nremove)
303
- label_dict_subset_new = {
304
- k: v
305
- for k, v in label_dict_subset.items()
306
- if k not in label0_remove_genes
307
- }
308
- label0_counts, label1_counts = count_genes_for_balancing(
309
- subset_data, label_dict_subset_new, num_proc
310
- )
311
- label_ratio_0to1 = label0_counts / label1_counts
312
- if 8 / 10 <= label_ratio_0to1 <= 10 / 8:
313
- # if gene sets now balanced, return new filtered data and new label_dict_subset
314
- return filter_data_balanced_genes(
315
- subset_data, label_dict_subset_new, num_proc
316
- )
317
- elif label_ratio_0to1 > 10 / 8:
318
- boost = boost * 1.1
319
- elif label_ratio_0to1 < 8 / 10:
320
- boost = boost * 0.9
321
- else:
322
- # downsample label 1
323
- for i in range(max_ntrials):
324
- label1 = 1
325
- label1_genes = [k for k, v in label_dict_subset.items() if v == label1]
326
- label1_ngenes = len(label1_genes)
327
- label1_nremove = max(
328
- 1,
329
- int(
330
- np.floor(
331
- label1_ngenes
332
- - label1_ngenes / ((1 / label_ratio_0to1) * boost)
333
- )
334
- ),
335
- )
336
- random.seed(i)
337
- label1_remove_genes = random.sample(label1_genes, label1_nremove)
338
- label_dict_subset_new = {
339
- k: v
340
- for k, v in label_dict_subset.items()
341
- if k not in label1_remove_genes
342
- }
343
- label0_counts, label1_counts = count_genes_for_balancing(
344
- subset_data, label_dict_subset_new, num_proc
345
- )
346
- label_ratio_0to1 = label0_counts / label1_counts
347
- if 8 / 10 <= label_ratio_0to1 <= 10 / 8:
348
- # if gene sets now balanced, return new filtered data and new label_dict_subset
349
- return filter_data_balanced_genes(
350
- subset_data, label_dict_subset_new, num_proc
351
- )
352
- elif label_ratio_0to1 < 8 / 10:
353
- boost = boost * 1.1
354
- elif label_ratio_0to1 > 10 / 8:
355
- boost = boost * 0.9
356
-
357
- assert i + 1 == max_ntrials
358
- if (label_ratio_0to1 <= label_ratio_0to1_orig < 8 / 10) or (
359
- 10 / 8 > label_ratio_0to1_orig >= label_ratio_0to1
360
- ):
361
- label_ratio_0to1 = label_ratio_0to1_orig
362
- label_dict_subset_new = label_dict_subset_orig
363
- logger.warning(
364
- f"Gene sets were not able to be balanced within 0.8-1.25 fold after {max_ntrials} trials. Imbalance level: {label_ratio_0to1}\n"
365
- )
366
- return filter_data_balanced_genes(subset_data, label_dict_subset_new, num_proc)
367
-
368
-
369
- def count_genes_for_balancing(subset_data, label_dict_subset, num_proc):
370
- def count_targets(example):
371
- labels = [
372
- label_dict_subset.get(token_id, -100) for token_id in example["input_ids"]
373
- ]
374
- counter_labels = Counter(labels)
375
- # get count of labels 0 or 1, or if absent, return 0
376
- example["labels_counts"] = [counter_labels.get(0, 0), counter_labels.get(1, 0)]
377
- return example
378
-
379
- subset_data = subset_data.map(count_targets, num_proc=num_proc)
380
-
381
- label0_counts = sum([counts[0] for counts in subset_data["labels_counts"]])
382
- label1_counts = sum([counts[1] for counts in subset_data["labels_counts"]])
383
-
384
- subset_data = subset_data.remove_columns("labels_counts")
385
-
386
- return label0_counts, label1_counts
387
-
388
-
389
- def filter_data_balanced_genes(subset_data, label_dict_subset, num_proc):
390
- # function to filter by whether contains labels
391
- def if_contains_subset_label(example):
392
- a = list(label_dict_subset.keys())
393
- b = example["input_ids"]
394
- return not set(a).isdisjoint(b)
395
-
396
- # filter dataset for examples containing classes for this split
397
- logger.info("Filtering data for balanced genes")
398
- subset_data_len_orig = len(subset_data)
399
- subset_data = subset_data.filter(if_contains_subset_label, num_proc=num_proc)
400
- logger.info(
401
- f"Filtered {round((1-len(subset_data)/subset_data_len_orig)*100)}%; {len(subset_data)} remain\n"
402
- )
403
-
404
- return subset_data, label_dict_subset
405
-
406
-
407
  def balance_attr_splits(
408
  data,
409
  attr_to_split,
@@ -490,7 +309,7 @@ def balance_attr_splits(
490
  exp_counts[cat] * sum(obs) / sum(exp_counts.values())
491
  for cat in all_categ
492
  ]
493
- pval = chisquare(f_obs=obs, f_exp=exp).pvalue
494
  train_attr_counts = str(obs_counts).strip("Counter(").strip(")")
495
  eval_attr_counts = str(exp_counts).strip("Counter(").strip(")")
496
  df_vals += [train_attr_counts, eval_attr_counts, pval]
@@ -604,45 +423,3 @@ def get_default_train_args(model, classifier, data, output_dir):
604
  training_args.update(default_training_args)
605
 
606
  return training_args, freeze_layers
607
-
608
-
609
- def load_best_model(directory, model_type, num_classes, mode="eval"):
610
- file_dict = dict()
611
- for subdir, dirs, files in os.walk(directory):
612
- for file in files:
613
- if file.endswith("result.json"):
614
- with open(f"{subdir}/{file}", "rb") as fp:
615
- result_json = json.load(fp)
616
- file_dict[f"{subdir}"] = result_json["eval_macro_f1"]
617
- file_df = pd.DataFrame(
618
- {"dir": file_dict.keys(), "eval_macro_f1": file_dict.values()}
619
- )
620
- model_superdir = (
621
- "run-"
622
- + file_df.iloc[file_df["eval_macro_f1"].idxmax()]["dir"]
623
- .split("_objective_")[2]
624
- .split("_")[0]
625
- )
626
-
627
- for subdir, dirs, files in os.walk(f"{directory}/{model_superdir}"):
628
- for file in files:
629
- if file.endswith("model.safetensors"):
630
- model = pu.load_model(model_type, num_classes, f"{subdir}", mode)
631
- return model
632
-
633
-
634
- class StratifiedKFold3(StratifiedKFold):
635
- def split(self, targets, labels, test_ratio=0.5, groups=None):
636
- s = super().split(targets, labels, groups)
637
- for train_indxs, test_indxs in s:
638
- if test_ratio == 0:
639
- yield train_indxs, test_indxs, None
640
- else:
641
- labels_test = np.array(labels)[test_indxs]
642
- valid_indxs, test_indxs = train_test_split(
643
- test_indxs,
644
- stratify=labels_test,
645
- test_size=test_ratio,
646
- random_state=0,
647
- )
648
- yield train_indxs, valid_indxs, test_indxs
 
 
1
  import logging
 
2
  import random
3
  from collections import Counter, defaultdict
4
 
 
6
  import pandas as pd
7
  from scipy.stats import chisquare, ranksums
8
  from sklearn.metrics import accuracy_score, f1_score
 
9
 
10
  from . import perturber_utils as pu
11
 
 
133
  ]
134
 
135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  def prep_gene_classifier_split(
137
+ data, targets, labels, train_index, eval_index, max_ncells, iteration_num, num_proc
 
 
 
 
 
 
 
 
138
  ):
139
  # generate cross-validation splits
140
  targets = np.array(targets)
141
  labels = np.array(labels)
142
+ targets_train, targets_eval = targets[train_index], targets[eval_index]
143
+ labels_train, labels_eval = labels[train_index], labels[eval_index]
144
+ label_dict_train = dict(zip(targets_train, labels_train))
145
+ label_dict_eval = dict(zip(targets_eval, labels_eval))
146
 
147
  # function to filter by whether contains train or eval labels
148
+ def if_contains_train_label(example):
149
+ a = targets_train
150
+ b = example["input_ids"]
151
+ return not set(a).isdisjoint(b)
152
+
153
+ def if_contains_eval_label(example):
154
+ a = targets_eval
155
  b = example["input_ids"]
156
  return not set(a).isdisjoint(b)
157
 
158
  # filter dataset for examples containing classes for this split
159
+ logger.info(f"Filtering training data for genes in split {iteration_num}")
160
+ train_data = data.filter(if_contains_train_label, num_proc=num_proc)
161
  logger.info(
162
+ f"Filtered {round((1-len(train_data)/len(data))*100)}%; {len(train_data)} remain\n"
163
+ )
164
+ logger.info(f"Filtering evalation data for genes in split {iteration_num}")
165
+ eval_data = data.filter(if_contains_eval_label, num_proc=num_proc)
166
+ logger.info(
167
+ f"Filtered {round((1-len(eval_data)/len(data))*100)}%; {len(eval_data)} remain\n"
168
  )
 
 
 
 
 
 
169
 
170
  # subsample to max_ncells
171
+ train_data = downsample_and_shuffle(train_data, max_ncells, None, None)
172
+ eval_data = downsample_and_shuffle(eval_data, max_ncells, None, None)
173
 
174
  # relabel genes for this split
175
+ def train_classes_to_ids(example):
176
  example["labels"] = [
177
+ label_dict_train.get(token_id, -100) for token_id in example["input_ids"]
178
  ]
179
  return example
180
 
181
+ def eval_classes_to_ids(example):
182
+ example["labels"] = [
183
+ label_dict_eval.get(token_id, -100) for token_id in example["input_ids"]
184
+ ]
185
+ return example
186
 
187
+ train_data = train_data.map(train_classes_to_ids, num_proc=num_proc)
188
+ eval_data = eval_data.map(eval_classes_to_ids, num_proc=num_proc)
189
 
190
+ return train_data, eval_data
191
 
192
+
193
+ def prep_gene_classifier_all_data(data, targets, labels, max_ncells, num_proc):
 
194
  targets = np.array(targets)
195
  labels = np.array(labels)
196
  label_dict_train = dict(zip(targets, labels))
 
208
  f"Filtered {round((1-len(train_data)/len(data))*100)}%; {len(train_data)} remain\n"
209
  )
210
 
 
 
 
 
 
211
  # subsample to max_ncells
212
  train_data = downsample_and_shuffle(train_data, max_ncells, None, None)
213
 
 
223
  return train_data
224
 
225
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  def balance_attr_splits(
227
  data,
228
  attr_to_split,
 
309
  exp_counts[cat] * sum(obs) / sum(exp_counts.values())
310
  for cat in all_categ
311
  ]
312
+ chisquare(f_obs=obs, f_exp=exp).pvalue
313
  train_attr_counts = str(obs_counts).strip("Counter(").strip(")")
314
  eval_attr_counts = str(exp_counts).strip("Counter(").strip(")")
315
  df_vals += [train_attr_counts, eval_attr_counts, pval]
 
423
  training_args.update(default_training_args)
424
 
425
  return training_args, freeze_layers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
geneformer/collator_for_classification.py CHANGED
@@ -1,22 +1,24 @@
1
  """
2
  Geneformer collator for gene and cell classification.
 
3
  Huggingface data collator modified to accommodate single-cell transcriptomics data for gene and cell classification.
4
  """
5
-
 
6
  import warnings
7
  from enum import Enum
8
  from typing import Dict, List, Optional, Union
9
 
10
- import numpy as np
11
- import torch
12
  from transformers import (
13
- BatchEncoding,
14
  DataCollatorForTokenClassification,
15
  SpecialTokensMixin,
 
16
  )
17
  from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
18
  from transformers.utils.generic import _is_tensorflow, _is_torch
19
 
 
 
20
  EncodedInput = List[int]
21
  logger = logging.get_logger(__name__)
22
  VERY_LARGE_INTEGER = int(
@@ -28,7 +30,6 @@ LARGE_INTEGER = int(
28
 
29
  # precollator functions
30
 
31
-
32
  class ExplicitEnum(Enum):
33
  """
34
  Enum with more explicit error message for missing values.
@@ -41,7 +42,6 @@ class ExplicitEnum(Enum):
41
  % (value, cls.__name__, str(list(cls._value2member_map_.keys())))
42
  )
43
 
44
-
45
  class TruncationStrategy(ExplicitEnum):
46
  """
47
  Possible values for the ``truncation`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
@@ -54,6 +54,7 @@ class TruncationStrategy(ExplicitEnum):
54
  DO_NOT_TRUNCATE = "do_not_truncate"
55
 
56
 
 
57
  class PaddingStrategy(ExplicitEnum):
58
  """
59
  Possible values for the ``padding`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for tab-completion
@@ -65,6 +66,7 @@ class PaddingStrategy(ExplicitEnum):
65
  DO_NOT_PAD = "do_not_pad"
66
 
67
 
 
68
  class TensorType(ExplicitEnum):
69
  """
70
  Possible values for the ``return_tensors`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
@@ -76,41 +78,21 @@ class TensorType(ExplicitEnum):
76
  NUMPY = "np"
77
  JAX = "jax"
78
 
79
-
80
  class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
81
- def __init__(self, *args, **kwargs) -> None:
82
- super().__init__(mask_token="<mask>", pad_token="<pad>")
83
-
84
- self.token_dictionary = kwargs.get("token_dictionary")
85
- self.padding_side = "right"
86
- self.model_input_names = ["input_ids"]
87
- self._mask_token_id = self.token_dictionary.get("<mask>")
88
- self._pad_token_id = self.token_dictionary.get("<pad>")
89
- self._all_special_ids = [
90
- self.token_dictionary.get("<mask>"),
91
- self.token_dictionary.get("<pad>"),
92
- ]
93
-
94
- @property
95
- def all_special_ids(self):
96
- return self._all_special_ids
97
-
98
- @property
99
- def mask_token_id(self):
100
- return self._mask_token_id
101
-
102
- @property
103
- def pad_token_id(self):
104
- return self._pad_token_id
105
 
106
  def _get_padding_truncation_strategies(
107
- self,
108
- padding=True,
109
- truncation=False,
110
- max_length=None,
111
- pad_to_multiple_of=None,
112
- verbose=True,
113
- **kwargs,
114
  ):
115
  """
116
  Find the correct padding/truncation strategy with backward compatibility for old arguments (truncation_strategy
@@ -123,9 +105,7 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
123
  # If you only set max_length, it activates truncation for max_length
124
  if max_length is not None and padding is False and truncation is False:
125
  if verbose:
126
- if not self.deprecation_warnings.get(
127
- "Truncation-not-explicitly-activated", False
128
- ):
129
  logger.warning(
130
  "Truncation was not explicitly activated but `max_length` is provided a specific value, "
131
  "please use `truncation=True` to explicitly truncate examples to max length. "
@@ -153,9 +133,7 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
153
  padding_strategy = PaddingStrategy.MAX_LENGTH
154
  elif padding is not False:
155
  if padding is True:
156
- padding_strategy = (
157
- PaddingStrategy.LONGEST
158
- ) # Default to pad to the longest sequence in the batch
159
  elif not isinstance(padding, PaddingStrategy):
160
  padding_strategy = PaddingStrategy(padding)
161
  elif isinstance(padding, PaddingStrategy):
@@ -195,9 +173,7 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
195
  if padding_strategy == PaddingStrategy.MAX_LENGTH:
196
  if self.model_max_length > LARGE_INTEGER:
197
  if verbose:
198
- if not self.deprecation_warnings.get(
199
- "Asking-to-pad-to-max_length", False
200
- ):
201
  logger.warning(
202
  "Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. "
203
  "Default to no padding."
@@ -210,24 +186,18 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
210
  if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE:
211
  if self.model_max_length > LARGE_INTEGER:
212
  if verbose:
213
- if not self.deprecation_warnings.get(
214
- "Asking-to-truncate-to-max_length", False
215
- ):
216
  logger.warning(
217
  "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. "
218
  "Default to no truncation."
219
  )
220
- self.deprecation_warnings[
221
- "Asking-to-truncate-to-max_length"
222
- ] = True
223
  truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
224
  else:
225
  max_length = self.model_max_length
226
 
227
  # Test if we have a padding token
228
- if padding_strategy != PaddingStrategy.DO_NOT_PAD and (
229
- not self.pad_token or self.pad_token_id < 0
230
- ):
231
  raise ValueError(
232
  "Asking to pad but the tokenizer does not have a padding token. "
233
  "Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` "
@@ -258,7 +228,7 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
258
  Dict[str, List[EncodedInput]],
259
  List[Dict[str, EncodedInput]],
260
  ],
261
- class_type, # options: "gene" or "cell"
262
  padding: Union[bool, str, PaddingStrategy] = True,
263
  max_length: Optional[int] = None,
264
  pad_to_multiple_of: Optional[int] = None,
@@ -269,23 +239,29 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
269
  """
270
  Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length
271
  in the batch.
 
272
  Padding side (left/right) padding token ids are defined at the tokenizer level (with ``self.padding_side``,
273
  ``self.pad_token_id`` and ``self.pad_token_type_id``)
 
274
  .. note::
 
275
  If the ``encoded_inputs`` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the
276
  result will use the same type unless you provide a different tensor type with ``return_tensors``. In the
277
  case of PyTorch tensors, you will lose the specific device of your tensors however.
 
278
  Args:
279
  encoded_inputs (:class:`~transformers.BatchEncoding`, list of :class:`~transformers.BatchEncoding`, :obj:`Dict[str, List[int]]`, :obj:`Dict[str, List[List[int]]` or :obj:`List[Dict[str, List[int]]]`):
280
  Tokenized inputs. Can represent one input (:class:`~transformers.BatchEncoding` or :obj:`Dict[str,
281
  List[int]]`) or a batch of tokenized inputs (list of :class:`~transformers.BatchEncoding`, `Dict[str,
282
  List[List[int]]]` or `List[Dict[str, List[int]]]`) so you can use this method during preprocessing as
283
  well as in a PyTorch Dataloader collate function.
 
284
  Instead of :obj:`List[int]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors),
285
  see the note above for the return type.
286
  padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
287
  Select a strategy to pad the returned sequences (according to the model's padding side and padding
288
  index) among:
 
289
  * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
290
  single sequence if provided).
291
  * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
@@ -296,14 +272,17 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
296
  Maximum length of the returned list and optionally padding length (see above).
297
  pad_to_multiple_of (:obj:`int`, `optional`):
298
  If set will pad the sequence to a multiple of the provided value.
 
299
  This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
300
  >= 7.5 (Volta).
301
  return_attention_mask (:obj:`bool`, `optional`):
302
  Whether to return the attention mask. If left to the default, will return the attention mask according
303
  to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute.
 
304
  `What are attention masks? <../glossary.html#attention-mask>`__
305
  return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`):
306
  If set, will return tensors instead of list of python integers. Acceptable values are:
 
307
  * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
308
  * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
309
  * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
@@ -312,13 +291,8 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
312
  """
313
  # If we have a list of dicts, let's convert it in a dict of lists
314
  # We do this to allow using this method as a collate_fn function in PyTorch Dataloader
315
- if isinstance(encoded_inputs, (list, tuple)) and isinstance(
316
- encoded_inputs[0], (dict, BatchEncoding)
317
- ):
318
- encoded_inputs = {
319
- key: [example[key] for example in encoded_inputs]
320
- for key in encoded_inputs[0].keys()
321
- }
322
 
323
  # The model's main input name, usually `input_ids`, has be passed for padding
324
  if self.model_input_names[0] not in encoded_inputs:
@@ -412,7 +386,7 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
412
  def _pad(
413
  self,
414
  encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
415
- class_type, # options: "gene" or "cell"
416
  max_length: Optional[int] = None,
417
  padding_strategy: PaddingStrategy = PaddingStrategy.LONGEST,
418
  pad_to_multiple_of: Optional[int] = None,
@@ -420,15 +394,18 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
420
  ) -> dict:
421
  """
422
  Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
 
423
  Args:
424
  encoded_inputs: Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
425
  max_length: maximum length of the returned list and optionally padding length (see below).
426
  Will truncate by taking into account the special tokens.
427
  padding_strategy: PaddingStrategy to use for padding.
 
428
  - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
429
  - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
430
  - PaddingStrategy.DO_NOT_PAD: Do not pad
431
  The tokenizer padding sides are defined in self.padding_side:
 
432
  - 'left': pads on the left of the sequences
433
  - 'right': pads on the right of the sequences
434
  pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
@@ -445,73 +422,46 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
445
  if padding_strategy == PaddingStrategy.LONGEST:
446
  max_length = len(required_input)
447
 
448
- if (
449
- max_length is not None
450
- and pad_to_multiple_of is not None
451
- and (max_length % pad_to_multiple_of != 0)
452
- ):
453
  max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
454
 
455
- needs_to_be_padded = (
456
- padding_strategy != PaddingStrategy.DO_NOT_PAD
457
- and len(required_input) != max_length
458
- )
459
 
460
  if needs_to_be_padded:
461
  difference = max_length - len(required_input)
462
  if self.padding_side == "right":
463
  if return_attention_mask:
464
- encoded_inputs["attention_mask"] = [1] * len(required_input) + [
465
- 0
466
- ] * difference
467
  if "token_type_ids" in encoded_inputs:
468
  encoded_inputs["token_type_ids"] = (
469
- encoded_inputs["token_type_ids"]
470
- + [self.pad_token_type_id] * difference
471
  )
472
  if "special_tokens_mask" in encoded_inputs:
473
- encoded_inputs["special_tokens_mask"] = (
474
- encoded_inputs["special_tokens_mask"] + [1] * difference
475
- )
476
- encoded_inputs[self.model_input_names[0]] = (
477
- required_input + [self.pad_token_id] * difference
478
- )
479
  if class_type == "gene":
480
- encoded_inputs["labels"] = (
481
- encoded_inputs["labels"] + [-100] * difference
482
- )
483
  elif self.padding_side == "left":
484
  if return_attention_mask:
485
- encoded_inputs["attention_mask"] = [0] * difference + [1] * len(
486
- required_input
487
- )
488
  if "token_type_ids" in encoded_inputs:
489
- encoded_inputs["token_type_ids"] = [
490
- self.pad_token_type_id
491
- ] * difference + encoded_inputs["token_type_ids"]
492
  if "special_tokens_mask" in encoded_inputs:
493
- encoded_inputs["special_tokens_mask"] = [
494
- 1
495
- ] * difference + encoded_inputs["special_tokens_mask"]
496
- encoded_inputs[self.model_input_names[0]] = [
497
- self.pad_token_id
498
- ] * difference + required_input
499
  if class_type == "gene":
500
- encoded_inputs["labels"] = [-100] * difference + encoded_inputs[
501
- "labels"
502
- ]
503
  else:
504
  raise ValueError("Invalid padding strategy:" + str(self.padding_side))
505
  elif return_attention_mask and "attention_mask" not in encoded_inputs:
506
  encoded_inputs["attention_mask"] = [1] * len(required_input)
507
-
508
  return encoded_inputs
509
 
510
  def get_special_tokens_mask(
511
- self,
512
- token_ids_0: List[int],
513
- token_ids_1: Optional[List[int]] = None,
514
- already_has_special_tokens: bool = False,
515
  ) -> List[int]:
516
  """
517
  Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
@@ -535,15 +485,11 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
535
 
536
  all_special_ids = self.all_special_ids # cache the property
537
 
538
- special_tokens_mask = [
539
- 1 if token in all_special_ids else 0 for token in token_ids_0
540
- ]
541
 
542
  return special_tokens_mask
543
 
544
- def convert_tokens_to_ids(
545
- self, tokens: Union[str, List[str]]
546
- ) -> Union[int, List[int]]:
547
  """
548
  Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the
549
  vocabulary.
@@ -567,15 +513,14 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
567
  if token is None:
568
  return None
569
 
570
- return self.token_dictionary.get(token)
571
 
572
  def __len__(self):
573
- return len(self.token_dictionary)
574
 
575
 
576
  # collator functions
577
 
578
-
579
  class DataCollatorForGeneClassification(DataCollatorForTokenClassification):
580
  """
581
  Data collator that will dynamically pad the inputs received, as well as the labels.
@@ -601,33 +546,25 @@ class DataCollatorForGeneClassification(DataCollatorForTokenClassification):
601
  The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
602
  """
603
 
 
604
  class_type = "gene"
605
  padding: Union[bool, str, PaddingStrategy] = True
606
  max_length: Optional[int] = None
607
  pad_to_multiple_of: Optional[int] = None
608
  label_pad_token_id: int = -100
609
-
610
  def __init__(self, *args, **kwargs) -> None:
611
- self.token_dictionary = kwargs.pop("token_dictionary")
612
  super().__init__(
613
- tokenizer=PrecollatorForGeneAndCellClassification(
614
- token_dictionary=self.token_dictionary
615
- ),
616
  padding=self.padding,
617
  max_length=self.max_length,
618
  pad_to_multiple_of=self.pad_to_multiple_of,
619
  label_pad_token_id=self.label_pad_token_id,
620
- *args,
621
- **kwargs,
622
- )
623
 
624
  def _prepare_batch(self, features):
625
  label_name = "label" if "label" in features[0].keys() else "labels"
626
- labels = (
627
- [feature[label_name] for feature in features]
628
- if label_name in features[0].keys()
629
- else None
630
- )
631
  batch = self.tokenizer.pad(
632
  features,
633
  class_type=self.class_type,
@@ -637,31 +574,29 @@ class DataCollatorForGeneClassification(DataCollatorForTokenClassification):
637
  return_tensors="pt",
638
  )
639
  return batch
640
-
641
  def __call__(self, features):
642
  batch = self._prepare_batch(features)
643
 
644
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
645
  return batch
646
 
647
-
648
  class DataCollatorForCellClassification(DataCollatorForGeneClassification):
 
649
  class_type = "cell"
650
 
651
  def _prepare_batch(self, features):
 
652
  batch = super()._prepare_batch(features)
653
-
654
  # Special handling for labels.
655
  # Ensure that tensor is created with the correct type
656
  # (it should be automatically the case, but let's make sure of it.)
657
  first = features[0]
658
  if "label" in first and first["label"] is not None:
659
- label = (
660
- first["label"].item()
661
- if isinstance(first["label"], torch.Tensor)
662
- else first["label"]
663
- )
664
  dtype = torch.long if isinstance(label, int) else torch.float
665
  batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
666
-
667
  return batch
 
1
  """
2
  Geneformer collator for gene and cell classification.
3
+
4
  Huggingface data collator modified to accommodate single-cell transcriptomics data for gene and cell classification.
5
  """
6
+ import numpy as np
7
+ import torch
8
  import warnings
9
  from enum import Enum
10
  from typing import Dict, List, Optional, Union
11
 
 
 
12
  from transformers import (
 
13
  DataCollatorForTokenClassification,
14
  SpecialTokensMixin,
15
+ BatchEncoding,
16
  )
17
  from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
18
  from transformers.utils.generic import _is_tensorflow, _is_torch
19
 
20
+ from .pretrainer import token_dictionary
21
+
22
  EncodedInput = List[int]
23
  logger = logging.get_logger(__name__)
24
  VERY_LARGE_INTEGER = int(
 
30
 
31
  # precollator functions
32
 
 
33
  class ExplicitEnum(Enum):
34
  """
35
  Enum with more explicit error message for missing values.
 
42
  % (value, cls.__name__, str(list(cls._value2member_map_.keys())))
43
  )
44
 
 
45
  class TruncationStrategy(ExplicitEnum):
46
  """
47
  Possible values for the ``truncation`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
 
54
  DO_NOT_TRUNCATE = "do_not_truncate"
55
 
56
 
57
+
58
  class PaddingStrategy(ExplicitEnum):
59
  """
60
  Possible values for the ``padding`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for tab-completion
 
66
  DO_NOT_PAD = "do_not_pad"
67
 
68
 
69
+
70
  class TensorType(ExplicitEnum):
71
  """
72
  Possible values for the ``return_tensors`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
 
78
  NUMPY = "np"
79
  JAX = "jax"
80
 
81
+
82
  class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
83
+ mask_token = "<mask>"
84
+ mask_token_id = token_dictionary.get("<mask>")
85
+ pad_token = "<pad>"
86
+ pad_token_id = token_dictionary.get("<pad>")
87
+ padding_side = "right"
88
+ all_special_ids = [
89
+ token_dictionary.get("<mask>"),
90
+ token_dictionary.get("<pad>")
91
+ ]
92
+ model_input_names = ["input_ids"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  def _get_padding_truncation_strategies(
95
+ self, padding=True, truncation=False, max_length=None, pad_to_multiple_of=None, verbose=True, **kwargs
 
 
 
 
 
 
96
  ):
97
  """
98
  Find the correct padding/truncation strategy with backward compatibility for old arguments (truncation_strategy
 
105
  # If you only set max_length, it activates truncation for max_length
106
  if max_length is not None and padding is False and truncation is False:
107
  if verbose:
108
+ if not self.deprecation_warnings.get("Truncation-not-explicitly-activated", False):
 
 
109
  logger.warning(
110
  "Truncation was not explicitly activated but `max_length` is provided a specific value, "
111
  "please use `truncation=True` to explicitly truncate examples to max length. "
 
133
  padding_strategy = PaddingStrategy.MAX_LENGTH
134
  elif padding is not False:
135
  if padding is True:
136
+ padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch
 
 
137
  elif not isinstance(padding, PaddingStrategy):
138
  padding_strategy = PaddingStrategy(padding)
139
  elif isinstance(padding, PaddingStrategy):
 
173
  if padding_strategy == PaddingStrategy.MAX_LENGTH:
174
  if self.model_max_length > LARGE_INTEGER:
175
  if verbose:
176
+ if not self.deprecation_warnings.get("Asking-to-pad-to-max_length", False):
 
 
177
  logger.warning(
178
  "Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. "
179
  "Default to no padding."
 
186
  if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE:
187
  if self.model_max_length > LARGE_INTEGER:
188
  if verbose:
189
+ if not self.deprecation_warnings.get("Asking-to-truncate-to-max_length", False):
 
 
190
  logger.warning(
191
  "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. "
192
  "Default to no truncation."
193
  )
194
+ self.deprecation_warnings["Asking-to-truncate-to-max_length"] = True
 
 
195
  truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
196
  else:
197
  max_length = self.model_max_length
198
 
199
  # Test if we have a padding token
200
+ if padding_strategy != PaddingStrategy.DO_NOT_PAD and (not self.pad_token or self.pad_token_id < 0):
 
 
201
  raise ValueError(
202
  "Asking to pad but the tokenizer does not have a padding token. "
203
  "Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` "
 
228
  Dict[str, List[EncodedInput]],
229
  List[Dict[str, EncodedInput]],
230
  ],
231
+ class_type, # options: "gene" or "cell"
232
  padding: Union[bool, str, PaddingStrategy] = True,
233
  max_length: Optional[int] = None,
234
  pad_to_multiple_of: Optional[int] = None,
 
239
  """
240
  Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length
241
  in the batch.
242
+
243
  Padding side (left/right) padding token ids are defined at the tokenizer level (with ``self.padding_side``,
244
  ``self.pad_token_id`` and ``self.pad_token_type_id``)
245
+
246
  .. note::
247
+
248
  If the ``encoded_inputs`` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the
249
  result will use the same type unless you provide a different tensor type with ``return_tensors``. In the
250
  case of PyTorch tensors, you will lose the specific device of your tensors however.
251
+
252
  Args:
253
  encoded_inputs (:class:`~transformers.BatchEncoding`, list of :class:`~transformers.BatchEncoding`, :obj:`Dict[str, List[int]]`, :obj:`Dict[str, List[List[int]]` or :obj:`List[Dict[str, List[int]]]`):
254
  Tokenized inputs. Can represent one input (:class:`~transformers.BatchEncoding` or :obj:`Dict[str,
255
  List[int]]`) or a batch of tokenized inputs (list of :class:`~transformers.BatchEncoding`, `Dict[str,
256
  List[List[int]]]` or `List[Dict[str, List[int]]]`) so you can use this method during preprocessing as
257
  well as in a PyTorch Dataloader collate function.
258
+
259
  Instead of :obj:`List[int]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors),
260
  see the note above for the return type.
261
  padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
262
  Select a strategy to pad the returned sequences (according to the model's padding side and padding
263
  index) among:
264
+
265
  * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
266
  single sequence if provided).
267
  * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
 
272
  Maximum length of the returned list and optionally padding length (see above).
273
  pad_to_multiple_of (:obj:`int`, `optional`):
274
  If set will pad the sequence to a multiple of the provided value.
275
+
276
  This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
277
  >= 7.5 (Volta).
278
  return_attention_mask (:obj:`bool`, `optional`):
279
  Whether to return the attention mask. If left to the default, will return the attention mask according
280
  to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute.
281
+
282
  `What are attention masks? <../glossary.html#attention-mask>`__
283
  return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`):
284
  If set, will return tensors instead of list of python integers. Acceptable values are:
285
+
286
  * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
287
  * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
288
  * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
 
291
  """
292
  # If we have a list of dicts, let's convert it in a dict of lists
293
  # We do this to allow using this method as a collate_fn function in PyTorch Dataloader
294
+ if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], (dict, BatchEncoding)):
295
+ encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()}
 
 
 
 
 
296
 
297
  # The model's main input name, usually `input_ids`, has be passed for padding
298
  if self.model_input_names[0] not in encoded_inputs:
 
386
  def _pad(
387
  self,
388
  encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
389
+ class_type, # options: "gene" or "cell"
390
  max_length: Optional[int] = None,
391
  padding_strategy: PaddingStrategy = PaddingStrategy.LONGEST,
392
  pad_to_multiple_of: Optional[int] = None,
 
394
  ) -> dict:
395
  """
396
  Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
397
+
398
  Args:
399
  encoded_inputs: Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
400
  max_length: maximum length of the returned list and optionally padding length (see below).
401
  Will truncate by taking into account the special tokens.
402
  padding_strategy: PaddingStrategy to use for padding.
403
+
404
  - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
405
  - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
406
  - PaddingStrategy.DO_NOT_PAD: Do not pad
407
  The tokenizer padding sides are defined in self.padding_side:
408
+
409
  - 'left': pads on the left of the sequences
410
  - 'right': pads on the right of the sequences
411
  pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
 
422
  if padding_strategy == PaddingStrategy.LONGEST:
423
  max_length = len(required_input)
424
 
425
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
 
 
 
 
426
  max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
427
 
428
+ needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
 
 
 
429
 
430
  if needs_to_be_padded:
431
  difference = max_length - len(required_input)
432
  if self.padding_side == "right":
433
  if return_attention_mask:
434
+ encoded_inputs["attention_mask"] = [1] * len(required_input) + [0] * difference
 
 
435
  if "token_type_ids" in encoded_inputs:
436
  encoded_inputs["token_type_ids"] = (
437
+ encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
 
438
  )
439
  if "special_tokens_mask" in encoded_inputs:
440
+ encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
441
+ encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
 
 
 
 
442
  if class_type == "gene":
443
+ encoded_inputs["labels"] = encoded_inputs["labels"] + [-100] * difference
 
 
444
  elif self.padding_side == "left":
445
  if return_attention_mask:
446
+ encoded_inputs["attention_mask"] = [0] * difference + [1] * len(required_input)
 
 
447
  if "token_type_ids" in encoded_inputs:
448
+ encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
449
+ "token_type_ids"
450
+ ]
451
  if "special_tokens_mask" in encoded_inputs:
452
+ encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
453
+ encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
 
 
 
 
454
  if class_type == "gene":
455
+ encoded_inputs["labels"] = [-100] * difference + encoded_inputs["labels"]
 
 
456
  else:
457
  raise ValueError("Invalid padding strategy:" + str(self.padding_side))
458
  elif return_attention_mask and "attention_mask" not in encoded_inputs:
459
  encoded_inputs["attention_mask"] = [1] * len(required_input)
460
+
461
  return encoded_inputs
462
 
463
  def get_special_tokens_mask(
464
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
 
 
 
465
  ) -> List[int]:
466
  """
467
  Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
 
485
 
486
  all_special_ids = self.all_special_ids # cache the property
487
 
488
+ special_tokens_mask = [1 if token in all_special_ids else 0 for token in token_ids_0]
 
 
489
 
490
  return special_tokens_mask
491
 
492
+ def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
 
 
493
  """
494
  Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the
495
  vocabulary.
 
513
  if token is None:
514
  return None
515
 
516
+ return token_dictionary.get(token)
517
 
518
  def __len__(self):
519
+ return len(token_dictionary)
520
 
521
 
522
  # collator functions
523
 
 
524
  class DataCollatorForGeneClassification(DataCollatorForTokenClassification):
525
  """
526
  Data collator that will dynamically pad the inputs received, as well as the labels.
 
546
  The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
547
  """
548
 
549
+ tokenizer = PrecollatorForGeneAndCellClassification()
550
  class_type = "gene"
551
  padding: Union[bool, str, PaddingStrategy] = True
552
  max_length: Optional[int] = None
553
  pad_to_multiple_of: Optional[int] = None
554
  label_pad_token_id: int = -100
555
+
556
  def __init__(self, *args, **kwargs) -> None:
 
557
  super().__init__(
558
+ tokenizer=self.tokenizer,
 
 
559
  padding=self.padding,
560
  max_length=self.max_length,
561
  pad_to_multiple_of=self.pad_to_multiple_of,
562
  label_pad_token_id=self.label_pad_token_id,
563
+ *args, **kwargs)
 
 
564
 
565
  def _prepare_batch(self, features):
566
  label_name = "label" if "label" in features[0].keys() else "labels"
567
+ labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
 
 
 
 
568
  batch = self.tokenizer.pad(
569
  features,
570
  class_type=self.class_type,
 
574
  return_tensors="pt",
575
  )
576
  return batch
577
+
578
  def __call__(self, features):
579
  batch = self._prepare_batch(features)
580
 
581
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
582
  return batch
583
 
584
+
585
  class DataCollatorForCellClassification(DataCollatorForGeneClassification):
586
+
587
  class_type = "cell"
588
 
589
  def _prepare_batch(self, features):
590
+
591
  batch = super()._prepare_batch(features)
592
+
593
  # Special handling for labels.
594
  # Ensure that tensor is created with the correct type
595
  # (it should be automatically the case, but let's make sure of it.)
596
  first = features[0]
597
  if "label" in first and first["label"] is not None:
598
+ label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"]
 
 
 
 
599
  dtype = torch.long if isinstance(label, int) else torch.float
600
  batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
601
+
602
  return batch
geneformer/emb_extractor.py CHANGED
@@ -24,8 +24,8 @@ import torch
24
  from tdigest import TDigest
25
  from tqdm.auto import trange
26
 
27
- from . import TOKEN_DICTIONARY_FILE
28
  from . import perturber_utils as pu
 
29
 
30
  logger = logging.getLogger(__name__)
31
 
@@ -38,8 +38,6 @@ def get_embs(
38
  layer_to_quant,
39
  pad_token_id,
40
  forward_batch_size,
41
- token_gene_dict,
42
- special_token=False,
43
  summary_stat=None,
44
  silent=False,
45
  ):
@@ -49,8 +47,10 @@ def get_embs(
49
  if summary_stat is None:
50
  embs_list = []
51
  elif summary_stat is not None:
52
- # get # of emb dims
53
- emb_dims = pu.get_model_emb_dims(model)
 
 
54
  if emb_mode == "cell":
55
  # initiate tdigests for # of emb dims
56
  embs_tdigests = [TDigest() for _ in range(emb_dims)]
@@ -67,27 +67,6 @@ def get_embs(
67
  k: [TDigest() for _ in range(emb_dims)] for k in gene_set
68
  }
69
 
70
- # Check if CLS and EOS token is present in the token dictionary
71
- cls_present = any("<cls>" in value for value in token_gene_dict.values())
72
- eos_present = any("<eos>" in value for value in token_gene_dict.values())
73
- if emb_mode == "cls":
74
- assert cls_present, "<cls> token missing in token dictionary"
75
- # Check to make sure that the first token of the filtered input data is cls token
76
- gene_token_dict = {v: k for k, v in token_gene_dict.items()}
77
- cls_token_id = gene_token_dict["<cls>"]
78
- assert (
79
- filtered_input_data["input_ids"][0][0] == cls_token_id
80
- ), "First token is not <cls> token value"
81
- elif emb_mode == "cell":
82
- if cls_present:
83
- logger.warning(
84
- "CLS token present in token dictionary, excluding from average."
85
- )
86
- if eos_present:
87
- logger.warning(
88
- "EOS token present in token dictionary, excluding from average."
89
- )
90
-
91
  overall_max_len = 0
92
 
93
  for i in trange(0, total_batch_length, forward_batch_size, leave=(not silent)):
@@ -113,14 +92,7 @@ def get_embs(
113
  embs_i = outputs.hidden_states[layer_to_quant]
114
 
115
  if emb_mode == "cell":
116
- if cls_present:
117
- non_cls_embs = embs_i[:, 1:, :] # Get all layers except the embs
118
- if eos_present:
119
- mean_embs = pu.mean_nonpadding_embs(non_cls_embs, original_lens - 2)
120
- else:
121
- mean_embs = pu.mean_nonpadding_embs(non_cls_embs, original_lens - 1)
122
- else:
123
- mean_embs = pu.mean_nonpadding_embs(embs_i, original_lens)
124
  if summary_stat is None:
125
  embs_list.append(mean_embs)
126
  elif summary_stat is not None:
@@ -149,12 +121,6 @@ def get_embs(
149
  accumulate_tdigests(
150
  embs_tdigests_dict[int(k)], dict_h[k], emb_dims
151
  )
152
- del embs_h
153
- del dict_h
154
- elif emb_mode == "cls":
155
- cls_embs = embs_i[:, 0, :].clone().detach() # CLS token layer
156
- embs_list.append(cls_embs)
157
- del cls_embs
158
 
159
  overall_max_len = max(overall_max_len, max_len)
160
  del outputs
@@ -165,7 +131,7 @@ def get_embs(
165
  torch.cuda.empty_cache()
166
 
167
  if summary_stat is None:
168
- if (emb_mode == "cell") or (emb_mode == "cls"):
169
  embs_stack = torch.cat(embs_list, dim=0)
170
  elif emb_mode == "gene":
171
  embs_stack = pu.pad_tensor_list(
@@ -243,6 +209,14 @@ def tdigest_median(embs_tdigests, emb_dims):
243
  return [embs_tdigests[i].percentile(50) for i in range(emb_dims)]
244
 
245
 
 
 
 
 
 
 
 
 
246
  def label_cell_embs(embs, downsampled_data, emb_labels):
247
  embs_df = pd.DataFrame(embs.cpu().numpy())
248
  if emb_labels is not None:
@@ -278,7 +252,7 @@ def label_gene_embs(embs, downsampled_data, token_gene_dict):
278
  return embs_df
279
 
280
 
281
- def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict, seed=0):
282
  only_embs_df = embs_df.iloc[:, :emb_dims]
283
  only_embs_df.index = pd.RangeIndex(0, only_embs_df.shape[0], name=None).astype(str)
284
  only_embs_df.columns = pd.RangeIndex(0, only_embs_df.shape[1], name=None).astype(
@@ -288,27 +262,15 @@ def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict, seed=0):
288
  obs_dict = {"cell_id": list(only_embs_df.index), f"{label}": list(embs_df[label])}
289
  adata = anndata.AnnData(X=only_embs_df, obs=obs_dict, var=vars_dict)
290
  sc.tl.pca(adata, svd_solver="arpack")
291
- sc.pp.neighbors(adata, random_state=seed)
292
- sc.tl.umap(adata, random_state=seed)
293
  sns.set(rc={"figure.figsize": (10, 10)}, font_scale=2.3)
294
  sns.set_style("white")
295
- default_kwargs_dict = {"size": 200}
296
  if kwargs_dict is not None:
297
  default_kwargs_dict.update(kwargs_dict)
298
 
299
- cats = set(embs_df[label])
300
-
301
- with plt.rc_context():
302
- ax = sc.pl.umap(adata, color=label, show=False, **default_kwargs_dict)
303
- ax.legend(
304
- markerscale=2,
305
- frameon=False,
306
- loc="center left",
307
- bbox_to_anchor=(1, 0.5),
308
- ncol=(1 if len(cats) <= 14 else 2 if len(cats) <= 30 else 3),
309
- )
310
- plt.show()
311
- plt.savefig(output_file, bbox_inches="tight")
312
 
313
 
314
  def gen_heatmap_class_colors(labels, df):
@@ -384,8 +346,7 @@ def plot_heatmap(embs_df, emb_dims, label, output_file, kwargs_dict):
384
  bbox_to_anchor=(0.5, 1),
385
  facecolor="white",
386
  )
387
- plt.show()
388
- logger.info(f"Output file: {output_file}")
389
  plt.savefig(output_file, bbox_inches="tight")
390
 
391
 
@@ -393,7 +354,7 @@ class EmbExtractor:
393
  valid_option_dict = {
394
  "model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
395
  "num_classes": {int},
396
- "emb_mode": {"cls", "cell", "gene"},
397
  "cell_emb_style": {"mean_pool"},
398
  "gene_emb_style": {"mean_pool"},
399
  "filter_data": {None, dict},
@@ -402,7 +363,6 @@ class EmbExtractor:
402
  "emb_label": {None, list},
403
  "labels_to_plot": {None, list},
404
  "forward_batch_size": {int},
405
- "token_dictionary_file": {None, str},
406
  "nproc": {int},
407
  "summary_stat": {None, "mean", "median", "exact_mean", "exact_median"},
408
  }
@@ -411,7 +371,7 @@ class EmbExtractor:
411
  self,
412
  model_type="Pretrained",
413
  num_classes=0,
414
- emb_mode="cls",
415
  cell_emb_style="mean_pool",
416
  gene_emb_style="mean_pool",
417
  filter_data=None,
@@ -422,7 +382,7 @@ class EmbExtractor:
422
  forward_batch_size=100,
423
  nproc=4,
424
  summary_stat=None,
425
- token_dictionary_file=None,
426
  ):
427
  """
428
  Initialize embedding extractor.
@@ -434,11 +394,10 @@ class EmbExtractor:
434
  num_classes : int
435
  | If model is a gene or cell classifier, specify number of classes it was trained to classify.
436
  | For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
437
- emb_mode : {"cls", "cell", "gene"}
438
- | Whether to output CLS, cell, or gene embeddings.
439
- | CLS embeddings are cell embeddings derived from the CLS token in the front of the rank value encoding.
440
- cell_emb_style : {"mean_pool"}
441
- | Method for summarizing cell embeddings if not using CLS token.
442
  | Currently only option is mean pooling of gene embeddings for given cell.
443
  gene_emb_style : "mean_pool"
444
  | Method for summarizing gene embeddings.
@@ -473,7 +432,6 @@ class EmbExtractor:
473
  | Non-exact recommended if encountering memory constraints while generating goal embedding positions.
474
  | Non-exact is slower but more memory-efficient.
475
  token_dictionary_file : Path
476
- | Default is the Geneformer token dictionary
477
  | Path to pickle file containing token dictionary (Ensembl ID:token).
478
 
479
  **Examples:**
@@ -486,6 +444,7 @@ class EmbExtractor:
486
  ... emb_mode="cell",
487
  ... filter_data={"cell_type":["cardiomyocyte"]},
488
  ... max_ncells=1000,
 
489
  ... emb_layer=-1,
490
  ... emb_label=["disease", "cell_type"],
491
  ... labels_to_plot=["disease", "cell_type"])
@@ -502,7 +461,6 @@ class EmbExtractor:
502
  self.emb_layer = emb_layer
503
  self.emb_label = emb_label
504
  self.labels_to_plot = labels_to_plot
505
- self.token_dictionary_file = token_dictionary_file
506
  self.forward_batch_size = forward_batch_size
507
  self.nproc = nproc
508
  if (summary_stat is not None) and ("exact" in summary_stat):
@@ -515,8 +473,6 @@ class EmbExtractor:
515
  self.validate_options()
516
 
517
  # load token dictionary (Ensembl IDs:token)
518
- if self.token_dictionary_file is None:
519
- token_dictionary_file = TOKEN_DICTIONARY_FILE
520
  with open(token_dictionary_file, "rb") as f:
521
  self.gene_token_dict = pickle.load(f)
522
 
@@ -532,7 +488,7 @@ class EmbExtractor:
532
  continue
533
  valid_type = False
534
  for option in valid_options:
535
- if (option in [int, list, dict, bool, str]) and isinstance(
536
  attr_value, option
537
  ):
538
  valid_type = True
@@ -606,14 +562,13 @@ class EmbExtractor:
606
  )
607
  layer_to_quant = pu.quant_layers(model) + self.emb_layer
608
  embs = get_embs(
609
- model=model,
610
- filtered_input_data=downsampled_data,
611
- emb_mode=self.emb_mode,
612
- layer_to_quant=layer_to_quant,
613
- pad_token_id=self.pad_token_id,
614
- forward_batch_size=self.forward_batch_size,
615
- token_gene_dict=self.token_gene_dict,
616
- summary_stat=self.summary_stat,
617
  )
618
 
619
  if self.emb_mode == "cell":
@@ -627,8 +582,6 @@ class EmbExtractor:
627
  elif self.summary_stat is not None:
628
  embs_df = pd.DataFrame(embs).T
629
  embs_df.index = [self.token_gene_dict[token] for token in embs_df.index]
630
- elif self.emb_mode == "cls":
631
- embs_df = label_cell_embs(embs, downsampled_data, self.emb_label)
632
 
633
  # save embeddings to output_path
634
  if cell_state is None:
@@ -637,17 +590,13 @@ class EmbExtractor:
637
 
638
  if self.exact_summary_stat == "exact_mean":
639
  embs = embs.mean(dim=0)
640
- emb_dims = pu.get_model_emb_dims(model)
641
  embs_df = pd.DataFrame(
642
- embs_df[0 : emb_dims - 1].mean(axis="rows"),
643
- columns=[self.exact_summary_stat],
644
  ).T
645
  elif self.exact_summary_stat == "exact_median":
646
  embs = torch.median(embs, dim=0)[0]
647
- emb_dims = pu.get_model_emb_dims(model)
648
  embs_df = pd.DataFrame(
649
- embs_df[0 : emb_dims - 1].median(axis="rows"),
650
- columns=[self.exact_summary_stat],
651
  ).T
652
 
653
  if cell_state is not None:
@@ -800,15 +749,15 @@ class EmbExtractor:
800
  logger.error("Plotting UMAP requires 'labels_to_plot'. ")
801
  raise
802
 
803
- if max_ncells_to_plot is not None:
804
- if max_ncells_to_plot > self.max_ncells:
805
- max_ncells_to_plot = self.max_ncells
806
- logger.warning(
807
- "max_ncells_to_plot must be <= max_ncells. "
808
- f"Changing max_ncells_to_plot to {self.max_ncells}."
809
- )
810
- elif max_ncells_to_plot < self.max_ncells:
811
- embs = embs.sample(max_ncells_to_plot, axis=0)
812
 
813
  if self.emb_label is None:
814
  label_len = 0
@@ -830,11 +779,11 @@ class EmbExtractor:
830
  f"not present in provided embeddings dataframe."
831
  )
832
  continue
833
- output_prefix_label = output_prefix + f"_umap_{label}"
834
  output_file = (
835
  Path(output_directory) / output_prefix_label
836
  ).with_suffix(".pdf")
837
- plot_umap(embs, emb_dims, label, output_file, kwargs_dict)
838
 
839
  if plot_style == "heatmap":
840
  for label in self.labels_to_plot:
 
24
  from tdigest import TDigest
25
  from tqdm.auto import trange
26
 
 
27
  from . import perturber_utils as pu
28
+ from .tokenizer import TOKEN_DICTIONARY_FILE
29
 
30
  logger = logging.getLogger(__name__)
31
 
 
38
  layer_to_quant,
39
  pad_token_id,
40
  forward_batch_size,
 
 
41
  summary_stat=None,
42
  silent=False,
43
  ):
 
47
  if summary_stat is None:
48
  embs_list = []
49
  elif summary_stat is not None:
50
+ # test embedding extraction for example cell and extract # emb dims
51
+ example = filtered_input_data.select([i for i in range(1)])
52
+ example.set_format(type="torch")
53
+ emb_dims = test_emb(model, example["input_ids"], layer_to_quant)
54
  if emb_mode == "cell":
55
  # initiate tdigests for # of emb dims
56
  embs_tdigests = [TDigest() for _ in range(emb_dims)]
 
67
  k: [TDigest() for _ in range(emb_dims)] for k in gene_set
68
  }
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  overall_max_len = 0
71
 
72
  for i in trange(0, total_batch_length, forward_batch_size, leave=(not silent)):
 
92
  embs_i = outputs.hidden_states[layer_to_quant]
93
 
94
  if emb_mode == "cell":
95
+ mean_embs = pu.mean_nonpadding_embs(embs_i, original_lens)
 
 
 
 
 
 
 
96
  if summary_stat is None:
97
  embs_list.append(mean_embs)
98
  elif summary_stat is not None:
 
121
  accumulate_tdigests(
122
  embs_tdigests_dict[int(k)], dict_h[k], emb_dims
123
  )
 
 
 
 
 
 
124
 
125
  overall_max_len = max(overall_max_len, max_len)
126
  del outputs
 
131
  torch.cuda.empty_cache()
132
 
133
  if summary_stat is None:
134
+ if emb_mode == "cell":
135
  embs_stack = torch.cat(embs_list, dim=0)
136
  elif emb_mode == "gene":
137
  embs_stack = pu.pad_tensor_list(
 
209
  return [embs_tdigests[i].percentile(50) for i in range(emb_dims)]
210
 
211
 
212
+ def test_emb(model, example, layer_to_quant):
213
+ with torch.no_grad():
214
+ outputs = model(input_ids=example.to("cuda"))
215
+
216
+ embs_test = outputs.hidden_states[layer_to_quant]
217
+ return embs_test.size()[2]
218
+
219
+
220
  def label_cell_embs(embs, downsampled_data, emb_labels):
221
  embs_df = pd.DataFrame(embs.cpu().numpy())
222
  if emb_labels is not None:
 
252
  return embs_df
253
 
254
 
255
+ def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict):
256
  only_embs_df = embs_df.iloc[:, :emb_dims]
257
  only_embs_df.index = pd.RangeIndex(0, only_embs_df.shape[0], name=None).astype(str)
258
  only_embs_df.columns = pd.RangeIndex(0, only_embs_df.shape[1], name=None).astype(
 
262
  obs_dict = {"cell_id": list(only_embs_df.index), f"{label}": list(embs_df[label])}
263
  adata = anndata.AnnData(X=only_embs_df, obs=obs_dict, var=vars_dict)
264
  sc.tl.pca(adata, svd_solver="arpack")
265
+ sc.pp.neighbors(adata)
266
+ sc.tl.umap(adata)
267
  sns.set(rc={"figure.figsize": (10, 10)}, font_scale=2.3)
268
  sns.set_style("white")
269
+ default_kwargs_dict = {"palette": "Set2", "size": 200}
270
  if kwargs_dict is not None:
271
  default_kwargs_dict.update(kwargs_dict)
272
 
273
+ sc.pl.umap(adata, color=label, save=output_file, **default_kwargs_dict)
 
 
 
 
 
 
 
 
 
 
 
 
274
 
275
 
276
  def gen_heatmap_class_colors(labels, df):
 
346
  bbox_to_anchor=(0.5, 1),
347
  facecolor="white",
348
  )
349
+
 
350
  plt.savefig(output_file, bbox_inches="tight")
351
 
352
 
 
354
  valid_option_dict = {
355
  "model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
356
  "num_classes": {int},
357
+ "emb_mode": {"cell", "gene"},
358
  "cell_emb_style": {"mean_pool"},
359
  "gene_emb_style": {"mean_pool"},
360
  "filter_data": {None, dict},
 
363
  "emb_label": {None, list},
364
  "labels_to_plot": {None, list},
365
  "forward_batch_size": {int},
 
366
  "nproc": {int},
367
  "summary_stat": {None, "mean", "median", "exact_mean", "exact_median"},
368
  }
 
371
  self,
372
  model_type="Pretrained",
373
  num_classes=0,
374
+ emb_mode="cell",
375
  cell_emb_style="mean_pool",
376
  gene_emb_style="mean_pool",
377
  filter_data=None,
 
382
  forward_batch_size=100,
383
  nproc=4,
384
  summary_stat=None,
385
+ token_dictionary_file=TOKEN_DICTIONARY_FILE,
386
  ):
387
  """
388
  Initialize embedding extractor.
 
394
  num_classes : int
395
  | If model is a gene or cell classifier, specify number of classes it was trained to classify.
396
  | For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
397
+ emb_mode : {"cell", "gene"}
398
+ | Whether to output cell or gene embeddings.
399
+ cell_emb_style : "mean_pool"
400
+ | Method for summarizing cell embeddings.
 
401
  | Currently only option is mean pooling of gene embeddings for given cell.
402
  gene_emb_style : "mean_pool"
403
  | Method for summarizing gene embeddings.
 
432
  | Non-exact recommended if encountering memory constraints while generating goal embedding positions.
433
  | Non-exact is slower but more memory-efficient.
434
  token_dictionary_file : Path
 
435
  | Path to pickle file containing token dictionary (Ensembl ID:token).
436
 
437
  **Examples:**
 
444
  ... emb_mode="cell",
445
  ... filter_data={"cell_type":["cardiomyocyte"]},
446
  ... max_ncells=1000,
447
+ ... max_ncells_to_plot=1000,
448
  ... emb_layer=-1,
449
  ... emb_label=["disease", "cell_type"],
450
  ... labels_to_plot=["disease", "cell_type"])
 
461
  self.emb_layer = emb_layer
462
  self.emb_label = emb_label
463
  self.labels_to_plot = labels_to_plot
 
464
  self.forward_batch_size = forward_batch_size
465
  self.nproc = nproc
466
  if (summary_stat is not None) and ("exact" in summary_stat):
 
473
  self.validate_options()
474
 
475
  # load token dictionary (Ensembl IDs:token)
 
 
476
  with open(token_dictionary_file, "rb") as f:
477
  self.gene_token_dict = pickle.load(f)
478
 
 
488
  continue
489
  valid_type = False
490
  for option in valid_options:
491
+ if (option in [int, list, dict, bool]) and isinstance(
492
  attr_value, option
493
  ):
494
  valid_type = True
 
562
  )
563
  layer_to_quant = pu.quant_layers(model) + self.emb_layer
564
  embs = get_embs(
565
+ model,
566
+ downsampled_data,
567
+ self.emb_mode,
568
+ layer_to_quant,
569
+ self.pad_token_id,
570
+ self.forward_batch_size,
571
+ self.summary_stat,
 
572
  )
573
 
574
  if self.emb_mode == "cell":
 
582
  elif self.summary_stat is not None:
583
  embs_df = pd.DataFrame(embs).T
584
  embs_df.index = [self.token_gene_dict[token] for token in embs_df.index]
 
 
585
 
586
  # save embeddings to output_path
587
  if cell_state is None:
 
590
 
591
  if self.exact_summary_stat == "exact_mean":
592
  embs = embs.mean(dim=0)
 
593
  embs_df = pd.DataFrame(
594
+ embs_df[0:255].mean(axis="rows"), columns=[self.exact_summary_stat]
 
595
  ).T
596
  elif self.exact_summary_stat == "exact_median":
597
  embs = torch.median(embs, dim=0)[0]
 
598
  embs_df = pd.DataFrame(
599
+ embs_df[0:255].median(axis="rows"), columns=[self.exact_summary_stat]
 
600
  ).T
601
 
602
  if cell_state is not None:
 
749
  logger.error("Plotting UMAP requires 'labels_to_plot'. ")
750
  raise
751
 
752
+ if max_ncells_to_plot > self.max_ncells:
753
+ max_ncells_to_plot = self.max_ncells
754
+ logger.warning(
755
+ "max_ncells_to_plot must be <= max_ncells. "
756
+ f"Changing max_ncells_to_plot to {self.max_ncells}."
757
+ )
758
+
759
+ if (max_ncells_to_plot is not None) and (max_ncells_to_plot < self.max_ncells):
760
+ embs = embs.sample(max_ncells_to_plot, axis=0)
761
 
762
  if self.emb_label is None:
763
  label_len = 0
 
779
  f"not present in provided embeddings dataframe."
780
  )
781
  continue
782
+ output_prefix_label = "_" + output_prefix + f"_umap_{label}"
783
  output_file = (
784
  Path(output_directory) / output_prefix_label
785
  ).with_suffix(".pdf")
786
+ plot_umap(embs, emb_dims, label, output_prefix_label, kwargs_dict)
787
 
788
  if plot_style == "heatmap":
789
  for label in self.labels_to_plot:
geneformer/ensembl_mapping_dict_gc95M.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:0819bcbd869cfa14279449b037eb9ed1d09a91310e77bd1a19d927465030e95c
3
- size 3957652
 
 
 
 
geneformer/evaluation_utils.py CHANGED
@@ -20,20 +20,20 @@ from sklearn.metrics import (
20
  )
21
  from tqdm.auto import trange
22
 
23
- from . import TOKEN_DICTIONARY_FILE
24
  from .emb_extractor import make_colorbar
 
25
 
26
  logger = logging.getLogger(__name__)
27
 
 
 
 
 
28
 
29
  def preprocess_classifier_batch(cell_batch, max_len, label_name):
30
  if max_len is None:
31
  max_len = max([len(i) for i in cell_batch["input_ids"]])
32
 
33
- # load token dictionary (Ensembl IDs:token)
34
- with open(TOKEN_DICTIONARY_FILE, "rb") as f:
35
- gene_token_dict = pickle.load(f)
36
-
37
  def pad_label_example(example):
38
  example[label_name] = np.pad(
39
  example[label_name],
 
20
  )
21
  from tqdm.auto import trange
22
 
 
23
  from .emb_extractor import make_colorbar
24
+ from .tokenizer import TOKEN_DICTIONARY_FILE
25
 
26
  logger = logging.getLogger(__name__)
27
 
28
+ # load token dictionary (Ensembl IDs:token)
29
+ with open(TOKEN_DICTIONARY_FILE, "rb") as f:
30
+ gene_token_dict = pickle.load(f)
31
+
32
 
33
  def preprocess_classifier_batch(cell_batch, max_len, label_name):
34
  if max_len is None:
35
  max_len = max([len(i) for i in cell_batch["input_ids"]])
36
 
 
 
 
 
37
  def pad_label_example(example):
38
  example[label_name] = np.pad(
39
  example[label_name],
geneformer/gene_dictionaries_30m/ensembl_mapping_dict_gc30M.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:eac0fb0b3007267871b6305ac0003ceba19d4f28d85686cb9067ecf142787869
3
- size 584125
 
 
 
 
geneformer/gene_dictionaries_30m/gene_median_dictionary_gc30M.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:b3b589bb5ec75040d05fc44dd6bf0184cf87f3c362cf158d196a6ed3b7fe5f39
3
- size 940965
 
 
 
 
geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:ab9dc40973fa5224d77b793e2fd114cacf3d08423ed9c4c49caf0ba9c7f218f1
3
- size 788424
 
 
 
 
geneformer/gene_median_dictionary.pkl ADDED
Binary file (941 kB). View file
 
geneformer/gene_median_dictionary_gc95M.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:a51c53f6a771d64508dfaf61529df70e394c53bd20856926117ae5d641a24bf5
3
- size 1512661
 
 
 
 
geneformer/{gene_dictionaries_30m/gene_name_id_dict_gc30M.pkl → gene_name_id_dict.pkl} RENAMED
File without changes
geneformer/gene_name_id_dict_gc95M.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:8b0fd0521406ed18b2e341ef0acb5f53aa1a62457a07ca5840e1c142f46dd326
3
- size 2038812
 
 
 
 
geneformer/in_silico_perturber.py CHANGED
@@ -38,17 +38,19 @@ import logging
38
  import os
39
  import pickle
40
  from collections import defaultdict
 
41
 
 
42
  import torch
43
- from datasets import Dataset, disable_progress_bars
44
- from multiprocess import set_start_method
45
  from tqdm.auto import trange
46
 
47
- from . import TOKEN_DICTIONARY_FILE
48
  from . import perturber_utils as pu
49
  from .emb_extractor import get_embs
 
 
 
50
 
51
- disable_progress_bars()
52
 
53
  logger = logging.getLogger(__name__)
54
 
@@ -60,9 +62,9 @@ class InSilicoPerturber:
60
  "genes_to_perturb": {"all", list},
61
  "combos": {0, 1},
62
  "anchor_gene": {None, str},
63
- "model_type": {"Pretrained", "GeneClassifier", "CellClassifier", "MTLCellClassifier", "MTLCellClassifier-Quantized"},
64
  "num_classes": {int},
65
- "emb_mode": {"cls", "cell", "cls_and_gene", "cell_and_gene"},
66
  "cell_emb_style": {"mean_pool"},
67
  "filter_data": {None, dict},
68
  "cell_states_to_model": {None, dict},
@@ -70,7 +72,6 @@ class InSilicoPerturber:
70
  "max_ncells": {None, int},
71
  "cell_inds_to_perturb": {"all", dict},
72
  "emb_layer": {-1, 0},
73
- "token_dictionary_file": {None, str},
74
  "forward_batch_size": {int},
75
  "nproc": {int},
76
  }
@@ -94,8 +95,7 @@ class InSilicoPerturber:
94
  emb_layer=-1,
95
  forward_batch_size=100,
96
  nproc=4,
97
- token_dictionary_file=None,
98
- clear_mem_ncells=1000,
99
  ):
100
  """
101
  Initialize in silico perturber.
@@ -130,16 +130,16 @@ class InSilicoPerturber:
130
  | ENSEMBL ID of gene to use as anchor in combination perturbations.
131
  | For example, if combos=1 and anchor_gene="ENSG00000148400":
132
  | anchor gene will be perturbed in combination with each other gene.
133
- model_type : {"Pretrained", "GeneClassifier", "CellClassifier", "MTLCellClassifier", "MTLCellClassifier-Quantized"}
134
- | Whether model is the pretrained Geneformer or a fine-tuned gene, cell, or multitask cell classifier (+/- 8bit quantization).
135
  num_classes : int
136
  | If model is a gene or cell classifier, specify number of classes it was trained to classify.
137
  | For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
138
- emb_mode : {"cls", "cell", "cls_and_gene","cell_and_gene"}
139
- | Whether to output impact of perturbation on CLS token, cell, and/or gene embeddings.
140
  | Gene embedding shifts only available as compared to original cell, not comparing to goal state.
141
  cell_emb_style : "mean_pool"
142
- | Method for summarizing cell embeddings if not using CLS token.
143
  | Currently only option is mean pooling of gene embeddings for given cell.
144
  filter_data : None, dict
145
  | Default is to use all input data for in silico perturbation study.
@@ -184,13 +184,7 @@ class InSilicoPerturber:
184
  | Number of CPU processes to use.
185
  token_dictionary_file : Path
186
  | Path to pickle file containing token dictionary (Ensembl ID:token).
187
- clear_mem_ncells : int
188
- | Clear memory every n cells.
189
  """
190
- try:
191
- set_start_method("spawn")
192
- except RuntimeError:
193
- pass
194
 
195
  self.perturb_type = perturb_type
196
  self.perturb_rank_shift = perturb_rank_shift
@@ -222,32 +216,14 @@ class InSilicoPerturber:
222
  self.emb_layer = emb_layer
223
  self.forward_batch_size = forward_batch_size
224
  self.nproc = nproc
225
- self.token_dictionary_file = token_dictionary_file
226
- self.clear_mem_ncells = clear_mem_ncells
227
 
228
  self.validate_options()
229
 
230
  # load token dictionary (Ensembl IDs:token)
231
- if self.token_dictionary_file is None:
232
- token_dictionary_file = TOKEN_DICTIONARY_FILE
233
  with open(token_dictionary_file, "rb") as f:
234
  self.gene_token_dict = pickle.load(f)
235
- self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
236
 
237
  self.pad_token_id = self.gene_token_dict.get("<pad>")
238
- self.cls_token_id = self.gene_token_dict.get("<cls>")
239
- self.eos_token_id = self.gene_token_dict.get("<eos>")
240
-
241
- # Identify if special token is present in the token dictionary
242
- if (self.cls_token_id is not None) and (self.eos_token_id is not None):
243
- self.special_token = True
244
- else:
245
- if "cls" in self.emb_mode:
246
- logger.error(
247
- f"emb_mode set to {self.emb_mode} but <cls> or <eos> token not in token dictionary."
248
- )
249
- raise
250
- self.special_token = False
251
 
252
  if self.anchor_gene is None:
253
  self.anchor_token = None
@@ -305,7 +281,7 @@ class InSilicoPerturber:
305
  continue
306
  valid_type = False
307
  for option in valid_options:
308
- if (option in [bool, int, list, dict, str]) and isinstance(
309
  attr_value, option
310
  ):
311
  valid_type = True
@@ -451,45 +427,16 @@ class InSilicoPerturber:
451
  filtered_input_data = pu.load_and_filter(
452
  self.filter_data, self.nproc, input_data_file
453
  )
454
-
455
- # Ensure emb_mode is cls if first token of the filtered input data is cls token
456
- if self.special_token:
457
- if (filtered_input_data["input_ids"][0][0] == self.cls_token_id) and (
458
- "cls" not in self.emb_mode
459
- ):
460
- logger.error(
461
- "Emb mode 'cls' or 'cls_and_gene' required when first token is <cls>."
462
- )
463
- raise
464
- if "cls" in self.emb_mode:
465
- if (filtered_input_data["input_ids"][0][0] != self.cls_token_id) or (
466
- filtered_input_data["input_ids"][0][-1] != self.eos_token_id
467
- ):
468
- logger.error(
469
- "Emb mode 'cls' and 'cls_and_gene' require that first token is <cls> and last token is <eos>."
470
- )
471
- raise
472
-
473
  filtered_input_data = self.apply_additional_filters(filtered_input_data)
474
 
475
  if self.perturb_group is True:
476
- if (self.special_token) and ("cls" in self.emb_mode):
477
- self.isp_perturb_set_special(
478
- model, filtered_input_data, layer_to_quant, output_path_prefix
479
- )
480
- else:
481
- self.isp_perturb_set(
482
- model, filtered_input_data, layer_to_quant, output_path_prefix
483
- )
484
  else:
485
- if (self.special_token) and ("cls" in self.emb_mode):
486
- self.isp_perturb_all_special(
487
- model, filtered_input_data, layer_to_quant, output_path_prefix
488
- )
489
- else:
490
- self.isp_perturb_all(
491
- model, filtered_input_data, layer_to_quant, output_path_prefix
492
- )
493
 
494
  def apply_additional_filters(self, filtered_input_data):
495
  # additional filtering of input data dependent on isp mode
@@ -552,9 +499,7 @@ class InSilicoPerturber:
552
  if self.perturb_type == "delete":
553
  example = pu.delete_indices(example)
554
  elif self.perturb_type == "overexpress":
555
- example = pu.overexpress_tokens(
556
- example, self.max_len, self.special_token
557
- )
558
  example["n_overflow"] = pu.calc_n_overflow(
559
  self.max_len,
560
  example["length"],
@@ -575,7 +520,6 @@ class InSilicoPerturber:
575
  perturbed_data = filtered_input_data.map(
576
  make_group_perturbation_batch, num_proc=self.nproc
577
  )
578
-
579
  if self.perturb_type == "overexpress":
580
  filtered_input_data = filtered_input_data.add_column(
581
  "n_overflow", perturbed_data["n_overflow"]
@@ -608,7 +552,6 @@ class InSilicoPerturber:
608
  layer_to_quant,
609
  self.pad_token_id,
610
  self.forward_batch_size,
611
- token_gene_dict=self.token_gene_dict,
612
  summary_stat=None,
613
  silent=True,
614
  )
@@ -628,7 +571,6 @@ class InSilicoPerturber:
628
  layer_to_quant,
629
  self.pad_token_id,
630
  self.forward_batch_size,
631
- token_gene_dict=self.token_gene_dict,
632
  summary_stat=None,
633
  silent=True,
634
  )
@@ -728,6 +670,8 @@ class InSilicoPerturber:
728
  cos_sims_dict = self.update_perturbation_dictionary(
729
  cos_sims_dict,
730
  cos_sims_data,
 
 
731
  gene_list,
732
  )
733
  else:
@@ -736,6 +680,8 @@ class InSilicoPerturber:
736
  cos_sims_dict[state] = self.update_perturbation_dictionary(
737
  cos_sims_dict[state],
738
  cos_sims_data[state],
 
 
739
  gene_list,
740
  )
741
  del minibatch
@@ -757,264 +703,6 @@ class InSilicoPerturber:
757
  f"{output_path_prefix}_gene_embs_dict_{self.tokens_to_perturb}",
758
  )
759
 
760
- def isp_perturb_set_special(
761
- self,
762
- model,
763
- filtered_input_data: Dataset,
764
- layer_to_quant: int,
765
- output_path_prefix: str,
766
- ):
767
- def make_group_perturbation_batch(example):
768
- example_input_ids = example["input_ids"]
769
- example["tokens_to_perturb"] = self.tokens_to_perturb
770
- indices_to_perturb = [
771
- example_input_ids.index(token) if token in example_input_ids else None
772
- for token in self.tokens_to_perturb
773
- ]
774
- indices_to_perturb = [
775
- item for item in indices_to_perturb if item is not None
776
- ]
777
- if len(indices_to_perturb) > 0:
778
- example["perturb_index"] = indices_to_perturb
779
- else:
780
- # -100 indicates tokens to overexpress are not present in rank value encoding
781
- example["perturb_index"] = [-100]
782
- if self.perturb_type == "delete":
783
- example = pu.delete_indices(example)
784
- elif self.perturb_type == "overexpress":
785
- example = pu.overexpress_tokens(
786
- example, self.max_len, self.special_token
787
- )
788
- example["n_overflow"] = pu.calc_n_overflow(
789
- self.max_len,
790
- example["length"],
791
- self.tokens_to_perturb,
792
- indices_to_perturb,
793
- )
794
- return example
795
-
796
- total_batch_length = len(filtered_input_data)
797
- if self.cell_states_to_model is None:
798
- cos_sims_dict = defaultdict(list)
799
- else:
800
- cos_sims_dict = {
801
- state: defaultdict(list)
802
- for state in pu.get_possible_states(self.cell_states_to_model)
803
- }
804
-
805
- perturbed_data = filtered_input_data.map(
806
- make_group_perturbation_batch, num_proc=self.nproc
807
- )
808
-
809
- if self.perturb_type == "overexpress":
810
- filtered_input_data = filtered_input_data.add_column(
811
- "n_overflow", perturbed_data["n_overflow"]
812
- )
813
- filtered_input_data = filtered_input_data.map(
814
- pu.truncate_by_n_overflow_special, num_proc=self.nproc
815
- )
816
-
817
- if self.emb_mode == "cls_and_gene":
818
- stored_gene_embs_dict = defaultdict(list)
819
-
820
- # iterate through batches
821
- for i in trange(0, total_batch_length, self.forward_batch_size):
822
- max_range = min(i + self.forward_batch_size, total_batch_length)
823
- inds_select = [i for i in range(i, max_range)]
824
-
825
- minibatch = filtered_input_data.select(inds_select)
826
- perturbation_batch = perturbed_data.select(inds_select)
827
-
828
- ##### CLS Embedding Mode #####
829
- if self.emb_mode == "cls":
830
- indices_to_perturb = perturbation_batch["perturb_index"]
831
-
832
- original_cls_emb = get_embs(
833
- model,
834
- minibatch,
835
- "cls",
836
- layer_to_quant,
837
- self.pad_token_id,
838
- self.forward_batch_size,
839
- token_gene_dict=self.token_gene_dict,
840
- summary_stat=None,
841
- silent=True,
842
- )
843
-
844
- perturbation_cls_emb = get_embs(
845
- model,
846
- perturbation_batch,
847
- "cls",
848
- layer_to_quant,
849
- self.pad_token_id,
850
- self.forward_batch_size,
851
- token_gene_dict=self.token_gene_dict,
852
- summary_stat=None,
853
- silent=True,
854
- )
855
-
856
- # Calculate the cosine similarities
857
- cls_cos_sims = pu.quant_cos_sims(
858
- perturbation_cls_emb,
859
- original_cls_emb,
860
- self.cell_states_to_model,
861
- self.state_embs_dict,
862
- emb_mode="cell",
863
- )
864
-
865
- # Update perturbation dictionary
866
- if self.cell_states_to_model is None:
867
- cos_sims_dict = self.update_perturbation_dictionary(
868
- cos_sims_dict,
869
- cls_cos_sims,
870
- gene_list=None,
871
- )
872
- else:
873
- for state in cos_sims_dict.keys():
874
- cos_sims_dict[state] = self.update_perturbation_dictionary(
875
- cos_sims_dict[state],
876
- cls_cos_sims[state],
877
- gene_list=None,
878
- )
879
-
880
- ##### CLS and Gene Embedding Mode #####
881
- elif self.emb_mode == "cls_and_gene":
882
- full_original_emb = get_embs(
883
- model,
884
- minibatch,
885
- "gene",
886
- layer_to_quant,
887
- self.pad_token_id,
888
- self.forward_batch_size,
889
- self.token_gene_dict,
890
- summary_stat=None,
891
- silent=True,
892
- )
893
- indices_to_perturb = perturbation_batch["perturb_index"]
894
- # remove indices that were perturbed
895
- original_emb = pu.remove_perturbed_indices_set(
896
- full_original_emb,
897
- self.perturb_type,
898
- indices_to_perturb,
899
- self.tokens_to_perturb,
900
- minibatch["length"],
901
- )
902
- full_perturbation_emb = get_embs(
903
- model,
904
- perturbation_batch,
905
- "gene",
906
- layer_to_quant,
907
- self.pad_token_id,
908
- self.forward_batch_size,
909
- self.token_gene_dict,
910
- summary_stat=None,
911
- silent=True,
912
- )
913
-
914
- # remove special tokens and padding
915
- original_emb = original_emb[:, 1:-1, :]
916
- if self.perturb_type == "overexpress":
917
- perturbation_emb = full_perturbation_emb[
918
- :, 1 + len(self.tokens_to_perturb) : -1, :
919
- ]
920
- elif self.perturb_type == "delete":
921
- perturbation_emb = full_perturbation_emb[
922
- :, 1 : max(perturbation_batch["length"]) - 1, :
923
- ]
924
-
925
- n_perturbation_genes = perturbation_emb.size()[1]
926
-
927
- gene_cos_sims = pu.quant_cos_sims(
928
- perturbation_emb,
929
- original_emb,
930
- self.cell_states_to_model,
931
- self.state_embs_dict,
932
- emb_mode="gene",
933
- )
934
-
935
- # get cls emb
936
- original_cls_emb = full_original_emb[:, 0, :]
937
- perturbation_cls_emb = full_perturbation_emb[:, 0, :]
938
-
939
- cls_cos_sims = pu.quant_cos_sims(
940
- perturbation_cls_emb,
941
- original_cls_emb,
942
- self.cell_states_to_model,
943
- self.state_embs_dict,
944
- emb_mode="cell",
945
- )
946
-
947
- # get cosine similarities in gene embeddings
948
- # since getting gene embeddings, need gene names
949
-
950
- gene_list = minibatch["input_ids"]
951
- # need to truncate gene_list
952
- genes_to_exclude = self.tokens_to_perturb + [
953
- self.cls_token_id,
954
- self.eos_token_id,
955
- ]
956
- gene_list = [
957
- [g for g in genes if g not in genes_to_exclude][
958
- :n_perturbation_genes
959
- ]
960
- for genes in gene_list
961
- ]
962
-
963
- for cell_i, genes in enumerate(gene_list):
964
- for gene_j, affected_gene in enumerate(genes):
965
- if len(self.genes_to_perturb) > 1:
966
- tokens_to_perturb = tuple(self.tokens_to_perturb)
967
- else:
968
- tokens_to_perturb = self.tokens_to_perturb[0]
969
-
970
- # fill in the gene cosine similarities
971
- try:
972
- stored_gene_embs_dict[
973
- (tokens_to_perturb, affected_gene)
974
- ].append(gene_cos_sims[cell_i, gene_j].item())
975
- except KeyError:
976
- stored_gene_embs_dict[
977
- (tokens_to_perturb, affected_gene)
978
- ] = gene_cos_sims[cell_i, gene_j].item()
979
-
980
- if self.cell_states_to_model is None:
981
- cos_sims_dict = self.update_perturbation_dictionary(
982
- cos_sims_dict,
983
- cls_cos_sims,
984
- gene_list=None,
985
- )
986
- else:
987
- for state in cos_sims_dict.keys():
988
- cos_sims_dict[state] = self.update_perturbation_dictionary(
989
- cos_sims_dict[state],
990
- cls_cos_sims[state],
991
- gene_list=None,
992
- )
993
- del full_original_emb
994
- del original_emb
995
- del full_perturbation_emb
996
- del perturbation_emb
997
- del gene_cos_sims
998
-
999
- del original_cls_emb
1000
- del perturbation_cls_emb
1001
- del cls_cos_sims
1002
- del minibatch
1003
- del perturbation_batch
1004
-
1005
- torch.cuda.empty_cache()
1006
-
1007
- pu.write_perturbation_dictionary(
1008
- cos_sims_dict,
1009
- f"{output_path_prefix}_cell_embs_dict_{self.tokens_to_perturb}",
1010
- )
1011
-
1012
- if self.emb_mode == "cls_and_gene":
1013
- pu.write_perturbation_dictionary(
1014
- stored_gene_embs_dict,
1015
- f"{output_path_prefix}_gene_embs_dict_{self.tokens_to_perturb}",
1016
- )
1017
-
1018
  def isp_perturb_all(
1019
  self,
1020
  model,
@@ -1033,10 +721,8 @@ class InSilicoPerturber:
1033
 
1034
  if self.emb_mode == "cell_and_gene":
1035
  stored_gene_embs_dict = defaultdict(list)
1036
-
1037
- num_inds_perturbed = 1 + self.combos
1038
- for h in trange(len(filtered_input_data)):
1039
- example_cell = filtered_input_data.select([h])
1040
  full_original_emb = get_embs(
1041
  model,
1042
  example_cell,
@@ -1044,30 +730,16 @@ class InSilicoPerturber:
1044
  layer_to_quant,
1045
  self.pad_token_id,
1046
  self.forward_batch_size,
1047
- self.token_gene_dict,
1048
  summary_stat=None,
1049
  silent=True,
1050
  )
1051
 
1052
- if self.cell_states_to_model is not None:
1053
- original_cell_emb = pu.compute_nonpadded_cell_embedding(
1054
- full_original_emb, "mean_pool"
1055
- )
1056
-
1057
  # gene_list is used to assign cos sims back to genes
1058
- gene_list = example_cell["input_ids"][0][:]
1059
  # need to remove the anchor gene
 
1060
  if self.anchor_token is not None:
1061
  for token in self.anchor_token:
1062
  gene_list.remove(token)
1063
- # index 0 is not overexpressed so remove
1064
- if self.perturb_type == "overexpress":
1065
- gene_list = gene_list[num_inds_perturbed:]
1066
- # remove perturbed index for gene list dict
1067
- perturbed_gene_dict = {
1068
- gene: gene_list[:i] + gene_list[i + 1 :]
1069
- for i, gene in enumerate(gene_list)
1070
- }
1071
 
1072
  perturbation_batch, indices_to_perturb = pu.make_perturbation_batch(
1073
  example_cell,
@@ -1078,459 +750,147 @@ class InSilicoPerturber:
1078
  self.nproc,
1079
  )
1080
 
1081
- ispall_total_batch_length = len(perturbation_batch)
1082
- for i in trange(
1083
- 0, ispall_total_batch_length, self.forward_batch_size, leave=False
1084
- ):
1085
- ispall_max_range = min(
1086
- i + self.forward_batch_size, ispall_total_batch_length
1087
- )
1088
- perturbation_minibatch = perturbation_batch.select(
1089
- [i for i in range(i, ispall_max_range)]
1090
- )
1091
- indices_to_perturb_mini = indices_to_perturb[i:ispall_max_range]
1092
- gene_list_mini = gene_list[
1093
- i:ispall_max_range
1094
- ] # only perturbed genes from this minibatch
1095
-
1096
- full_perturbation_emb = get_embs(
1097
- model,
1098
- perturbation_minibatch,
1099
- "gene",
1100
- layer_to_quant,
1101
- self.pad_token_id,
1102
- self.forward_batch_size,
1103
- self.token_gene_dict,
1104
- summary_stat=None,
1105
- silent=True,
1106
- )
1107
-
1108
- del perturbation_minibatch
1109
-
1110
- # need to remove overexpressed gene to quantify cosine shifts
1111
- if self.perturb_type == "overexpress":
1112
- perturbation_emb = full_perturbation_emb[:, num_inds_perturbed:, :]
1113
-
1114
- elif self.perturb_type == "delete":
1115
- perturbation_emb = full_perturbation_emb
1116
-
1117
- if (
1118
- self.cell_states_to_model is None
1119
- or self.emb_mode == "cell_and_gene"
1120
- ):
1121
- original_emb_minibatch = pu.make_comparison_batch(
1122
- full_original_emb, indices_to_perturb_mini, perturb_group=False
1123
- )
1124
- gene_cos_sims = pu.quant_cos_sims(
1125
- perturbation_emb,
1126
- original_emb_minibatch,
1127
- self.cell_states_to_model,
1128
- self.state_embs_dict,
1129
- emb_mode="gene",
1130
- )
1131
- del original_emb_minibatch
1132
-
1133
- if self.cell_states_to_model is not None:
1134
- perturbation_cell_emb = pu.compute_nonpadded_cell_embedding(
1135
- full_perturbation_emb, "mean_pool"
1136
- )
1137
-
1138
- cell_cos_sims = pu.quant_cos_sims(
1139
- perturbation_cell_emb,
1140
- original_cell_emb,
1141
- self.cell_states_to_model,
1142
- self.state_embs_dict,
1143
- emb_mode="cell",
1144
- )
1145
- del perturbation_cell_emb
1146
-
1147
- if self.emb_mode == "cell_and_gene":
1148
- for perturbation_i, perturbed_gene in enumerate(gene_list_mini):
1149
- for gene_j, affected_gene in enumerate(
1150
- perturbed_gene_dict[perturbed_gene]
1151
- ):
1152
- try:
1153
- stored_gene_embs_dict[
1154
- (perturbed_gene, affected_gene)
1155
- ].append(gene_cos_sims[perturbation_i, gene_j].item())
1156
- except KeyError:
1157
- stored_gene_embs_dict[
1158
- (perturbed_gene, affected_gene)
1159
- ] = gene_cos_sims[perturbation_i, gene_j].item()
1160
 
1161
- del full_perturbation_emb
 
 
 
 
 
 
1162
 
1163
- if self.cell_states_to_model is None:
1164
- cos_sims_data = torch.mean(gene_cos_sims, dim=1)
1165
- cos_sims_dict = self.update_perturbation_dictionary(
1166
- cos_sims_dict,
1167
- cos_sims_data,
1168
- gene_list_mini,
1169
- )
1170
- else:
1171
- cos_sims_data = cell_cos_sims
1172
- for state in cos_sims_dict.keys():
1173
- cos_sims_dict[state] = self.update_perturbation_dictionary(
1174
- cos_sims_dict[state],
1175
- cos_sims_data[state],
1176
- gene_list_mini,
1177
- )
1178
-
1179
- # save dict to disk every self.clear_mem_ncells/10 (default 100) simulated cells
1180
- if i % self.clear_mem_ncells / 10 == 0:
1181
- pu.write_perturbation_dictionary(
1182
- cos_sims_dict,
1183
- f"{output_path_prefix}_dict_cell_embs_{h}batch{pickle_batch}",
1184
- )
1185
- if self.emb_mode == "cell_and_gene":
1186
- pu.write_perturbation_dictionary(
1187
- stored_gene_embs_dict,
1188
- f"{output_path_prefix}_dict_gene_embs_{h}batch{pickle_batch}",
1189
- )
1190
-
1191
- # reset and clear memory every self.clear_mem_ncells (default 1000) simulated cells or at the end of the example cell
1192
- if i % self.clear_mem_ncells == 0:
1193
- pickle_batch += 1
1194
- if self.cell_states_to_model is None:
1195
- cos_sims_dict = defaultdict(list)
1196
- else:
1197
- cos_sims_dict = {
1198
- state: defaultdict(list)
1199
- for state in pu.get_possible_states(
1200
- self.cell_states_to_model
1201
- )
1202
- }
1203
-
1204
- if self.emb_mode == "cell_and_gene":
1205
- stored_gene_embs_dict = defaultdict(list)
1206
-
1207
- torch.cuda.empty_cache()
1208
 
1209
- pu.write_perturbation_dictionary(
1210
- cos_sims_dict,
1211
- f"{output_path_prefix}_dict_cell_embs_{h}batch{pickle_batch}",
1212
  )
1213
 
1214
- if self.emb_mode == "cell_and_gene":
1215
- pu.write_perturbation_dictionary(
1216
- stored_gene_embs_dict,
1217
- f"{output_path_prefix}_dict_gene_embs_{h}batch{pickle_batch}",
 
 
 
1218
  )
1219
-
1220
- pickle_batch = -1
1221
- if self.cell_states_to_model is None:
1222
- cos_sims_dict = defaultdict(list)
1223
- else:
1224
- cos_sims_dict = {
1225
- state: defaultdict(list)
1226
- for state in pu.get_possible_states(self.cell_states_to_model)
1227
- }
1228
-
1229
- if self.emb_mode == "cell_and_gene":
1230
- stored_gene_embs_dict = defaultdict(list)
1231
-
1232
- # clear memory between cells
1233
- del perturbation_batch
1234
- del full_original_emb
1235
  if self.cell_states_to_model is not None:
1236
- del original_cell_emb
1237
- torch.cuda.empty_cache()
1238
-
1239
- def isp_perturb_all_special(
1240
- self,
1241
- model,
1242
- filtered_input_data: Dataset,
1243
- layer_to_quant: int,
1244
- output_path_prefix: str,
1245
- ):
1246
- pickle_batch = -1
1247
- if self.cell_states_to_model is None:
1248
- cos_sims_dict = defaultdict(list)
1249
- else:
1250
- cos_sims_dict = {
1251
- state: defaultdict(list)
1252
- for state in pu.get_possible_states(self.cell_states_to_model)
1253
- }
1254
-
1255
- if self.emb_mode == "cls_and_gene":
1256
- stored_gene_embs_dict = defaultdict(list)
1257
-
1258
- num_inds_perturbed = 1 + self.combos
1259
- for h in trange(len(filtered_input_data)):
1260
- example_cell = filtered_input_data.select([h])
1261
-
1262
- # get original example cell cls and/or gene embs for comparison
1263
- if self.emb_mode == "cls":
1264
- original_cls_emb = get_embs(
1265
- model,
1266
- example_cell,
1267
- "cls",
1268
- layer_to_quant,
1269
- self.pad_token_id,
1270
- self.forward_batch_size,
1271
- self.token_gene_dict,
1272
- summary_stat=None,
1273
- silent=True,
1274
  )
1275
- elif self.emb_mode == "cls_and_gene":
1276
- full_original_emb = get_embs(
1277
- model,
1278
- example_cell,
1279
- "gene",
1280
- layer_to_quant,
1281
- self.pad_token_id,
1282
- self.forward_batch_size,
1283
- self.token_gene_dict,
1284
- summary_stat=None,
1285
- silent=True,
1286
  )
1287
- original_cls_emb = full_original_emb[:, 0, :].clone().detach()
1288
-
1289
- # gene_list is used to assign cos sims back to genes
1290
- gene_list = example_cell["input_ids"][0][:]
1291
 
1292
- # need to remove special tokens
1293
- for token in [self.cls_token_id, self.eos_token_id]:
1294
- gene_list.remove(token)
1295
- # need to remove the anchor gene
1296
- if self.anchor_token is not None:
1297
- for token in self.anchor_token:
1298
- gene_list.remove(token)
1299
- # index 0 is not overexpressed so remove
1300
- if self.perturb_type == "overexpress":
1301
- gene_list = gene_list[num_inds_perturbed:]
1302
- # remove perturbed index for gene list dict
1303
- perturbed_gene_dict = {
1304
- gene: gene_list[:i] + gene_list[i + 1 :]
1305
- for i, gene in enumerate(gene_list)
1306
- }
1307
-
1308
- perturbation_batch, indices_to_perturb = pu.make_perturbation_batch_special(
1309
- example_cell,
1310
- self.perturb_type,
1311
- self.tokens_to_perturb,
1312
- self.anchor_token,
1313
- self.combos,
1314
- self.nproc,
1315
- )
1316
-
1317
- ispall_total_batch_length = len(perturbation_batch)
1318
- for i in trange(
1319
- 0, ispall_total_batch_length, self.forward_batch_size, leave=False
1320
- ):
1321
- ispall_max_range = min(
1322
- i + self.forward_batch_size, ispall_total_batch_length
1323
- )
1324
- perturbation_minibatch = perturbation_batch.select(
1325
- [i for i in range(i, ispall_max_range)]
1326
  )
1327
- indices_to_perturb_mini = indices_to_perturb[i:ispall_max_range]
1328
- gene_list_mini = gene_list[
1329
- i:ispall_max_range
1330
- ] # only perturbed genes from this minibatch
1331
-
1332
- ##### CLS Embedding Mode #####
1333
- if self.emb_mode == "cls":
1334
- # Extract cls embeddings from perturbed cells
1335
- perturbation_cls_emb = get_embs(
1336
- model,
1337
- perturbation_minibatch,
1338
- "cls",
1339
- layer_to_quant,
1340
- self.pad_token_id,
1341
- self.forward_batch_size,
1342
- self.token_gene_dict,
1343
- summary_stat=None,
1344
- silent=True,
1345
- )
1346
 
1347
- # Calculate cosine similarities
1348
- cls_cos_sims = pu.quant_cos_sims(
1349
- perturbation_cls_emb,
1350
- original_cls_emb,
1351
- self.cell_states_to_model,
1352
- self.state_embs_dict,
1353
- emb_mode="cell",
1354
- )
1355
 
1356
- if self.cell_states_to_model is None:
1357
- cos_sims_dict = self.update_perturbation_dictionary(
1358
- cos_sims_dict,
1359
- cls_cos_sims,
1360
- gene_list_mini,
1361
- )
1362
- else:
1363
- for state in cos_sims_dict.keys():
1364
- cos_sims_dict[state] = self.update_perturbation_dictionary(
1365
- cos_sims_dict[state],
1366
- cls_cos_sims[state],
1367
- gene_list_mini,
1368
- )
1369
-
1370
- del perturbation_minibatch
1371
- del perturbation_cls_emb
1372
- del cls_cos_sims
1373
-
1374
- ##### CLS and Gene Embedding Mode #####
1375
- elif self.emb_mode == "cls_and_gene":
1376
- full_perturbation_emb = get_embs(
1377
- model,
1378
- perturbation_minibatch,
1379
- "gene",
1380
- layer_to_quant,
1381
- self.pad_token_id,
1382
- self.forward_batch_size,
1383
- self.token_gene_dict,
1384
- summary_stat=None,
1385
- silent=True,
1386
- )
1387
 
1388
- # need to remove overexpressed gene and cls/eos to quantify cosine shifts
1389
- if self.perturb_type == "overexpress":
1390
- perturbation_emb = (
1391
- full_perturbation_emb[:, 1 + num_inds_perturbed : -1, :]
1392
- .clone()
1393
- .detach()
1394
- )
1395
- elif self.perturb_type == "delete":
1396
- perturbation_emb = (
1397
- full_perturbation_emb[:, 1:-1, :].clone().detach()
1398
- )
1399
-
1400
- original_emb_minibatch = pu.make_comparison_batch(
1401
- full_original_emb, indices_to_perturb_mini, perturb_group=False
 
 
 
 
1402
  )
1403
 
1404
- original_emb_minibatch = (
1405
- original_emb_minibatch[:, 1:-1, :].clone().detach()
1406
- )
1407
- gene_cos_sims = pu.quant_cos_sims(
1408
- perturbation_emb,
1409
- original_emb_minibatch,
1410
- self.cell_states_to_model,
1411
- self.state_embs_dict,
1412
- emb_mode="gene",
 
1413
  )
1414
 
1415
- for perturbation_i, perturbed_gene in enumerate(gene_list_mini):
1416
- for gene_j, affected_gene in enumerate(
1417
- perturbed_gene_dict[perturbed_gene]
1418
- ):
1419
- try:
1420
- stored_gene_embs_dict[
1421
- (perturbed_gene, affected_gene)
1422
- ].append(gene_cos_sims[perturbation_i, gene_j].item())
1423
- except KeyError:
1424
- stored_gene_embs_dict[
1425
- (perturbed_gene, affected_gene)
1426
- ] = gene_cos_sims[perturbation_i, gene_j].item()
1427
 
1428
- # get cls emb
1429
- perturbation_cls_emb = (
1430
- full_perturbation_emb[:, 0, :].clone().detach()
1431
- )
1432
 
1433
- cls_cos_sims = pu.quant_cos_sims(
1434
- perturbation_cls_emb,
1435
- original_cls_emb,
1436
- self.cell_states_to_model,
1437
- self.state_embs_dict,
1438
- emb_mode="cell",
1439
- )
1440
 
1441
- if self.cell_states_to_model is None:
1442
- cos_sims_dict = self.update_perturbation_dictionary(
1443
- cos_sims_dict,
1444
- cls_cos_sims,
1445
- gene_list_mini,
1446
- )
1447
- else:
1448
- for state in cos_sims_dict.keys():
1449
- cos_sims_dict[state] = self.update_perturbation_dictionary(
1450
- cos_sims_dict[state],
1451
- cls_cos_sims[state],
1452
- gene_list_mini,
1453
- )
1454
-
1455
- del perturbation_minibatch
1456
- del original_emb_minibatch
1457
- del full_perturbation_emb
1458
- del perturbation_emb
1459
- del perturbation_cls_emb
1460
- del cls_cos_sims
1461
- del gene_cos_sims
1462
-
1463
- # save dict to disk every self.clear_mem_ncells/10 (default 100) simulated cells
1464
- if i % max(1, self.clear_mem_ncells / 10) == 0:
1465
- pu.write_perturbation_dictionary(
1466
- cos_sims_dict,
1467
- f"{output_path_prefix}_dict_cell_embs_{h}batch{pickle_batch}",
1468
- )
1469
- if self.emb_mode == "cls_and_gene":
1470
- pu.write_perturbation_dictionary(
1471
- stored_gene_embs_dict,
1472
- f"{output_path_prefix}_dict_gene_embs_{h}batch{pickle_batch}",
1473
- )
1474
-
1475
- # reset and clear memory every self.clear_mem_ncells (default 1000) simulated cells or at the end of the example cell
1476
- if i % self.clear_mem_ncells == 0:
1477
- pickle_batch += 1
1478
- if self.cell_states_to_model is None:
1479
- cos_sims_dict = defaultdict(list)
1480
- else:
1481
- cos_sims_dict = {
1482
- state: defaultdict(list)
1483
- for state in pu.get_possible_states(
1484
- self.cell_states_to_model
1485
- )
1486
- }
1487
-
1488
- if self.emb_mode == "cls_and_gene":
1489
- stored_gene_embs_dict = defaultdict(list)
1490
-
1491
- torch.cuda.empty_cache()
1492
 
 
1493
  pu.write_perturbation_dictionary(
1494
- cos_sims_dict,
1495
- f"{output_path_prefix}_dict_cell_embs_{h}batch{pickle_batch}",
1496
  )
1497
 
1498
- if self.emb_mode == "cls_and_gene":
1499
- pu.write_perturbation_dictionary(
1500
- stored_gene_embs_dict,
1501
- f"{output_path_prefix}_dict_gene_embs_{h}batch{pickle_batch}",
1502
- )
1503
-
1504
- pickle_batch = -1
1505
- if self.cell_states_to_model is None:
1506
- cos_sims_dict = defaultdict(list)
1507
- else:
1508
- cos_sims_dict = {
1509
- state: defaultdict(list)
1510
- for state in pu.get_possible_states(self.cell_states_to_model)
1511
- }
1512
-
1513
- if self.emb_mode == "cls_and_gene":
1514
- stored_gene_embs_dict = defaultdict(list)
1515
-
1516
- # clear memory between cells
1517
- del perturbation_batch
1518
- del original_cls_emb
1519
- if self.emb_mode == "cls_and_gene":
1520
- del full_original_emb
1521
- torch.cuda.empty_cache()
1522
-
1523
  def update_perturbation_dictionary(
1524
  self,
1525
  cos_sims_dict: defaultdict,
1526
  cos_sims_data: torch.Tensor,
 
 
1527
  gene_list=None,
1528
  ):
1529
  if gene_list is not None and cos_sims_data.shape[0] != len(gene_list):
1530
  logger.error(
1531
  f"len(cos_sims_data.shape[0]) != len(gene_list). \n \
1532
- {cos_sims_data.shape[0]=}.\n \
1533
- {len(gene_list)=}."
1534
  )
1535
  raise
1536
 
 
38
  import os
39
  import pickle
40
  from collections import defaultdict
41
+ from typing import List
42
 
43
+ import seaborn as sns
44
  import torch
45
+ from datasets import Dataset
 
46
  from tqdm.auto import trange
47
 
 
48
  from . import perturber_utils as pu
49
  from .emb_extractor import get_embs
50
+ from .tokenizer import TOKEN_DICTIONARY_FILE
51
+
52
+ sns.set()
53
 
 
54
 
55
  logger = logging.getLogger(__name__)
56
 
 
62
  "genes_to_perturb": {"all", list},
63
  "combos": {0, 1},
64
  "anchor_gene": {None, str},
65
+ "model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
66
  "num_classes": {int},
67
+ "emb_mode": {"cell", "cell_and_gene"},
68
  "cell_emb_style": {"mean_pool"},
69
  "filter_data": {None, dict},
70
  "cell_states_to_model": {None, dict},
 
72
  "max_ncells": {None, int},
73
  "cell_inds_to_perturb": {"all", dict},
74
  "emb_layer": {-1, 0},
 
75
  "forward_batch_size": {int},
76
  "nproc": {int},
77
  }
 
95
  emb_layer=-1,
96
  forward_batch_size=100,
97
  nproc=4,
98
+ token_dictionary_file=TOKEN_DICTIONARY_FILE,
 
99
  ):
100
  """
101
  Initialize in silico perturber.
 
130
  | ENSEMBL ID of gene to use as anchor in combination perturbations.
131
  | For example, if combos=1 and anchor_gene="ENSG00000148400":
132
  | anchor gene will be perturbed in combination with each other gene.
133
+ model_type : {"Pretrained", "GeneClassifier", "CellClassifier"}
134
+ | Whether model is the pretrained Geneformer or a fine-tuned gene or cell classifier.
135
  num_classes : int
136
  | If model is a gene or cell classifier, specify number of classes it was trained to classify.
137
  | For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
138
+ emb_mode : {"cell", "cell_and_gene"}
139
+ | Whether to output impact of perturbation on cell and/or gene embeddings.
140
  | Gene embedding shifts only available as compared to original cell, not comparing to goal state.
141
  cell_emb_style : "mean_pool"
142
+ | Method for summarizing cell embeddings.
143
  | Currently only option is mean pooling of gene embeddings for given cell.
144
  filter_data : None, dict
145
  | Default is to use all input data for in silico perturbation study.
 
184
  | Number of CPU processes to use.
185
  token_dictionary_file : Path
186
  | Path to pickle file containing token dictionary (Ensembl ID:token).
 
 
187
  """
 
 
 
 
188
 
189
  self.perturb_type = perturb_type
190
  self.perturb_rank_shift = perturb_rank_shift
 
216
  self.emb_layer = emb_layer
217
  self.forward_batch_size = forward_batch_size
218
  self.nproc = nproc
 
 
219
 
220
  self.validate_options()
221
 
222
  # load token dictionary (Ensembl IDs:token)
 
 
223
  with open(token_dictionary_file, "rb") as f:
224
  self.gene_token_dict = pickle.load(f)
 
225
 
226
  self.pad_token_id = self.gene_token_dict.get("<pad>")
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
  if self.anchor_gene is None:
229
  self.anchor_token = None
 
281
  continue
282
  valid_type = False
283
  for option in valid_options:
284
+ if (option in [bool, int, list, dict]) and isinstance(
285
  attr_value, option
286
  ):
287
  valid_type = True
 
427
  filtered_input_data = pu.load_and_filter(
428
  self.filter_data, self.nproc, input_data_file
429
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
  filtered_input_data = self.apply_additional_filters(filtered_input_data)
431
 
432
  if self.perturb_group is True:
433
+ self.isp_perturb_set(
434
+ model, filtered_input_data, layer_to_quant, output_path_prefix
435
+ )
 
 
 
 
 
436
  else:
437
+ self.isp_perturb_all(
438
+ model, filtered_input_data, layer_to_quant, output_path_prefix
439
+ )
 
 
 
 
 
440
 
441
  def apply_additional_filters(self, filtered_input_data):
442
  # additional filtering of input data dependent on isp mode
 
499
  if self.perturb_type == "delete":
500
  example = pu.delete_indices(example)
501
  elif self.perturb_type == "overexpress":
502
+ example = pu.overexpress_tokens(example, self.max_len)
 
 
503
  example["n_overflow"] = pu.calc_n_overflow(
504
  self.max_len,
505
  example["length"],
 
520
  perturbed_data = filtered_input_data.map(
521
  make_group_perturbation_batch, num_proc=self.nproc
522
  )
 
523
  if self.perturb_type == "overexpress":
524
  filtered_input_data = filtered_input_data.add_column(
525
  "n_overflow", perturbed_data["n_overflow"]
 
552
  layer_to_quant,
553
  self.pad_token_id,
554
  self.forward_batch_size,
 
555
  summary_stat=None,
556
  silent=True,
557
  )
 
571
  layer_to_quant,
572
  self.pad_token_id,
573
  self.forward_batch_size,
 
574
  summary_stat=None,
575
  silent=True,
576
  )
 
670
  cos_sims_dict = self.update_perturbation_dictionary(
671
  cos_sims_dict,
672
  cos_sims_data,
673
+ filtered_input_data,
674
+ indices_to_perturb,
675
  gene_list,
676
  )
677
  else:
 
680
  cos_sims_dict[state] = self.update_perturbation_dictionary(
681
  cos_sims_dict[state],
682
  cos_sims_data[state],
683
+ filtered_input_data,
684
+ indices_to_perturb,
685
  gene_list,
686
  )
687
  del minibatch
 
703
  f"{output_path_prefix}_gene_embs_dict_{self.tokens_to_perturb}",
704
  )
705
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
706
  def isp_perturb_all(
707
  self,
708
  model,
 
721
 
722
  if self.emb_mode == "cell_and_gene":
723
  stored_gene_embs_dict = defaultdict(list)
724
+ for i in trange(len(filtered_input_data)):
725
+ example_cell = filtered_input_data.select([i])
 
 
726
  full_original_emb = get_embs(
727
  model,
728
  example_cell,
 
730
  layer_to_quant,
731
  self.pad_token_id,
732
  self.forward_batch_size,
 
733
  summary_stat=None,
734
  silent=True,
735
  )
736
 
 
 
 
 
 
737
  # gene_list is used to assign cos sims back to genes
 
738
  # need to remove the anchor gene
739
+ gene_list = example_cell["input_ids"][0][:]
740
  if self.anchor_token is not None:
741
  for token in self.anchor_token:
742
  gene_list.remove(token)
 
 
 
 
 
 
 
 
743
 
744
  perturbation_batch, indices_to_perturb = pu.make_perturbation_batch(
745
  example_cell,
 
750
  self.nproc,
751
  )
752
 
753
+ full_perturbation_emb = get_embs(
754
+ model,
755
+ perturbation_batch,
756
+ "gene",
757
+ layer_to_quant,
758
+ self.pad_token_id,
759
+ self.forward_batch_size,
760
+ summary_stat=None,
761
+ silent=True,
762
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
763
 
764
+ num_inds_perturbed = 1 + self.combos
765
+ # need to remove overexpressed gene to quantify cosine shifts
766
+ if self.perturb_type == "overexpress":
767
+ perturbation_emb = full_perturbation_emb[:, num_inds_perturbed:, :]
768
+ gene_list = gene_list[
769
+ num_inds_perturbed:
770
+ ] # index 0 is not overexpressed
771
 
772
+ elif self.perturb_type == "delete":
773
+ perturbation_emb = full_perturbation_emb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
774
 
775
+ original_batch = pu.make_comparison_batch(
776
+ full_original_emb, indices_to_perturb, perturb_group=False
 
777
  )
778
 
779
+ if self.cell_states_to_model is None or self.emb_mode == "cell_and_gene":
780
+ gene_cos_sims = pu.quant_cos_sims(
781
+ perturbation_emb,
782
+ original_batch,
783
+ self.cell_states_to_model,
784
+ self.state_embs_dict,
785
+ emb_mode="gene",
786
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
787
  if self.cell_states_to_model is not None:
788
+ original_cell_emb = pu.compute_nonpadded_cell_embedding(
789
+ full_original_emb, "mean_pool"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
790
  )
791
+ perturbation_cell_emb = pu.compute_nonpadded_cell_embedding(
792
+ full_perturbation_emb, "mean_pool"
 
 
 
 
 
 
 
 
 
793
  )
 
 
 
 
794
 
795
+ cell_cos_sims = pu.quant_cos_sims(
796
+ perturbation_cell_emb,
797
+ original_cell_emb,
798
+ self.cell_states_to_model,
799
+ self.state_embs_dict,
800
+ emb_mode="cell",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
801
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
802
 
803
+ if self.emb_mode == "cell_and_gene":
804
+ # remove perturbed index for gene list
805
+ perturbed_gene_dict = {
806
+ gene: gene_list[:i] + gene_list[i + 1 :]
807
+ for i, gene in enumerate(gene_list)
808
+ }
 
 
809
 
810
+ for perturbation_i, perturbed_gene in enumerate(gene_list):
811
+ for gene_j, affected_gene in enumerate(
812
+ perturbed_gene_dict[perturbed_gene]
813
+ ):
814
+ try:
815
+ stored_gene_embs_dict[
816
+ (perturbed_gene, affected_gene)
817
+ ].append(gene_cos_sims[perturbation_i, gene_j].item())
818
+ except KeyError:
819
+ stored_gene_embs_dict[
820
+ (perturbed_gene, affected_gene)
821
+ ] = gene_cos_sims[perturbation_i, gene_j].item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
822
 
823
+ if self.cell_states_to_model is None:
824
+ cos_sims_data = torch.mean(gene_cos_sims, dim=1)
825
+ cos_sims_dict = self.update_perturbation_dictionary(
826
+ cos_sims_dict,
827
+ cos_sims_data,
828
+ filtered_input_data,
829
+ indices_to_perturb,
830
+ gene_list,
831
+ )
832
+ else:
833
+ cos_sims_data = cell_cos_sims
834
+ for state in cos_sims_dict.keys():
835
+ cos_sims_dict[state] = self.update_perturbation_dictionary(
836
+ cos_sims_dict[state],
837
+ cos_sims_data[state],
838
+ filtered_input_data,
839
+ indices_to_perturb,
840
+ gene_list,
841
  )
842
 
843
+ # save dict to disk every 100 cells
844
+ if i % 100 == 0:
845
+ pu.write_perturbation_dictionary(
846
+ cos_sims_dict,
847
+ f"{output_path_prefix}_dict_cell_embs_1Kbatch{pickle_batch}",
848
+ )
849
+ if self.emb_mode == "cell_and_gene":
850
+ pu.write_perturbation_dictionary(
851
+ stored_gene_embs_dict,
852
+ f"{output_path_prefix}_dict_gene_embs_1Kbatch{pickle_batch}",
853
  )
854
 
855
+ # reset and clear memory every 1000 cells
856
+ if i % 1000 == 0:
857
+ pickle_batch += 1
858
+ if self.cell_states_to_model is None:
859
+ cos_sims_dict = defaultdict(list)
860
+ else:
861
+ cos_sims_dict = {
862
+ state: defaultdict(list)
863
+ for state in pu.get_possible_states(self.cell_states_to_model)
864
+ }
 
 
865
 
866
+ if self.emb_mode == "cell_and_gene":
867
+ stored_gene_embs_dict = defaultdict(list)
 
 
868
 
869
+ torch.cuda.empty_cache()
 
 
 
 
 
 
870
 
871
+ pu.write_perturbation_dictionary(
872
+ cos_sims_dict, f"{output_path_prefix}_dict_cell_embs_1Kbatch{pickle_batch}"
873
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
874
 
875
+ if self.emb_mode == "cell_and_gene":
876
  pu.write_perturbation_dictionary(
877
+ stored_gene_embs_dict,
878
+ f"{output_path_prefix}_dict_gene_embs_1Kbatch{pickle_batch}",
879
  )
880
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
881
  def update_perturbation_dictionary(
882
  self,
883
  cos_sims_dict: defaultdict,
884
  cos_sims_data: torch.Tensor,
885
+ filtered_input_data: Dataset,
886
+ indices_to_perturb: List[List[int]],
887
  gene_list=None,
888
  ):
889
  if gene_list is not None and cos_sims_data.shape[0] != len(gene_list):
890
  logger.error(
891
  f"len(cos_sims_data.shape[0]) != len(gene_list). \n \
892
+ cos_sims_data.shape[0] = {cos_sims_data.shape[0]}.\n \
893
+ len(gene_list) = {len(gene_list)}."
894
  )
895
  raise
896
 
geneformer/in_silico_perturber_stats.py CHANGED
@@ -37,8 +37,10 @@ from scipy.stats import ranksums
37
  from sklearn.mixture import GaussianMixture
38
  from tqdm.auto import tqdm, trange
39
 
40
- from . import ENSEMBL_DICTIONARY_FILE, TOKEN_DICTIONARY_FILE
41
  from .perturber_utils import flatten_list, validate_cell_states_to_model
 
 
 
42
 
43
  logger = logging.getLogger(__name__)
44
 
@@ -114,7 +116,6 @@ def read_dictionaries(
114
  state_dict[state_value][key] += new_dict[key]
115
  except KeyError:
116
  state_dict[state_value][key] = new_dict[key]
117
-
118
  if not file_found:
119
  logger.error(
120
  "No raw data for processing found within provided directory. "
@@ -191,69 +192,34 @@ def get_impact_component(test_value, gaussian_mixture_model):
191
 
192
 
193
  # aggregate data for single perturbation in multiple cells
194
- def isp_aggregate_grouped_perturb(cos_sims_df, dict_list, genes_perturbed):
195
- names = ["Cosine_sim", "Gene"]
196
- cos_sims_full_dfs = []
197
- if isinstance(genes_perturbed, list):
198
- if len(genes_perturbed) > 1:
199
- gene_ids_df = cos_sims_df.loc[
200
- np.isin(
201
- [set(idx) for idx in cos_sims_df["Ensembl_ID"]],
202
- set(genes_perturbed),
203
- ),
204
- :,
205
- ]
206
- else:
207
- gene_ids_df = cos_sims_df.loc[
208
- np.isin(cos_sims_df["Ensembl_ID"], genes_perturbed), :
209
- ]
210
- else:
211
- logger.error(
212
- "aggregate_data is for perturbation of single gene or single group of genes. genes_to_perturb should be formatted as list."
213
- )
214
- raise
215
-
216
- if gene_ids_df.empty:
217
- logger.error("genes_to_perturb not found in data.")
218
- raise
219
-
220
- tokens = gene_ids_df["Gene"]
221
- symbols = gene_ids_df["Gene_name"]
222
-
223
- for token, symbol in zip(tokens, symbols):
224
- cos_shift_data = []
225
- for dict_i in dict_list:
226
- cos_shift_data += dict_i.get((token, "cell_emb"), [])
227
-
228
- df = pd.DataFrame(columns=names)
229
- df["Cosine_sim"] = cos_shift_data
230
- df["Gene"] = symbol
231
- cos_sims_full_dfs.append(df)
232
 
233
- return pd.concat(cos_sims_full_dfs)
 
 
 
 
 
234
 
235
 
236
  def find(variable, x):
237
  try:
238
  if x in variable: # Test if variable is iterable and contains x
239
  return True
240
- elif x == variable:
241
- return True
242
  except (ValueError, TypeError):
243
  return x == variable # Test if variable is x if non-iterable
244
 
245
 
246
  def isp_aggregate_gene_shifts(
247
- cos_sims_df, dict_list, gene_token_id_dict, gene_id_name_dict, token_dtype
248
  ):
249
  cos_shift_data = dict()
250
  for i in trange(cos_sims_df.shape[0]):
251
  token = cos_sims_df["Gene"][i]
252
  for dict_i in dict_list:
253
- if token_dtype == "nontuple":
254
- affected_pairs = [k for k, v in dict_i.items() if k[0] == token]
255
- else:
256
- affected_pairs = [k for k, v in dict_i.items() if find(k[0], token)]
257
  for key in affected_pairs:
258
  if key in cos_shift_data.keys():
259
  cos_shift_data[key] += dict_i.get(key, [])
@@ -266,11 +232,11 @@ def isp_aggregate_gene_shifts(
266
  cos_sims_full_df = pd.DataFrame()
267
  cos_sims_full_df["Perturbed"] = [k[0] for k, v in cos_data_mean.items()]
268
  cos_sims_full_df["Gene_name"] = [
269
- cos_sims_df[cos_sims_df["Gene"] == k[0]]["Gene_name"].item()
270
  for k, v in cos_data_mean.items()
271
  ]
272
  cos_sims_full_df["Ensembl_ID"] = [
273
- cos_sims_df[cos_sims_df["Gene"] == k[0]]["Ensembl_ID"].item()
274
  for k, v in cos_data_mean.items()
275
  ]
276
 
@@ -282,15 +248,15 @@ def isp_aggregate_gene_shifts(
282
  cos_sims_full_df["Affected_Ensembl_ID"] = [
283
  gene_token_id_dict.get(token, np.nan) for token in cos_sims_full_df["Affected"]
284
  ]
285
- cos_sims_full_df["Cosine_sim_mean"] = [v[0] for k, v in cos_data_mean.items()]
286
- cos_sims_full_df["Cosine_sim_stdev"] = [v[1] for k, v in cos_data_mean.items()]
287
  cos_sims_full_df["N_Detections"] = [v[2] for k, v in cos_data_mean.items()]
288
 
289
  specific_val = "cell_emb"
290
  cos_sims_full_df["temp"] = list(cos_sims_full_df["Affected"] == specific_val)
291
- # reorder so cell embs are at the top and all are subordered by magnitude of cosine sim
292
  cos_sims_full_df = cos_sims_full_df.sort_values(
293
- by=(["temp", "Cosine_sim_mean"]), ascending=[False, True]
294
  ).drop("temp", axis=1)
295
 
296
  return cos_sims_full_df
@@ -681,7 +647,7 @@ class InSilicoPerturberStats:
681
  cell_states_to_model=None,
682
  pickle_suffix="_raw.pickle",
683
  token_dictionary_file=TOKEN_DICTIONARY_FILE,
684
- gene_name_id_dictionary_file=ENSEMBL_DICTIONARY_FILE,
685
  ):
686
  """
687
  Initialize in silico perturber stats generator.
@@ -700,7 +666,7 @@ class InSilicoPerturberStats:
700
  | Default is assuming genes_to_perturb in isp experiment was "all" (each gene in each cell).
701
  | Otherwise, may provide a list of ENSEMBL IDs of genes perturbed as a group all together.
702
  combos : {0,1,2}
703
- | Whether genex perturbed in isp experiment were perturbed individually (0), in pairs (1), or in triplets (2).
704
  anchor_gene : None, str
705
  | ENSEMBL ID of gene to use as anchor in combination perturbations or in testing effect on downstream genes.
706
  | For example, if combos=1 and anchor_gene="ENSG00000136574":
@@ -948,11 +914,11 @@ class InSilicoPerturberStats:
948
  | 1: within impact component; 0: not within impact component
949
  | "Impact_component_percent": percent of cells in which given perturbation was modeled to be within impact component
950
 
951
- | In case of aggregating data / gene shifts:
952
  | "Perturbed": ID(s) of gene(s) being perturbed
953
  | "Affected": ID of affected gene or "cell_emb" indicating the impact on the cell embedding as a whole
954
- | "Cosine_sim_mean": mean of cosine similarity of cell or affected gene in original vs. perturbed
955
- | "Cosine_sim_stdev": standard deviation of cosine similarity of cell or affected gene in original vs. perturbed
956
  """
957
 
958
  if self.mode not in [
@@ -1052,30 +1018,14 @@ class InSilicoPerturberStats:
1052
  )
1053
 
1054
  elif self.mode == "aggregate_data":
1055
- cos_sims_df = isp_aggregate_grouped_perturb(
1056
- cos_sims_df_initial, dict_list, self.genes_perturbed
1057
- )
1058
 
1059
  elif self.mode == "aggregate_gene_shifts":
1060
- if (self.genes_perturbed == "all") and (self.combos == 0):
1061
- tuple_types = [
1062
- True if isinstance(genes, tuple) else False for genes in gene_list
1063
- ]
1064
- if all(tuple_types):
1065
- token_dtype = "tuple"
1066
- elif not any(tuple_types):
1067
- token_dtype = "nontuple"
1068
- else:
1069
- token_dtype = "mix"
1070
- else:
1071
- token_dtype = "mix"
1072
-
1073
  cos_sims_df = isp_aggregate_gene_shifts(
1074
  cos_sims_df_initial,
1075
  dict_list,
1076
  self.gene_token_id_dict,
1077
  self.gene_id_name_dict,
1078
- token_dtype,
1079
  )
1080
 
1081
  # save perturbation stats to output_path
 
37
  from sklearn.mixture import GaussianMixture
38
  from tqdm.auto import tqdm, trange
39
 
 
40
  from .perturber_utils import flatten_list, validate_cell_states_to_model
41
+ from .tokenizer import TOKEN_DICTIONARY_FILE
42
+
43
+ GENE_NAME_ID_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl"
44
 
45
  logger = logging.getLogger(__name__)
46
 
 
116
  state_dict[state_value][key] += new_dict[key]
117
  except KeyError:
118
  state_dict[state_value][key] = new_dict[key]
 
119
  if not file_found:
120
  logger.error(
121
  "No raw data for processing found within provided directory. "
 
192
 
193
 
194
  # aggregate data for single perturbation in multiple cells
195
+ def isp_aggregate_grouped_perturb(cos_sims_df, dict_list):
196
+ names = ["Cosine_shift"]
197
+ cos_sims_full_df = pd.DataFrame(columns=names)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
+ cos_shift_data = []
200
+ token = cos_sims_df["Gene"][0]
201
+ for dict_i in dict_list:
202
+ cos_shift_data += dict_i.get((token, "cell_emb"), [])
203
+ cos_sims_full_df["Cosine_shift"] = cos_shift_data
204
+ return cos_sims_full_df
205
 
206
 
207
  def find(variable, x):
208
  try:
209
  if x in variable: # Test if variable is iterable and contains x
210
  return True
 
 
211
  except (ValueError, TypeError):
212
  return x == variable # Test if variable is x if non-iterable
213
 
214
 
215
  def isp_aggregate_gene_shifts(
216
+ cos_sims_df, dict_list, gene_token_id_dict, gene_id_name_dict
217
  ):
218
  cos_shift_data = dict()
219
  for i in trange(cos_sims_df.shape[0]):
220
  token = cos_sims_df["Gene"][i]
221
  for dict_i in dict_list:
222
+ affected_pairs = [k for k, v in dict_i.items() if find(k[0], token)]
 
 
 
223
  for key in affected_pairs:
224
  if key in cos_shift_data.keys():
225
  cos_shift_data[key] += dict_i.get(key, [])
 
232
  cos_sims_full_df = pd.DataFrame()
233
  cos_sims_full_df["Perturbed"] = [k[0] for k, v in cos_data_mean.items()]
234
  cos_sims_full_df["Gene_name"] = [
235
+ cos_sims_df[cos_sims_df["Gene"] == k[0]]["Gene_name"][0]
236
  for k, v in cos_data_mean.items()
237
  ]
238
  cos_sims_full_df["Ensembl_ID"] = [
239
+ cos_sims_df[cos_sims_df["Gene"] == k[0]]["Ensembl_ID"][0]
240
  for k, v in cos_data_mean.items()
241
  ]
242
 
 
248
  cos_sims_full_df["Affected_Ensembl_ID"] = [
249
  gene_token_id_dict.get(token, np.nan) for token in cos_sims_full_df["Affected"]
250
  ]
251
+ cos_sims_full_df["Cosine_shift_mean"] = [v[0] for k, v in cos_data_mean.items()]
252
+ cos_sims_full_df["Cosine_shift_stdev"] = [v[1] for k, v in cos_data_mean.items()]
253
  cos_sims_full_df["N_Detections"] = [v[2] for k, v in cos_data_mean.items()]
254
 
255
  specific_val = "cell_emb"
256
  cos_sims_full_df["temp"] = list(cos_sims_full_df["Affected"] == specific_val)
257
+ # reorder so cell embs are at the top and all are subordered by magnitude of cosine shift
258
  cos_sims_full_df = cos_sims_full_df.sort_values(
259
+ by=(["temp", "Cosine_shift_mean"]), ascending=[False, False]
260
  ).drop("temp", axis=1)
261
 
262
  return cos_sims_full_df
 
647
  cell_states_to_model=None,
648
  pickle_suffix="_raw.pickle",
649
  token_dictionary_file=TOKEN_DICTIONARY_FILE,
650
+ gene_name_id_dictionary_file=GENE_NAME_ID_DICTIONARY_FILE,
651
  ):
652
  """
653
  Initialize in silico perturber stats generator.
 
666
  | Default is assuming genes_to_perturb in isp experiment was "all" (each gene in each cell).
667
  | Otherwise, may provide a list of ENSEMBL IDs of genes perturbed as a group all together.
668
  combos : {0,1,2}
669
+ | Whether to perturb genes individually (0), in pairs (1), or in triplets (2).
670
  anchor_gene : None, str
671
  | ENSEMBL ID of gene to use as anchor in combination perturbations or in testing effect on downstream genes.
672
  | For example, if combos=1 and anchor_gene="ENSG00000136574":
 
914
  | 1: within impact component; 0: not within impact component
915
  | "Impact_component_percent": percent of cells in which given perturbation was modeled to be within impact component
916
 
917
+ | In case of aggregating gene shifts:
918
  | "Perturbed": ID(s) of gene(s) being perturbed
919
  | "Affected": ID of affected gene or "cell_emb" indicating the impact on the cell embedding as a whole
920
+ | "Cosine_shift_mean": mean of cosine shift of modeled perturbation on affected gene or cell
921
+ | "Cosine_shift_stdev": standard deviation of cosine shift of modeled perturbation on affected gene or cell
922
  """
923
 
924
  if self.mode not in [
 
1018
  )
1019
 
1020
  elif self.mode == "aggregate_data":
1021
+ cos_sims_df = isp_aggregate_grouped_perturb(cos_sims_df_initial, dict_list)
 
 
1022
 
1023
  elif self.mode == "aggregate_gene_shifts":
 
 
 
 
 
 
 
 
 
 
 
 
 
1024
  cos_sims_df = isp_aggregate_gene_shifts(
1025
  cos_sims_df_initial,
1026
  dict_list,
1027
  self.gene_token_id_dict,
1028
  self.gene_id_name_dict,
 
1029
  )
1030
 
1031
  # save perturbation stats to output_path
geneformer/mtl/__init__.py DELETED
@@ -1 +0,0 @@
1
- # ruff: noqa: F401
 
 
geneformer/mtl/collators.py DELETED
@@ -1,76 +0,0 @@
1
- # imports
2
- import torch
3
- import pickle
4
- from ..collator_for_classification import DataCollatorForGeneClassification
5
- from .. import TOKEN_DICTIONARY_FILE
6
-
7
- """Geneformer collator for multi-task cell classification."""
8
-
9
- class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassification):
10
- class_type = "cell"
11
-
12
- @staticmethod
13
- def load_token_dictionary():
14
- with open(TOKEN_DICTIONARY_FILE, 'rb') as f:
15
- return pickle.load(f)
16
-
17
- def __init__(self, *args, **kwargs) -> None:
18
- # Load the token dictionary
19
- token_dictionary = self.load_token_dictionary()
20
- # Use the loaded token dictionary
21
- super().__init__(token_dictionary=token_dictionary, *args, **kwargs)
22
-
23
- def _prepare_batch(self, features):
24
- # Process inputs as usual
25
- batch = self.tokenizer.pad(
26
- features,
27
- class_type=self.class_type,
28
- padding=self.padding,
29
- max_length=self.max_length,
30
- pad_to_multiple_of=self.pad_to_multiple_of,
31
- return_tensors="pt",
32
- )
33
-
34
- # Check if labels are present
35
- if "label" in features[0]:
36
- # Initialize labels dictionary for all tasks
37
- labels = {task: [] for task in features[0]["label"].keys()}
38
- # Populate labels for each task
39
- for feature in features:
40
- for task, label in feature["label"].items():
41
- labels[task].append(label)
42
-
43
- # Convert label lists to tensors, handling dictionaries appropriately
44
- for task in labels:
45
- if isinstance(labels[task][0], (list, torch.Tensor)):
46
- dtype = torch.long
47
- labels[task] = torch.tensor(labels[task], dtype=dtype)
48
- elif isinstance(labels[task][0], dict):
49
- # Handle dict specifically if needed
50
- pass # Resolve nested data structure
51
-
52
- # Update the batch to include task-specific labels
53
- batch["labels"] = labels
54
- else:
55
- # If no labels are present, create empty labels for all tasks
56
- batch["labels"] = {
57
- task: torch.tensor([], dtype=torch.long)
58
- for task in features[0]["input_ids"].keys()
59
- }
60
-
61
- return batch
62
-
63
- def __call__(self, features):
64
- batch = self._prepare_batch(features)
65
- for k, v in batch.items():
66
- if torch.is_tensor(v):
67
- batch[k] = v.clone().detach()
68
- elif isinstance(v, dict):
69
- # Assuming nested structure needs conversion
70
- batch[k] = {
71
- task: torch.tensor(labels, dtype=torch.int64)
72
- for task, labels in v.items()
73
- }
74
- else:
75
- batch[k] = torch.tensor(v, dtype=torch.int64)
76
- return batch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
geneformer/mtl/data.py DELETED
@@ -1,150 +0,0 @@
1
- import os
2
-
3
- from .collators import DataCollatorForMultitaskCellClassification
4
- from .imports import *
5
-
6
-
7
- def load_and_preprocess_data(dataset_path, config, is_test=False, dataset_type=""):
8
- try:
9
- dataset = load_from_disk(dataset_path)
10
-
11
- task_names = [f"task{i+1}" for i in range(len(config["task_columns"]))]
12
- task_to_column = dict(zip(task_names, config["task_columns"]))
13
- config["task_names"] = task_names
14
-
15
- if not is_test:
16
- available_columns = set(dataset.column_names)
17
- for column in task_to_column.values():
18
- if column not in available_columns:
19
- raise KeyError(
20
- f"Column {column} not found in the dataset. Available columns: {list(available_columns)}"
21
- )
22
-
23
- label_mappings = {}
24
- task_label_mappings = {}
25
- cell_id_mapping = {}
26
- num_labels_list = []
27
-
28
- # Load or create task label mappings
29
- if not is_test:
30
- for task, column in task_to_column.items():
31
- unique_values = sorted(set(dataset[column])) # Ensure consistency
32
- label_mappings[column] = {
33
- label: idx for idx, label in enumerate(unique_values)
34
- }
35
- task_label_mappings[task] = label_mappings[column]
36
- num_labels_list.append(len(unique_values))
37
-
38
- # Print the mappings for each task with dataset type prefix
39
- for task, mapping in task_label_mappings.items():
40
- print(
41
- f"{dataset_type.capitalize()} mapping for {task}: {mapping}"
42
- ) # sanity check, for train/validation splits
43
-
44
- # Save the task label mappings as a pickle file
45
- with open(f"{config['results_dir']}/task_label_mappings.pkl", "wb") as f:
46
- pickle.dump(task_label_mappings, f)
47
- else:
48
- # Load task label mappings from pickle file for test data
49
- with open(f"{config['results_dir']}/task_label_mappings.pkl", "rb") as f:
50
- task_label_mappings = pickle.load(f)
51
-
52
- # Infer num_labels_list from task_label_mappings
53
- for task, mapping in task_label_mappings.items():
54
- num_labels_list.append(len(mapping))
55
-
56
- # Store unique cell IDs in a separate dictionary
57
- for idx, record in enumerate(dataset):
58
- cell_id = record.get("unique_cell_id", idx)
59
- cell_id_mapping[idx] = cell_id
60
-
61
- # Transform records to the desired format
62
- transformed_dataset = []
63
- for idx, record in enumerate(dataset):
64
- transformed_record = {}
65
- transformed_record["input_ids"] = torch.tensor(
66
- record["input_ids"], dtype=torch.long
67
- )
68
-
69
- # Use index-based cell ID for internal tracking
70
- transformed_record["cell_id"] = idx
71
-
72
- if not is_test:
73
- # Prepare labels
74
- label_dict = {}
75
- for task, column in task_to_column.items():
76
- label_value = record[column]
77
- label_index = task_label_mappings[task][label_value]
78
- label_dict[task] = label_index
79
- transformed_record["label"] = label_dict
80
- else:
81
- # Create dummy labels for test data
82
- label_dict = {task: -1 for task in config["task_names"]}
83
- transformed_record["label"] = label_dict
84
-
85
- transformed_dataset.append(transformed_record)
86
-
87
- return transformed_dataset, cell_id_mapping, num_labels_list
88
- except KeyError as e:
89
- print(f"Missing configuration or dataset key: {e}")
90
- except Exception as e:
91
- print(f"An error occurred while loading or preprocessing data: {e}")
92
- return None, None, None
93
-
94
-
95
- def preload_and_process_data(config):
96
- # Load and preprocess data once
97
- train_dataset, train_cell_id_mapping, num_labels_list = load_and_preprocess_data(
98
- config["train_path"], config, dataset_type="train"
99
- )
100
- val_dataset, val_cell_id_mapping, _ = load_and_preprocess_data(
101
- config["val_path"], config, dataset_type="validation"
102
- )
103
- return (
104
- train_dataset,
105
- train_cell_id_mapping,
106
- val_dataset,
107
- val_cell_id_mapping,
108
- num_labels_list,
109
- )
110
-
111
-
112
- def get_data_loader(preprocessed_dataset, batch_size):
113
- nproc = os.cpu_count() ### I/O operations
114
-
115
- data_collator = DataCollatorForMultitaskCellClassification()
116
-
117
- loader = DataLoader(
118
- preprocessed_dataset,
119
- batch_size=batch_size,
120
- shuffle=True,
121
- collate_fn=data_collator,
122
- num_workers=nproc,
123
- pin_memory=True,
124
- )
125
- return loader
126
-
127
-
128
- def preload_data(config):
129
- # Preprocessing the data before the Optuna trials start
130
- train_loader = get_data_loader("train", config)
131
- val_loader = get_data_loader("val", config)
132
- return train_loader, val_loader
133
-
134
-
135
- def load_and_preprocess_test_data(config):
136
- """
137
- Load and preprocess test data, treating it as unlabeled.
138
- """
139
- return load_and_preprocess_data(config["test_path"], config, is_test=True)
140
-
141
-
142
- def prepare_test_loader(config):
143
- """
144
- Prepare DataLoader for the test dataset.
145
- """
146
- test_dataset, cell_id_mapping, num_labels_list = load_and_preprocess_test_data(
147
- config
148
- )
149
- test_loader = get_data_loader(test_dataset, config["batch_size"])
150
- return test_loader, cell_id_mapping, num_labels_list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
geneformer/mtl/eval_utils.py DELETED
@@ -1,88 +0,0 @@
1
- import pandas as pd
2
-
3
- from .imports import * # noqa # isort:skip
4
- from .data import prepare_test_loader # noqa # isort:skip
5
- from .model import GeneformerMultiTask
6
-
7
-
8
- def evaluate_test_dataset(model, device, test_loader, cell_id_mapping, config):
9
- task_pred_labels = {task_name: [] for task_name in config["task_names"]}
10
- task_pred_probs = {task_name: [] for task_name in config["task_names"]}
11
- cell_ids = []
12
-
13
- # # Load task label mappings from pickle file
14
- # with open(f"{config['results_dir']}/task_label_mappings.pkl", "rb") as f:
15
- # task_label_mappings = pickle.load(f)
16
-
17
- model.eval()
18
- with torch.no_grad():
19
- for batch in test_loader:
20
- input_ids = batch["input_ids"].to(device)
21
- attention_mask = batch["attention_mask"].to(device)
22
- _, logits, _ = model(input_ids, attention_mask)
23
- for sample_idx in range(len(batch["input_ids"])):
24
- cell_id = cell_id_mapping[batch["cell_id"][sample_idx].item()]
25
- cell_ids.append(cell_id)
26
- for i, task_name in enumerate(config["task_names"]):
27
- pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item()
28
- pred_prob = (
29
- torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy()
30
- )
31
- task_pred_labels[task_name].append(pred_label)
32
- task_pred_probs[task_name].append(pred_prob)
33
-
34
- # Save test predictions with cell IDs and probabilities to CSV
35
- test_results_dir = config["results_dir"]
36
- os.makedirs(test_results_dir, exist_ok=True)
37
- test_preds_file = os.path.join(test_results_dir, "test_preds.csv")
38
-
39
- rows = []
40
- for sample_idx in range(len(cell_ids)):
41
- row = {"Cell ID": cell_ids[sample_idx]}
42
- for task_name in config["task_names"]:
43
- row[f"{task_name} Prediction"] = task_pred_labels[task_name][sample_idx]
44
- row[f"{task_name} Probabilities"] = ",".join(
45
- map(str, task_pred_probs[task_name][sample_idx])
46
- )
47
- rows.append(row)
48
-
49
- df = pd.DataFrame(rows)
50
- df.to_csv(test_preds_file, index=False)
51
- print(f"Test predictions saved to {test_preds_file}")
52
-
53
-
54
- def load_and_evaluate_test_model(config):
55
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
56
- test_loader, cell_id_mapping, num_labels_list = prepare_test_loader(config)
57
- model_directory = os.path.join(config["model_save_path"], "GeneformerMultiTask")
58
- hyperparams_path = os.path.join(model_directory, "hyperparameters.json")
59
-
60
- # Load the saved best hyperparameters
61
- with open(hyperparams_path, "r") as f:
62
- best_hyperparams = json.load(f)
63
-
64
- # Extract the task weights if present, otherwise set to None
65
- task_weights = best_hyperparams.get("task_weights", None)
66
- normalized_task_weights = task_weights if task_weights else []
67
-
68
- # Print the loaded hyperparameters
69
- print("Loaded hyperparameters:")
70
- for param, value in best_hyperparams.items():
71
- if param == "task_weights":
72
- print(f"normalized_task_weights: {value}")
73
- else:
74
- print(f"{param}: {value}")
75
-
76
- best_model_path = os.path.join(model_directory, "pytorch_model.bin")
77
- best_model = GeneformerMultiTask(
78
- config["pretrained_path"],
79
- num_labels_list,
80
- dropout_rate=best_hyperparams["dropout_rate"],
81
- use_task_weights=config["use_task_weights"],
82
- task_weights=normalized_task_weights,
83
- )
84
- best_model.load_state_dict(torch.load(best_model_path))
85
- best_model.to(device)
86
-
87
- evaluate_test_dataset(best_model, device, test_loader, cell_id_mapping, config)
88
- print("Evaluation completed.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
geneformer/mtl/imports.py DELETED
@@ -1,43 +0,0 @@
1
- import functools
2
- import gc
3
- import json
4
- import os
5
- import pickle
6
- import sys
7
- import warnings
8
- from enum import Enum
9
- from itertools import chain
10
- from typing import Dict, List, Optional, Union
11
-
12
- import numpy as np
13
- import optuna
14
- import pandas as pd
15
- import torch
16
- import torch.nn as nn
17
- import torch.nn.functional as F
18
- import torch.optim as optim
19
- from datasets import load_from_disk
20
- from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, roc_curve
21
- from sklearn.model_selection import train_test_split
22
- from sklearn.preprocessing import LabelEncoder
23
- from torch.utils.data import DataLoader
24
- from transformers import (
25
- AdamW,
26
- BatchEncoding,
27
- BertConfig,
28
- BertModel,
29
- DataCollatorForTokenClassification,
30
- SpecialTokensMixin,
31
- get_cosine_schedule_with_warmup,
32
- get_linear_schedule_with_warmup,
33
- get_scheduler,
34
- )
35
- from transformers.utils import logging, to_py_obj
36
-
37
- from .collators import DataCollatorForMultitaskCellClassification
38
-
39
- # local modules
40
- from .data import get_data_loader, preload_and_process_data
41
- from .model import GeneformerMultiTask
42
- from .optuna_utils import create_optuna_study
43
- from .utils import save_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
geneformer/mtl/model.py DELETED
@@ -1,121 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from transformers import BertConfig, BertModel
4
-
5
-
6
- class AttentionPool(nn.Module):
7
- """Attention-based pooling layer."""
8
-
9
- def __init__(self, hidden_size):
10
- super(AttentionPool, self).__init__()
11
- self.attention_weights = nn.Parameter(torch.randn(hidden_size, 1))
12
- nn.init.xavier_uniform_(
13
- self.attention_weights
14
- ) # https://pytorch.org/docs/stable/nn.init.html
15
-
16
- def forward(self, hidden_states):
17
- attention_scores = torch.matmul(hidden_states, self.attention_weights)
18
- attention_scores = torch.softmax(attention_scores, dim=1)
19
- pooled_output = torch.sum(hidden_states * attention_scores, dim=1)
20
- return pooled_output
21
-
22
-
23
- class GeneformerMultiTask(nn.Module):
24
- def __init__(
25
- self,
26
- pretrained_path,
27
- num_labels_list,
28
- dropout_rate=0.1,
29
- use_task_weights=False,
30
- task_weights=None,
31
- max_layers_to_freeze=0,
32
- use_attention_pooling=False,
33
- ):
34
- super(GeneformerMultiTask, self).__init__()
35
- self.config = BertConfig.from_pretrained(pretrained_path)
36
- self.bert = BertModel(self.config)
37
- self.num_labels_list = num_labels_list
38
- self.use_task_weights = use_task_weights
39
- self.dropout = nn.Dropout(dropout_rate)
40
- self.use_attention_pooling = use_attention_pooling
41
-
42
- if use_task_weights and (
43
- task_weights is None or len(task_weights) != len(num_labels_list)
44
- ):
45
- raise ValueError(
46
- "Task weights must be defined and match the number of tasks when 'use_task_weights' is True."
47
- )
48
- self.task_weights = (
49
- task_weights if use_task_weights else [1.0] * len(num_labels_list)
50
- )
51
-
52
- # Freeze the specified initial layers
53
- for layer in self.bert.encoder.layer[:max_layers_to_freeze]:
54
- for param in layer.parameters():
55
- param.requires_grad = False
56
-
57
- self.attention_pool = (
58
- AttentionPool(self.config.hidden_size) if use_attention_pooling else None
59
- )
60
-
61
- self.classification_heads = nn.ModuleList(
62
- [
63
- nn.Linear(self.config.hidden_size, num_labels)
64
- for num_labels in num_labels_list
65
- ]
66
- )
67
- # initialization of the classification heads: https://pytorch.org/docs/stable/nn.init.html
68
- for head in self.classification_heads:
69
- nn.init.xavier_uniform_(head.weight)
70
- nn.init.zeros_(head.bias)
71
-
72
- def forward(self, input_ids, attention_mask, labels=None):
73
- try:
74
- outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
75
- except Exception as e:
76
- raise RuntimeError(f"Error during BERT forward pass: {e}")
77
-
78
- sequence_output = outputs.last_hidden_state
79
-
80
- try:
81
- pooled_output = (
82
- self.attention_pool(sequence_output)
83
- if self.use_attention_pooling
84
- else sequence_output[:, 0, :]
85
- )
86
- pooled_output = self.dropout(pooled_output)
87
- except Exception as e:
88
- raise RuntimeError(f"Error during pooling and dropout: {e}")
89
-
90
- total_loss = 0
91
- logits = []
92
- losses = []
93
-
94
- for task_id, (head, num_labels) in enumerate(
95
- zip(self.classification_heads, self.num_labels_list)
96
- ):
97
- try:
98
- task_logits = head(pooled_output)
99
- except Exception as e:
100
- raise RuntimeError(
101
- f"Error during forward pass of classification head {task_id}: {e}"
102
- )
103
-
104
- logits.append(task_logits)
105
-
106
- if labels is not None:
107
- try:
108
- loss_fct = nn.CrossEntropyLoss()
109
- task_loss = loss_fct(
110
- task_logits.view(-1, num_labels), labels[task_id].view(-1)
111
- )
112
- if self.use_task_weights:
113
- task_loss *= self.task_weights[task_id]
114
- total_loss += task_loss
115
- losses.append(task_loss.item())
116
- except Exception as e:
117
- raise RuntimeError(
118
- f"Error during loss computation for task {task_id}: {e}"
119
- )
120
-
121
- return total_loss, logits, losses if labels is not None else logits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
geneformer/mtl/optuna_utils.py DELETED
@@ -1,27 +0,0 @@
1
- import optuna
2
- from optuna.integration import TensorBoardCallback
3
-
4
-
5
- def save_trial_callback(study, trial, trials_result_path):
6
- with open(trials_result_path, "a") as f:
7
- f.write(
8
- f"Trial {trial.number}: Value (F1 Macro): {trial.value}, Params: {trial.params}\n"
9
- )
10
-
11
-
12
- def create_optuna_study(objective, n_trials, trials_result_path, tensorboard_log_dir):
13
- study = optuna.create_study(direction="maximize")
14
-
15
- # init TensorBoard callback
16
- tensorboard_callback = TensorBoardCallback(
17
- dirname=tensorboard_log_dir, metric_name="F1 Macro"
18
- )
19
-
20
- # callback and TensorBoard callback
21
- callbacks = [
22
- lambda study, trial: save_trial_callback(study, trial, trials_result_path),
23
- tensorboard_callback,
24
- ]
25
-
26
- study.optimize(objective, n_trials=n_trials, callbacks=callbacks)
27
- return study
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
geneformer/mtl/train.py DELETED
@@ -1,380 +0,0 @@
1
- import os
2
- import random
3
-
4
- import numpy as np
5
- import pandas as pd
6
- import torch
7
- from torch.utils.tensorboard import SummaryWriter
8
- from tqdm import tqdm
9
-
10
- from .imports import *
11
- from .model import GeneformerMultiTask
12
- from .utils import calculate_task_specific_metrics, get_layer_freeze_range
13
-
14
-
15
- def set_seed(seed):
16
- random.seed(seed)
17
- np.random.seed(seed)
18
- torch.manual_seed(seed)
19
- torch.cuda.manual_seed_all(seed)
20
- torch.backends.cudnn.deterministic = True
21
- torch.backends.cudnn.benchmark = False
22
-
23
-
24
- def initialize_wandb(config):
25
- if config.get("use_wandb", False):
26
- import wandb
27
-
28
- wandb.init(project=config["wandb_project"], config=config)
29
- print("Weights & Biases (wandb) initialized and will be used for logging.")
30
- else:
31
- print(
32
- "Weights & Biases (wandb) is not enabled. Logging will use other methods."
33
- )
34
-
35
-
36
- def create_model(config, num_labels_list, device):
37
- model = GeneformerMultiTask(
38
- config["pretrained_path"],
39
- num_labels_list,
40
- dropout_rate=config["dropout_rate"],
41
- use_task_weights=config["use_task_weights"],
42
- task_weights=config["task_weights"],
43
- max_layers_to_freeze=config["max_layers_to_freeze"],
44
- use_attention_pooling=config["use_attention_pooling"],
45
- )
46
- if config["use_data_parallel"]:
47
- model = nn.DataParallel(model)
48
- return model.to(device)
49
-
50
-
51
- def setup_optimizer_and_scheduler(model, config, total_steps):
52
- optimizer = AdamW(
53
- model.parameters(),
54
- lr=config["learning_rate"],
55
- weight_decay=config["weight_decay"],
56
- )
57
- warmup_steps = int(config["warmup_ratio"] * total_steps)
58
-
59
- if config["lr_scheduler_type"] == "linear":
60
- scheduler = get_linear_schedule_with_warmup(
61
- optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps
62
- )
63
- elif config["lr_scheduler_type"] == "cosine":
64
- scheduler = get_cosine_schedule_with_warmup(
65
- optimizer,
66
- num_warmup_steps=warmup_steps,
67
- num_training_steps=total_steps,
68
- num_cycles=0.5,
69
- )
70
-
71
- return optimizer, scheduler
72
-
73
-
74
- def train_epoch(
75
- model, train_loader, optimizer, scheduler, device, config, writer, epoch
76
- ):
77
- model.train()
78
- progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']}")
79
- for batch_idx, batch in enumerate(progress_bar):
80
- optimizer.zero_grad()
81
- input_ids = batch["input_ids"].to(device)
82
- attention_mask = batch["attention_mask"].to(device)
83
- labels = [
84
- batch["labels"][task_name].to(device) for task_name in config["task_names"]
85
- ]
86
-
87
- loss, _, _ = model(input_ids, attention_mask, labels)
88
- loss.backward()
89
-
90
- if config["gradient_clipping"]:
91
- torch.nn.utils.clip_grad_norm_(model.parameters(), config["max_grad_norm"])
92
-
93
- optimizer.step()
94
- scheduler.step()
95
-
96
- writer.add_scalar(
97
- "Training Loss", loss.item(), epoch * len(train_loader) + batch_idx
98
- )
99
- if config.get("use_wandb", False):
100
- import wandb
101
-
102
- wandb.log({"Training Loss": loss.item()})
103
-
104
- # Update progress bar
105
- progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})
106
-
107
- return loss.item() # Return the last batch loss
108
-
109
-
110
- def validate_model(model, val_loader, device, config):
111
- model.eval()
112
- val_loss = 0.0
113
- task_true_labels = {task_name: [] for task_name in config["task_names"]}
114
- task_pred_labels = {task_name: [] for task_name in config["task_names"]}
115
- task_pred_probs = {task_name: [] for task_name in config["task_names"]}
116
-
117
- with torch.no_grad():
118
- for batch in val_loader:
119
- input_ids = batch["input_ids"].to(device)
120
- attention_mask = batch["attention_mask"].to(device)
121
- labels = [
122
- batch["labels"][task_name].to(device)
123
- for task_name in config["task_names"]
124
- ]
125
- loss, logits, _ = model(input_ids, attention_mask, labels)
126
- val_loss += loss.item()
127
-
128
- for sample_idx in range(len(batch["input_ids"])):
129
- for i, task_name in enumerate(config["task_names"]):
130
- true_label = batch["labels"][task_name][sample_idx].item()
131
- pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item()
132
- pred_prob = (
133
- torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy()
134
- )
135
- task_true_labels[task_name].append(true_label)
136
- task_pred_labels[task_name].append(pred_label)
137
- task_pred_probs[task_name].append(pred_prob)
138
-
139
- val_loss /= len(val_loader)
140
- return val_loss, task_true_labels, task_pred_labels, task_pred_probs
141
-
142
-
143
- def log_metrics(task_metrics, val_loss, config, writer, epochs):
144
- for task_name, metrics in task_metrics.items():
145
- print(
146
- f"{task_name} - Validation F1 Macro: {metrics['f1']:.4f}, Validation Accuracy: {metrics['accuracy']:.4f}"
147
- )
148
- if config.get("use_wandb", False):
149
- import wandb
150
-
151
- wandb.log(
152
- {
153
- f"{task_name} Validation F1 Macro": metrics["f1"],
154
- f"{task_name} Validation Accuracy": metrics["accuracy"],
155
- }
156
- )
157
-
158
- writer.add_scalar("Validation Loss", val_loss, epochs)
159
- for task_name, metrics in task_metrics.items():
160
- writer.add_scalar(f"{task_name} - Validation F1 Macro", metrics["f1"], epochs)
161
- writer.add_scalar(
162
- f"{task_name} - Validation Accuracy", metrics["accuracy"], epochs
163
- )
164
-
165
-
166
- def save_validation_predictions(
167
- val_cell_id_mapping,
168
- task_true_labels,
169
- task_pred_labels,
170
- task_pred_probs,
171
- config,
172
- trial_number=None,
173
- ):
174
- if trial_number is not None:
175
- trial_results_dir = os.path.join(config["results_dir"], f"trial_{trial_number}")
176
- os.makedirs(trial_results_dir, exist_ok=True)
177
- val_preds_file = os.path.join(trial_results_dir, "val_preds.csv")
178
- else:
179
- val_preds_file = os.path.join(config["results_dir"], "manual_run_val_preds.csv")
180
-
181
- rows = []
182
- for sample_idx in range(len(val_cell_id_mapping)):
183
- row = {"Cell ID": val_cell_id_mapping[sample_idx]}
184
- for task_name in config["task_names"]:
185
- row[f"{task_name} True"] = task_true_labels[task_name][sample_idx]
186
- row[f"{task_name} Pred"] = task_pred_labels[task_name][sample_idx]
187
- row[f"{task_name} Probabilities"] = ",".join(
188
- map(str, task_pred_probs[task_name][sample_idx])
189
- )
190
- rows.append(row)
191
-
192
- df = pd.DataFrame(rows)
193
- df.to_csv(val_preds_file, index=False)
194
- print(f"Validation predictions saved to {val_preds_file}")
195
-
196
-
197
- def train_model(
198
- config,
199
- device,
200
- train_loader,
201
- val_loader,
202
- train_cell_id_mapping,
203
- val_cell_id_mapping,
204
- num_labels_list,
205
- ):
206
- set_seed(config["seed"])
207
- initialize_wandb(config)
208
-
209
- model = create_model(config, num_labels_list, device)
210
- total_steps = len(train_loader) * config["epochs"]
211
- optimizer, scheduler = setup_optimizer_and_scheduler(model, config, total_steps)
212
-
213
- log_dir = os.path.join(config["tensorboard_log_dir"], "manual_run")
214
- writer = SummaryWriter(log_dir=log_dir)
215
-
216
- epoch_progress = tqdm(range(config["epochs"]), desc="Training Progress")
217
- for epoch in epoch_progress:
218
- last_loss = train_epoch(
219
- model, train_loader, optimizer, scheduler, device, config, writer, epoch
220
- )
221
- epoch_progress.set_postfix({"last_loss": f"{last_loss:.4f}"})
222
-
223
- val_loss, task_true_labels, task_pred_labels, task_pred_probs = validate_model(
224
- model, val_loader, device, config
225
- )
226
- task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels)
227
-
228
- log_metrics(task_metrics, val_loss, config, writer, config["epochs"])
229
- writer.close()
230
-
231
- save_validation_predictions(
232
- val_cell_id_mapping, task_true_labels, task_pred_labels, task_pred_probs, config
233
- )
234
-
235
- if config.get("use_wandb", False):
236
- import wandb
237
-
238
- wandb.finish()
239
-
240
- print(f"\nFinal Validation Loss: {val_loss:.4f}")
241
- return val_loss, model # Return both the validation loss and the trained model
242
-
243
-
244
- def objective(
245
- trial,
246
- train_loader,
247
- val_loader,
248
- train_cell_id_mapping,
249
- val_cell_id_mapping,
250
- num_labels_list,
251
- config,
252
- device,
253
- ):
254
- set_seed(config["seed"]) # Set the seed before each trial
255
- initialize_wandb(config)
256
-
257
- # Hyperparameters
258
- config["learning_rate"] = trial.suggest_float(
259
- "learning_rate",
260
- config["hyperparameters"]["learning_rate"]["low"],
261
- config["hyperparameters"]["learning_rate"]["high"],
262
- log=config["hyperparameters"]["learning_rate"]["log"],
263
- )
264
- config["warmup_ratio"] = trial.suggest_float(
265
- "warmup_ratio",
266
- config["hyperparameters"]["warmup_ratio"]["low"],
267
- config["hyperparameters"]["warmup_ratio"]["high"],
268
- )
269
- config["weight_decay"] = trial.suggest_float(
270
- "weight_decay",
271
- config["hyperparameters"]["weight_decay"]["low"],
272
- config["hyperparameters"]["weight_decay"]["high"],
273
- )
274
- config["dropout_rate"] = trial.suggest_float(
275
- "dropout_rate",
276
- config["hyperparameters"]["dropout_rate"]["low"],
277
- config["hyperparameters"]["dropout_rate"]["high"],
278
- )
279
- config["lr_scheduler_type"] = trial.suggest_categorical(
280
- "lr_scheduler_type", config["hyperparameters"]["lr_scheduler_type"]["choices"]
281
- )
282
- config["use_attention_pooling"] = trial.suggest_categorical(
283
- "use_attention_pooling", [False]
284
- )
285
-
286
- if config["use_task_weights"]:
287
- config["task_weights"] = [
288
- trial.suggest_float(
289
- f"task_weight_{i}",
290
- config["hyperparameters"]["task_weights"]["low"],
291
- config["hyperparameters"]["task_weights"]["high"],
292
- )
293
- for i in range(len(num_labels_list))
294
- ]
295
- weight_sum = sum(config["task_weights"])
296
- config["task_weights"] = [
297
- weight / weight_sum for weight in config["task_weights"]
298
- ]
299
- else:
300
- config["task_weights"] = None
301
-
302
- # Dynamic range for max_layers_to_freeze
303
- freeze_range = get_layer_freeze_range(config["pretrained_path"])
304
- config["max_layers_to_freeze"] = trial.suggest_int(
305
- "max_layers_to_freeze",
306
- freeze_range["min"],
307
- freeze_range["max"]
308
- )
309
-
310
- model = create_model(config, num_labels_list, device)
311
- total_steps = len(train_loader) * config["epochs"]
312
- optimizer, scheduler = setup_optimizer_and_scheduler(model, config, total_steps)
313
-
314
- log_dir = os.path.join(config["tensorboard_log_dir"], f"trial_{trial.number}")
315
- writer = SummaryWriter(log_dir=log_dir)
316
-
317
- for epoch in range(config["epochs"]):
318
- train_epoch(
319
- model, train_loader, optimizer, scheduler, device, config, writer, epoch
320
- )
321
-
322
- val_loss, task_true_labels, task_pred_labels, task_pred_probs = validate_model(
323
- model, val_loader, device, config
324
- )
325
- task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels)
326
-
327
- log_metrics(task_metrics, val_loss, config, writer, config["epochs"])
328
- writer.close()
329
-
330
- save_validation_predictions(
331
- val_cell_id_mapping,
332
- task_true_labels,
333
- task_pred_labels,
334
- task_pred_probs,
335
- config,
336
- trial.number,
337
- )
338
-
339
- trial.set_user_attr("model_state_dict", model.state_dict())
340
- trial.set_user_attr("task_weights", config["task_weights"])
341
-
342
- trial.report(val_loss, config["epochs"])
343
-
344
- if trial.should_prune():
345
- raise optuna.TrialPruned()
346
-
347
- if config.get("use_wandb", False):
348
- import wandb
349
-
350
- wandb.log(
351
- {
352
- "trial_number": trial.number,
353
- "val_loss": val_loss,
354
- **{
355
- f"{task_name}_f1": metrics["f1"]
356
- for task_name, metrics in task_metrics.items()
357
- },
358
- **{
359
- f"{task_name}_accuracy": metrics["accuracy"]
360
- for task_name, metrics in task_metrics.items()
361
- },
362
- **{
363
- k: v
364
- for k, v in config.items()
365
- if k
366
- in [
367
- "learning_rate",
368
- "warmup_ratio",
369
- "weight_decay",
370
- "dropout_rate",
371
- "lr_scheduler_type",
372
- "use_attention_pooling",
373
- "max_layers_to_freeze",
374
- ]
375
- },
376
- }
377
- )
378
- wandb.finish()
379
-
380
- return val_loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
geneformer/mtl/train_utils.py DELETED
@@ -1,161 +0,0 @@
1
- import random
2
-
3
- from .data import get_data_loader, preload_and_process_data
4
- from .imports import *
5
- from .model import GeneformerMultiTask
6
- from .train import objective, train_model
7
- from .utils import save_model
8
-
9
-
10
- def set_seed(seed):
11
- random.seed(seed)
12
- np.random.seed(seed)
13
- torch.manual_seed(seed)
14
- torch.cuda.manual_seed_all(seed)
15
- torch.backends.cudnn.deterministic = True
16
- torch.backends.cudnn.benchmark = False
17
-
18
-
19
- def run_manual_tuning(config):
20
- # Set seed for reproducibility
21
- set_seed(config["seed"])
22
-
23
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
- (
25
- train_dataset,
26
- train_cell_id_mapping,
27
- val_dataset,
28
- val_cell_id_mapping,
29
- num_labels_list,
30
- ) = preload_and_process_data(config)
31
- train_loader = get_data_loader(train_dataset, config["batch_size"])
32
- val_loader = get_data_loader(val_dataset, config["batch_size"])
33
-
34
- # Print the manual hyperparameters being used
35
- print("\nManual hyperparameters being used:")
36
- for key, value in config["manual_hyperparameters"].items():
37
- print(f"{key}: {value}")
38
- print() # Add an empty line for better readability
39
-
40
- # Use the manual hyperparameters
41
- for key, value in config["manual_hyperparameters"].items():
42
- config[key] = value
43
-
44
- # Train the model
45
- val_loss, trained_model = train_model(
46
- config,
47
- device,
48
- train_loader,
49
- val_loader,
50
- train_cell_id_mapping,
51
- val_cell_id_mapping,
52
- num_labels_list,
53
- )
54
-
55
- print(f"\nValidation loss with manual hyperparameters: {val_loss}")
56
-
57
- # Save the trained model
58
- model_save_directory = os.path.join(
59
- config["model_save_path"], "GeneformerMultiTask"
60
- )
61
- save_model(trained_model, model_save_directory)
62
-
63
- # Save the hyperparameters
64
- hyperparams_to_save = {
65
- **config["manual_hyperparameters"],
66
- "dropout_rate": config["dropout_rate"],
67
- "use_task_weights": config["use_task_weights"],
68
- "task_weights": config["task_weights"],
69
- "max_layers_to_freeze": config["max_layers_to_freeze"],
70
- "use_attention_pooling": config["use_attention_pooling"],
71
- }
72
- hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json")
73
- with open(hyperparams_path, "w") as f:
74
- json.dump(hyperparams_to_save, f)
75
- print(f"Manual hyperparameters saved to {hyperparams_path}")
76
-
77
- return val_loss
78
-
79
-
80
- def run_optuna_study(config):
81
- # Set seed for reproducibility
82
- set_seed(config["seed"])
83
-
84
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
85
- (
86
- train_dataset,
87
- train_cell_id_mapping,
88
- val_dataset,
89
- val_cell_id_mapping,
90
- num_labels_list,
91
- ) = preload_and_process_data(config)
92
- train_loader = get_data_loader(train_dataset, config["batch_size"])
93
- val_loader = get_data_loader(val_dataset, config["batch_size"])
94
-
95
- if config["use_manual_hyperparameters"]:
96
- train_model(
97
- config,
98
- device,
99
- train_loader,
100
- val_loader,
101
- train_cell_id_mapping,
102
- val_cell_id_mapping,
103
- num_labels_list,
104
- )
105
- else:
106
- objective_with_config_and_data = functools.partial(
107
- objective,
108
- train_loader=train_loader,
109
- val_loader=val_loader,
110
- train_cell_id_mapping=train_cell_id_mapping,
111
- val_cell_id_mapping=val_cell_id_mapping,
112
- num_labels_list=num_labels_list,
113
- config=config,
114
- device=device,
115
- )
116
-
117
- study = optuna.create_study(
118
- direction="minimize", # Minimize validation loss
119
- study_name=config["study_name"],
120
- # storage=config["storage"],
121
- load_if_exists=True,
122
- )
123
-
124
- study.optimize(objective_with_config_and_data, n_trials=config["n_trials"])
125
-
126
- # After finding the best trial
127
- best_params = study.best_trial.params
128
- best_task_weights = study.best_trial.user_attrs["task_weights"]
129
- print("Saving the best model and its hyperparameters...")
130
-
131
- # Saving model as before
132
- best_model = GeneformerMultiTask(
133
- config["pretrained_path"],
134
- num_labels_list,
135
- dropout_rate=best_params["dropout_rate"],
136
- use_task_weights=config["use_task_weights"],
137
- task_weights=best_task_weights,
138
- )
139
-
140
- # Get the best model state dictionary
141
- best_model_state_dict = study.best_trial.user_attrs["model_state_dict"]
142
-
143
- # Remove the "module." prefix from the state dictionary keys if present
144
- best_model_state_dict = {
145
- k.replace("module.", ""): v for k, v in best_model_state_dict.items()
146
- }
147
-
148
- # Load the modified state dictionary into the model, skipping unexpected keys
149
- best_model.load_state_dict(best_model_state_dict, strict=False)
150
-
151
- model_save_directory = os.path.join(
152
- config["model_save_path"], "GeneformerMultiTask"
153
- )
154
- save_model(best_model, model_save_directory)
155
-
156
- # Additionally, save the best hyperparameters and task weights
157
- hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json")
158
-
159
- with open(hyperparams_path, "w") as f:
160
- json.dump({**best_params, "task_weights": best_task_weights}, f)
161
- print(f"Best hyperparameters and task weights saved to {hyperparams_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
geneformer/mtl/utils.py DELETED
@@ -1,129 +0,0 @@
1
- import os
2
- import shutil
3
-
4
- from sklearn.metrics import accuracy_score, f1_score
5
- from sklearn.preprocessing import LabelEncoder
6
- from transformers import AutoConfig, BertConfig, BertModel
7
-
8
- from .imports import *
9
-
10
-
11
- def save_model(model, model_save_directory):
12
- if not os.path.exists(model_save_directory):
13
- os.makedirs(model_save_directory)
14
-
15
- # Get the state dict
16
- if isinstance(model, nn.DataParallel):
17
- model_state_dict = (
18
- model.module.state_dict()
19
- ) # Use model.module to access the underlying model
20
- else:
21
- model_state_dict = model.state_dict()
22
-
23
- # Remove the "module." prefix from the keys if present
24
- model_state_dict = {
25
- k.replace("module.", ""): v for k, v in model_state_dict.items()
26
- }
27
-
28
- model_save_path = os.path.join(model_save_directory, "pytorch_model.bin")
29
- torch.save(model_state_dict, model_save_path)
30
-
31
- # Save the model configuration
32
- if isinstance(model, nn.DataParallel):
33
- model.module.config.to_json_file(
34
- os.path.join(model_save_directory, "config.json")
35
- )
36
- else:
37
- model.config.to_json_file(os.path.join(model_save_directory, "config.json"))
38
-
39
- print(f"Model and configuration saved to {model_save_directory}")
40
-
41
-
42
- def calculate_task_specific_metrics(task_true_labels, task_pred_labels):
43
- task_metrics = {}
44
- for task_name in task_true_labels.keys():
45
- true_labels = task_true_labels[task_name]
46
- pred_labels = task_pred_labels[task_name]
47
- f1 = f1_score(true_labels, pred_labels, average="macro")
48
- accuracy = accuracy_score(true_labels, pred_labels)
49
- task_metrics[task_name] = {"f1": f1, "accuracy": accuracy}
50
- return task_metrics
51
-
52
-
53
- def calculate_combined_f1(combined_labels, combined_preds):
54
- # Initialize the LabelEncoder
55
- le = LabelEncoder()
56
-
57
- # Fit and transform combined labels and predictions to numerical values
58
- le.fit(combined_labels + combined_preds)
59
- encoded_true_labels = le.transform(combined_labels)
60
- encoded_pred_labels = le.transform(combined_preds)
61
-
62
- # Print out the mapping for sanity check
63
- print("\nLabel Encoder Mapping:")
64
- for index, class_label in enumerate(le.classes_):
65
- print(f"'{class_label}': {index}")
66
-
67
- # Calculate accuracy
68
- accuracy = accuracy_score(encoded_true_labels, encoded_pred_labels)
69
-
70
- # Calculate F1 Macro score
71
- f1 = f1_score(encoded_true_labels, encoded_pred_labels, average="macro")
72
-
73
- return f1, accuracy
74
-
75
-
76
- # def save_model_without_heads(original_model_save_directory):
77
- # # Create a new directory for the model without heads
78
- # new_model_save_directory = original_model_save_directory + "_No_Heads"
79
- # if not os.path.exists(new_model_save_directory):
80
- # os.makedirs(new_model_save_directory)
81
-
82
- # # Load the model state dictionary
83
- # model_state_dict = torch.load(
84
- # os.path.join(original_model_save_directory, "pytorch_model.bin")
85
- # )
86
-
87
- # # Initialize a new BERT model without the classification heads
88
- # config = BertConfig.from_pretrained(
89
- # os.path.join(original_model_save_directory, "config.json")
90
- # )
91
- # model_without_heads = BertModel(config)
92
-
93
- # # Filter the state dict to exclude classification heads
94
- # model_without_heads_state_dict = {
95
- # k: v
96
- # for k, v in model_state_dict.items()
97
- # if not k.startswith("classification_heads")
98
- # }
99
-
100
- # # Load the filtered state dict into the model
101
- # model_without_heads.load_state_dict(model_without_heads_state_dict, strict=False)
102
-
103
- # # Save the model without heads
104
- # model_save_path = os.path.join(new_model_save_directory, "pytorch_model.bin")
105
- # torch.save(model_without_heads.state_dict(), model_save_path)
106
-
107
- # # Copy the configuration file
108
- # shutil.copy(
109
- # os.path.join(original_model_save_directory, "config.json"),
110
- # new_model_save_directory,
111
- # )
112
-
113
- # print(f"Model without classification heads saved to {new_model_save_directory}")
114
-
115
-
116
- def get_layer_freeze_range(pretrained_path):
117
- """
118
- Dynamically determines the number of layers to freeze based on the model depth from its configuration.
119
- Args:
120
- pretrained_path (str): Path to the pretrained model directory or model identifier.
121
- Returns:
122
- dict: A dictionary with 'min' and 'max' keys indicating the range of layers to freeze.
123
- """
124
- if pretrained_path:
125
- config = AutoConfig.from_pretrained(pretrained_path)
126
- total_layers = config.num_hidden_layers
127
- return {"min": 0, "max": total_layers - 1}
128
- else:
129
- return {"min": 0, "max": 0}