commit from jbz 2023.08.29 for the first time
Browse files- MANIFEST.in +3 -0
- README.md +67 -0
- config.json +24 -0
- examples/cell_classification.ipynb +1952 -0
- examples/extract_and_plot_cell_embeddings.ipynb +0 -0
- examples/gene_classification.ipynb +0 -0
- examples/hyperparam_optimiz_for_disease_classifier.py +226 -0
- examples/in_silico_perturbation.ipynb +110 -0
- examples/pretraining_new_model/obtain_nonzero_median_digests.ipynb +365 -0
- examples/pretraining_new_model/pretrain_geneformer_w_deepspeed.py +165 -0
- examples/tokenizing_scRNAseq_data.ipynb +72 -0
- fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224/config.json +35 -0
- fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224/trainer_state.json +150 -0
- geneformer-12L-30M/config.json +23 -0
- geneformer/__init__.py +12 -0
- geneformer/collator_for_classification.py +602 -0
- geneformer/emb_extractor.py +493 -0
- geneformer/gene_median_dictionary.pkl +3 -0
- geneformer/gene_name_id_dict.pkl +3 -0
- geneformer/in_silico_perturber.py +1297 -0
- geneformer/in_silico_perturber_stats.py +716 -0
- geneformer/pretrainer.py +822 -0
- geneformer/token_dictionary.pkl +3 -0
- geneformer/tokenizer.py +235 -0
- generation_config.json +5 -0
- pytorch_model.bin +3 -0
- setup.py +21 -0
- training_args.bin +3 -0
MANIFEST.in
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
include geneformer/gene_median_dictionary.pkl
|
2 |
+
include geneformer/token_dictionary.pkl
|
3 |
+
include geneformer/gene_name_id_dict.pkl
|
README.md
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
datasets: ctheodoris/Genecorpus-30M
|
3 |
+
license: apache-2.0
|
4 |
+
---
|
5 |
+
# Geneformer
|
6 |
+
Geneformer is a foundation transformer model pretrained on a large-scale corpus of ~30 million single cell transcriptomes to enable context-aware predictions in settings with limited data in network biology.
|
7 |
+
|
8 |
+
See [our manuscript](https://rdcu.be/ddrx0) for details.
|
9 |
+
|
10 |
+
# Model Description
|
11 |
+
Geneformer is a foundation transformer model pretrained on [Genecorpus-30M](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M), a pretraining corpus comprised of ~30 million single cell transcriptomes from a broad range of human tissues. We excluded cells with high mutational burdens (e.g. malignant cells and immortalized cell lines) that could lead to substantial network rewiring without companion genome sequencing to facilitate interpretation. Each single cell’s transcriptome is presented to the model as a rank value encoding where genes are ranked by their expression in that cell normalized by their expression across the entire Genecorpus-30M. The rank value encoding provides a nonparametric representation of that cell’s transcriptome and takes advantage of the many observations of each gene’s expression across Genecorpus-30M to prioritize genes that distinguish cell state. Specifically, this method will deprioritize ubiquitously highly-expressed housekeeping genes by normalizing them to a lower rank. Conversely, genes such as transcription factors that may be lowly expressed when they are expressed but highly distinguish cell state will move to a higher rank within the encoding. Furthermore, this rank-based approach may be more robust against technical artifacts that may systematically bias the absolute transcript counts value while the overall relative ranking of genes within each cell remains more stable.
|
12 |
+
|
13 |
+
The rank value encoding of each single cell’s transcriptome then proceeds through six transformer encoder units. Pretraining was accomplished using a masked learning objective where 15% of the genes within each transcriptome were masked and the model was trained to predict which gene should be within each masked position in that specific cell state using the context of the remaining unmasked genes. A major strength of this approach is that it is entirely self-supervised and can be accomplished on completely unlabeled data, which allows the inclusion of large amounts of training data without being restricted to samples with accompanying labels.
|
14 |
+
|
15 |
+
We detail applications and results in [our manuscript](https://rdcu.be/ddrx0).
|
16 |
+
|
17 |
+
During pretraining, Geneformer gained a fundamental understanding of network dynamics, encoding network hierarchy in the model’s attention weights in a completely self-supervised manner. Fine-tuning Geneformer towards a diverse panel of downstream tasks relevant to chromatin and network dynamics using limited task-specific data demonstrated that Geneformer consistently boosted predictive accuracy. Applied to disease modeling with limited patient data, Geneformer identified candidate therapeutic targets. Overall, Geneformer represents a pretrained deep learning model from which fine-tuning towards a broad range of downstream applications can be pursued to accelerate discovery of key network regulators and candidate therapeutic targets.
|
18 |
+
|
19 |
+
In [our manuscript](https://rdcu.be/ddrx0), we report results for the 6 layer Geneformer model pretrained on Genecorpus-30M. We additionally provide within this repository a 12 layer Geneformer model, scaled up with retained width:depth aspect ratio, also pretrained on Genecorpus-30M.
|
20 |
+
|
21 |
+
# Application
|
22 |
+
The pretrained Geneformer model can be used directly for zero-shot learning, for example for in silico perturbation analysis, or by fine-tuning towards the relevant downstream task, such as gene or cell state classification.
|
23 |
+
|
24 |
+
Example applications demonstrated in [our manuscript](https://rdcu.be/ddrx0) include:
|
25 |
+
|
26 |
+
*Fine-tuning*:
|
27 |
+
- transcription factor dosage sensitivity
|
28 |
+
- chromatin dynamics (bivalently marked promoters)
|
29 |
+
- transcription factor regulatory range
|
30 |
+
- gene network centrality
|
31 |
+
- transcription factor targets
|
32 |
+
- cell type annotation
|
33 |
+
- batch integration
|
34 |
+
- cell state classification across differentiation
|
35 |
+
- disease classification
|
36 |
+
- in silico perturbation to determine disease-driving genes
|
37 |
+
- in silico treatment to determine candidate therapeutic targets
|
38 |
+
|
39 |
+
*Zero-shot learning*:
|
40 |
+
- batch integration
|
41 |
+
- gene context specificity
|
42 |
+
- in silico reprogramming
|
43 |
+
- in silico differentiation
|
44 |
+
- in silico perturbation to determine impact on cell state
|
45 |
+
- in silico perturbation to determine transcription factor targets
|
46 |
+
- in silico perturbation to determine transcription factor cooperativity
|
47 |
+
|
48 |
+
# Installation
|
49 |
+
In addition to the pretrained model, contained herein are functions for tokenizing and collating data specific to single cell transcriptomics, pretraining the model, fine-tuning the model, extracting and plotting cell embeddings, and performing in silico pertrubation with either the pretrained or fine-tuned models. To install:
|
50 |
+
|
51 |
+
```bash
|
52 |
+
git clone https://huggingface.co/ctheodoris/Geneformer
|
53 |
+
cd Geneformer
|
54 |
+
pip install .
|
55 |
+
```
|
56 |
+
|
57 |
+
For usage, see [examples](https://huggingface.co/ctheodoris/Geneformer/tree/main/examples) for:
|
58 |
+
- tokenizing transcriptomes
|
59 |
+
- pretraining
|
60 |
+
- hyperparameter tuning
|
61 |
+
- fine-tuning
|
62 |
+
- extracting and plotting cell embeddings
|
63 |
+
- in silico perturbation
|
64 |
+
|
65 |
+
Please note that the fine-tuning examples are meant to be generally applicable and the input datasets and labels will vary dependent on the downstream task. Example input files for a few of the downstream tasks demonstrated in the manuscript are located within the [example_input_files directory](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files) in the dataset repository, but these only represent a few example fine-tuning applications.
|
66 |
+
|
67 |
+
Please note that GPU resources are required for efficient usage of Geneformer. Additionally, we strongly recommend tuning hyperparameters for each downstream fine-tuning application as this can significantly boost predictive potential in the downstream task (e.g. max learning rate, learning schedule, number of layers to freeze, etc.).
|
config.json
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"BertForMaskedLM"
|
4 |
+
],
|
5 |
+
"attention_probs_dropout_prob": 0.02,
|
6 |
+
"classifier_dropout": null,
|
7 |
+
"hidden_act": "relu",
|
8 |
+
"hidden_dropout_prob": 0.02,
|
9 |
+
"hidden_size": 256,
|
10 |
+
"initializer_range": 0.02,
|
11 |
+
"intermediate_size": 512,
|
12 |
+
"layer_norm_eps": 1e-12,
|
13 |
+
"max_position_embeddings": 2048,
|
14 |
+
"model_type": "bert",
|
15 |
+
"num_attention_heads": 4,
|
16 |
+
"num_hidden_layers": 6,
|
17 |
+
"pad_token_id": 0,
|
18 |
+
"position_embedding_type": "absolute",
|
19 |
+
"torch_dtype": "float32",
|
20 |
+
"transformers_version": "4.32.0",
|
21 |
+
"type_vocab_size": 2,
|
22 |
+
"use_cache": true,
|
23 |
+
"vocab_size": 15994
|
24 |
+
}
|
examples/cell_classification.ipynb
ADDED
@@ -0,0 +1,1952 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "234afff3",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"## Geneformer Fine-Tuning for Cell Annotation Application"
|
9 |
+
]
|
10 |
+
},
|
11 |
+
{
|
12 |
+
"cell_type": "code",
|
13 |
+
"execution_count": 2,
|
14 |
+
"id": "1cbe6178-ea4d-478a-80a8-65ffaa4c1820",
|
15 |
+
"metadata": {},
|
16 |
+
"outputs": [],
|
17 |
+
"source": [
|
18 |
+
"import os\n",
|
19 |
+
"GPU_NUMBER = [0]\n",
|
20 |
+
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \",\".join([str(s) for s in GPU_NUMBER])\n",
|
21 |
+
"os.environ[\"NCCL_DEBUG\"] = \"INFO\""
|
22 |
+
]
|
23 |
+
},
|
24 |
+
{
|
25 |
+
"cell_type": "code",
|
26 |
+
"execution_count": 3,
|
27 |
+
"id": "a9885d9f-00ac-4c84-b6a3-b7b648a90f0f",
|
28 |
+
"metadata": {},
|
29 |
+
"outputs": [],
|
30 |
+
"source": [
|
31 |
+
"# imports\n",
|
32 |
+
"from collections import Counter\n",
|
33 |
+
"import datetime\n",
|
34 |
+
"import pickle\n",
|
35 |
+
"import subprocess\n",
|
36 |
+
"import seaborn as sns; sns.set()\n",
|
37 |
+
"from datasets import load_from_disk\n",
|
38 |
+
"from sklearn.metrics import accuracy_score, f1_score\n",
|
39 |
+
"from transformers import BertForSequenceClassification\n",
|
40 |
+
"from transformers import Trainer\n",
|
41 |
+
"from transformers.training_args import TrainingArguments\n",
|
42 |
+
"\n",
|
43 |
+
"from geneformer import DataCollatorForCellClassification"
|
44 |
+
]
|
45 |
+
},
|
46 |
+
{
|
47 |
+
"cell_type": "markdown",
|
48 |
+
"id": "68bd3b98-5409-4105-b7af-f1ff64ea6a72",
|
49 |
+
"metadata": {},
|
50 |
+
"source": [
|
51 |
+
"## Prepare training and evaluation datasets"
|
52 |
+
]
|
53 |
+
},
|
54 |
+
{
|
55 |
+
"cell_type": "code",
|
56 |
+
"execution_count": 15,
|
57 |
+
"id": "5735f1b7-7595-4a02-be17-2c5b970ad81a",
|
58 |
+
"metadata": {},
|
59 |
+
"outputs": [],
|
60 |
+
"source": [
|
61 |
+
"# load cell type dataset (includes all tissues)\n",
|
62 |
+
"train_dataset=load_from_disk(\"/path/to/cell_type_train_data.dataset\")"
|
63 |
+
]
|
64 |
+
},
|
65 |
+
{
|
66 |
+
"cell_type": "code",
|
67 |
+
"execution_count": null,
|
68 |
+
"id": "a4297a02-4c4c-434c-ae55-3387a0b239b5",
|
69 |
+
"metadata": {
|
70 |
+
"collapsed": true,
|
71 |
+
"jupyter": {
|
72 |
+
"outputs_hidden": true
|
73 |
+
},
|
74 |
+
"tags": []
|
75 |
+
},
|
76 |
+
"outputs": [],
|
77 |
+
"source": [
|
78 |
+
"dataset_list = []\n",
|
79 |
+
"evalset_list = []\n",
|
80 |
+
"organ_list = []\n",
|
81 |
+
"target_dict_list = []\n",
|
82 |
+
"\n",
|
83 |
+
"for organ in Counter(train_dataset[\"organ_major\"]).keys():\n",
|
84 |
+
" # collect list of tissues for fine-tuning (immune and bone marrow are included together)\n",
|
85 |
+
" if organ in [\"bone_marrow\"]: \n",
|
86 |
+
" continue\n",
|
87 |
+
" elif organ==\"immune\":\n",
|
88 |
+
" organ_ids = [\"immune\",\"bone_marrow\"]\n",
|
89 |
+
" organ_list += [\"immune\"]\n",
|
90 |
+
" else:\n",
|
91 |
+
" organ_ids = [organ]\n",
|
92 |
+
" organ_list += [organ]\n",
|
93 |
+
" \n",
|
94 |
+
" print(organ)\n",
|
95 |
+
" \n",
|
96 |
+
" # filter datasets for given organ\n",
|
97 |
+
" def if_organ(example):\n",
|
98 |
+
" return example[\"organ_major\"] in organ_ids\n",
|
99 |
+
" trainset_organ = train_dataset.filter(if_organ, num_proc=16)\n",
|
100 |
+
" \n",
|
101 |
+
" # per scDeepsort published method, drop cell types representing <0.5% of cells\n",
|
102 |
+
" celltype_counter = Counter(trainset_organ[\"cell_type\"])\n",
|
103 |
+
" total_cells = sum(celltype_counter.values())\n",
|
104 |
+
" cells_to_keep = [k for k,v in celltype_counter.items() if v>(0.005*total_cells)]\n",
|
105 |
+
" def if_not_rare_celltype(example):\n",
|
106 |
+
" return example[\"cell_type\"] in cells_to_keep\n",
|
107 |
+
" trainset_organ_subset = trainset_organ.filter(if_not_rare_celltype, num_proc=16)\n",
|
108 |
+
" \n",
|
109 |
+
" # shuffle datasets and rename columns\n",
|
110 |
+
" trainset_organ_shuffled = trainset_organ_subset.shuffle(seed=42)\n",
|
111 |
+
" trainset_organ_shuffled = trainset_organ_shuffled.rename_column(\"cell_type\",\"label\")\n",
|
112 |
+
" trainset_organ_shuffled = trainset_organ_shuffled.remove_columns(\"organ_major\")\n",
|
113 |
+
" \n",
|
114 |
+
" # create dictionary of cell types : label ids\n",
|
115 |
+
" target_names = list(Counter(trainset_organ_shuffled[\"label\"]).keys())\n",
|
116 |
+
" target_name_id_dict = dict(zip(target_names,[i for i in range(len(target_names))]))\n",
|
117 |
+
" target_dict_list += [target_name_id_dict]\n",
|
118 |
+
" \n",
|
119 |
+
" # change labels to numerical ids\n",
|
120 |
+
" def classes_to_ids(example):\n",
|
121 |
+
" example[\"label\"] = target_name_id_dict[example[\"label\"]]\n",
|
122 |
+
" return example\n",
|
123 |
+
" labeled_trainset = trainset_organ_shuffled.map(classes_to_ids, num_proc=16)\n",
|
124 |
+
" \n",
|
125 |
+
" # create 80/20 train/eval splits\n",
|
126 |
+
" labeled_train_split = labeled_trainset.select([i for i in range(0,round(len(labeled_trainset)*0.8))])\n",
|
127 |
+
" labeled_eval_split = labeled_trainset.select([i for i in range(round(len(labeled_trainset)*0.8),len(labeled_trainset))])\n",
|
128 |
+
" \n",
|
129 |
+
" # filter dataset for cell types in corresponding training set\n",
|
130 |
+
" trained_labels = list(Counter(labeled_train_split[\"label\"]).keys())\n",
|
131 |
+
" def if_trained_label(example):\n",
|
132 |
+
" return example[\"label\"] in trained_labels\n",
|
133 |
+
" labeled_eval_split_subset = labeled_eval_split.filter(if_trained_label, num_proc=16)\n",
|
134 |
+
"\n",
|
135 |
+
" dataset_list += [labeled_train_split]\n",
|
136 |
+
" evalset_list += [labeled_eval_split_subset]"
|
137 |
+
]
|
138 |
+
},
|
139 |
+
{
|
140 |
+
"cell_type": "code",
|
141 |
+
"execution_count": 20,
|
142 |
+
"id": "83e20521-597a-4c54-897b-c4d42ea622c2",
|
143 |
+
"metadata": {},
|
144 |
+
"outputs": [],
|
145 |
+
"source": [
|
146 |
+
"trainset_dict = dict(zip(organ_list,dataset_list))\n",
|
147 |
+
"traintargetdict_dict = dict(zip(organ_list,target_dict_list))\n",
|
148 |
+
"\n",
|
149 |
+
"evalset_dict = dict(zip(organ_list,evalset_list))"
|
150 |
+
]
|
151 |
+
},
|
152 |
+
{
|
153 |
+
"cell_type": "markdown",
|
154 |
+
"id": "10eb110d-ba43-4efc-bc43-1815d6912647",
|
155 |
+
"metadata": {},
|
156 |
+
"source": [
|
157 |
+
"## Fine-Tune With Cell Classification Learning Objective and Quantify Predictive Performance"
|
158 |
+
]
|
159 |
+
},
|
160 |
+
{
|
161 |
+
"cell_type": "code",
|
162 |
+
"execution_count": 18,
|
163 |
+
"id": "cd7b1cfb-f5cb-460e-ae77-769522ece054",
|
164 |
+
"metadata": {},
|
165 |
+
"outputs": [],
|
166 |
+
"source": [
|
167 |
+
"def compute_metrics(pred):\n",
|
168 |
+
" labels = pred.label_ids\n",
|
169 |
+
" preds = pred.predictions.argmax(-1)\n",
|
170 |
+
" # calculate accuracy and macro f1 using sklearn's function\n",
|
171 |
+
" acc = accuracy_score(labels, preds)\n",
|
172 |
+
" macro_f1 = f1_score(labels, preds, average='macro')\n",
|
173 |
+
" return {\n",
|
174 |
+
" 'accuracy': acc,\n",
|
175 |
+
" 'macro_f1': macro_f1\n",
|
176 |
+
" }"
|
177 |
+
]
|
178 |
+
},
|
179 |
+
{
|
180 |
+
"cell_type": "markdown",
|
181 |
+
"id": "beaab7a4-cc13-4e8f-b137-ed18ff7b633c",
|
182 |
+
"metadata": {},
|
183 |
+
"source": [
|
184 |
+
"### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example hyperparameters are defined below, but please see the \"hyperparam_optimiz_for_disease_classifier\" script for an example of how to tune hyperparameters for downstream applications."
|
185 |
+
]
|
186 |
+
},
|
187 |
+
{
|
188 |
+
"cell_type": "code",
|
189 |
+
"execution_count": 19,
|
190 |
+
"id": "d24e1ab7-0131-44bd-b458-1ce5ba31853e",
|
191 |
+
"metadata": {},
|
192 |
+
"outputs": [],
|
193 |
+
"source": [
|
194 |
+
"# set model parameters\n",
|
195 |
+
"# max input size\n",
|
196 |
+
"max_input_size = 2 ** 11 # 2048\n",
|
197 |
+
"\n",
|
198 |
+
"# set training hyperparameters\n",
|
199 |
+
"# max learning rate\n",
|
200 |
+
"max_lr = 5e-5\n",
|
201 |
+
"# how many pretrained layers to freeze\n",
|
202 |
+
"freeze_layers = 0\n",
|
203 |
+
"# number gpus\n",
|
204 |
+
"num_gpus = 1\n",
|
205 |
+
"# number cpu cores\n",
|
206 |
+
"num_proc = 16\n",
|
207 |
+
"# batch size for training and eval\n",
|
208 |
+
"geneformer_batch_size = 12\n",
|
209 |
+
"# learning schedule\n",
|
210 |
+
"lr_schedule_fn = \"linear\"\n",
|
211 |
+
"# warmup steps\n",
|
212 |
+
"warmup_steps = 500\n",
|
213 |
+
"# number of epochs\n",
|
214 |
+
"epochs = 10\n",
|
215 |
+
"# optimizer\n",
|
216 |
+
"optimizer = \"adamw\""
|
217 |
+
]
|
218 |
+
},
|
219 |
+
{
|
220 |
+
"cell_type": "code",
|
221 |
+
"execution_count": 20,
|
222 |
+
"id": "05164c24-5fbf-4372-b26c-a43f3777a88d",
|
223 |
+
"metadata": {},
|
224 |
+
"outputs": [
|
225 |
+
{
|
226 |
+
"name": "stderr",
|
227 |
+
"output_type": "stream",
|
228 |
+
"text": [
|
229 |
+
"Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n",
|
230 |
+
"- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
231 |
+
"- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
232 |
+
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias']\n",
|
233 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
234 |
+
]
|
235 |
+
},
|
236 |
+
{
|
237 |
+
"name": "stdout",
|
238 |
+
"output_type": "stream",
|
239 |
+
"text": [
|
240 |
+
"spleen\n"
|
241 |
+
]
|
242 |
+
},
|
243 |
+
{
|
244 |
+
"name": "stderr",
|
245 |
+
"output_type": "stream",
|
246 |
+
"text": [
|
247 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
248 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
|
249 |
+
]
|
250 |
+
},
|
251 |
+
{
|
252 |
+
"data": {
|
253 |
+
"text/html": [
|
254 |
+
"\n",
|
255 |
+
" <div>\n",
|
256 |
+
" \n",
|
257 |
+
" <progress value='10280' max='10280' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
258 |
+
" [10280/10280 13:33, Epoch 10/10]\n",
|
259 |
+
" </div>\n",
|
260 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
261 |
+
" <thead>\n",
|
262 |
+
" <tr style=\"text-align: left;\">\n",
|
263 |
+
" <th>Epoch</th>\n",
|
264 |
+
" <th>Training Loss</th>\n",
|
265 |
+
" <th>Validation Loss</th>\n",
|
266 |
+
" <th>Accuracy</th>\n",
|
267 |
+
" <th>Macro F1</th>\n",
|
268 |
+
" <th>Weighted F1</th>\n",
|
269 |
+
" </tr>\n",
|
270 |
+
" </thead>\n",
|
271 |
+
" <tbody>\n",
|
272 |
+
" <tr>\n",
|
273 |
+
" <td>1</td>\n",
|
274 |
+
" <td>0.087000</td>\n",
|
275 |
+
" <td>0.068067</td>\n",
|
276 |
+
" <td>0.985404</td>\n",
|
277 |
+
" <td>0.956839</td>\n",
|
278 |
+
" <td>0.985483</td>\n",
|
279 |
+
" </tr>\n",
|
280 |
+
" <tr>\n",
|
281 |
+
" <td>2</td>\n",
|
282 |
+
" <td>0.044400</td>\n",
|
283 |
+
" <td>0.075289</td>\n",
|
284 |
+
" <td>0.985079</td>\n",
|
285 |
+
" <td>0.955069</td>\n",
|
286 |
+
" <td>0.984898</td>\n",
|
287 |
+
" </tr>\n",
|
288 |
+
" <tr>\n",
|
289 |
+
" <td>3</td>\n",
|
290 |
+
" <td>0.066700</td>\n",
|
291 |
+
" <td>0.078703</td>\n",
|
292 |
+
" <td>0.983782</td>\n",
|
293 |
+
" <td>0.953240</td>\n",
|
294 |
+
" <td>0.983959</td>\n",
|
295 |
+
" </tr>\n",
|
296 |
+
" <tr>\n",
|
297 |
+
" <td>4</td>\n",
|
298 |
+
" <td>0.037400</td>\n",
|
299 |
+
" <td>0.057132</td>\n",
|
300 |
+
" <td>0.989945</td>\n",
|
301 |
+
" <td>0.970619</td>\n",
|
302 |
+
" <td>0.989883</td>\n",
|
303 |
+
" </tr>\n",
|
304 |
+
" <tr>\n",
|
305 |
+
" <td>5</td>\n",
|
306 |
+
" <td>0.025000</td>\n",
|
307 |
+
" <td>0.061644</td>\n",
|
308 |
+
" <td>0.988323</td>\n",
|
309 |
+
" <td>0.961126</td>\n",
|
310 |
+
" <td>0.988211</td>\n",
|
311 |
+
" </tr>\n",
|
312 |
+
" <tr>\n",
|
313 |
+
" <td>6</td>\n",
|
314 |
+
" <td>0.022400</td>\n",
|
315 |
+
" <td>0.065323</td>\n",
|
316 |
+
" <td>0.989296</td>\n",
|
317 |
+
" <td>0.969737</td>\n",
|
318 |
+
" <td>0.989362</td>\n",
|
319 |
+
" </tr>\n",
|
320 |
+
" <tr>\n",
|
321 |
+
" <td>7</td>\n",
|
322 |
+
" <td>0.018600</td>\n",
|
323 |
+
" <td>0.063710</td>\n",
|
324 |
+
" <td>0.989620</td>\n",
|
325 |
+
" <td>0.969436</td>\n",
|
326 |
+
" <td>0.989579</td>\n",
|
327 |
+
" </tr>\n",
|
328 |
+
" <tr>\n",
|
329 |
+
" <td>8</td>\n",
|
330 |
+
" <td>0.039800</td>\n",
|
331 |
+
" <td>0.065919</td>\n",
|
332 |
+
" <td>0.989945</td>\n",
|
333 |
+
" <td>0.968065</td>\n",
|
334 |
+
" <td>0.989802</td>\n",
|
335 |
+
" </tr>\n",
|
336 |
+
" <tr>\n",
|
337 |
+
" <td>9</td>\n",
|
338 |
+
" <td>0.030200</td>\n",
|
339 |
+
" <td>0.061359</td>\n",
|
340 |
+
" <td>0.990269</td>\n",
|
341 |
+
" <td>0.971700</td>\n",
|
342 |
+
" <td>0.990314</td>\n",
|
343 |
+
" </tr>\n",
|
344 |
+
" <tr>\n",
|
345 |
+
" <td>10</td>\n",
|
346 |
+
" <td>0.013400</td>\n",
|
347 |
+
" <td>0.059181</td>\n",
|
348 |
+
" <td>0.991567</td>\n",
|
349 |
+
" <td>0.974599</td>\n",
|
350 |
+
" <td>0.991552</td>\n",
|
351 |
+
" </tr>\n",
|
352 |
+
" </tbody>\n",
|
353 |
+
"</table><p>"
|
354 |
+
],
|
355 |
+
"text/plain": [
|
356 |
+
"<IPython.core.display.HTML object>"
|
357 |
+
]
|
358 |
+
},
|
359 |
+
"metadata": {},
|
360 |
+
"output_type": "display_data"
|
361 |
+
},
|
362 |
+
{
|
363 |
+
"name": "stderr",
|
364 |
+
"output_type": "stream",
|
365 |
+
"text": [
|
366 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
367 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
368 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
369 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
370 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
371 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
372 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
373 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
374 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
375 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
376 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
377 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
378 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
379 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
380 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
381 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
382 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
383 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
384 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
385 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
|
386 |
+
]
|
387 |
+
},
|
388 |
+
{
|
389 |
+
"data": {
|
390 |
+
"text/html": [
|
391 |
+
"\n",
|
392 |
+
" <div>\n",
|
393 |
+
" \n",
|
394 |
+
" <progress value='257' max='257' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
395 |
+
" [257/257 00:07]\n",
|
396 |
+
" </div>\n",
|
397 |
+
" "
|
398 |
+
],
|
399 |
+
"text/plain": [
|
400 |
+
"<IPython.core.display.HTML object>"
|
401 |
+
]
|
402 |
+
},
|
403 |
+
"metadata": {},
|
404 |
+
"output_type": "display_data"
|
405 |
+
},
|
406 |
+
{
|
407 |
+
"name": "stderr",
|
408 |
+
"output_type": "stream",
|
409 |
+
"text": [
|
410 |
+
"Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n",
|
411 |
+
"- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
412 |
+
"- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
413 |
+
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias']\n",
|
414 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
415 |
+
]
|
416 |
+
},
|
417 |
+
{
|
418 |
+
"name": "stdout",
|
419 |
+
"output_type": "stream",
|
420 |
+
"text": [
|
421 |
+
"kidney\n"
|
422 |
+
]
|
423 |
+
},
|
424 |
+
{
|
425 |
+
"name": "stderr",
|
426 |
+
"output_type": "stream",
|
427 |
+
"text": [
|
428 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
429 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
|
430 |
+
]
|
431 |
+
},
|
432 |
+
{
|
433 |
+
"data": {
|
434 |
+
"text/html": [
|
435 |
+
"\n",
|
436 |
+
" <div>\n",
|
437 |
+
" \n",
|
438 |
+
" <progress value='29340' max='29340' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
439 |
+
" [29340/29340 45:43, Epoch 10/10]\n",
|
440 |
+
" </div>\n",
|
441 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
442 |
+
" <thead>\n",
|
443 |
+
" <tr style=\"text-align: left;\">\n",
|
444 |
+
" <th>Epoch</th>\n",
|
445 |
+
" <th>Training Loss</th>\n",
|
446 |
+
" <th>Validation Loss</th>\n",
|
447 |
+
" <th>Accuracy</th>\n",
|
448 |
+
" <th>Macro F1</th>\n",
|
449 |
+
" <th>Weighted F1</th>\n",
|
450 |
+
" </tr>\n",
|
451 |
+
" </thead>\n",
|
452 |
+
" <tbody>\n",
|
453 |
+
" <tr>\n",
|
454 |
+
" <td>1</td>\n",
|
455 |
+
" <td>0.326900</td>\n",
|
456 |
+
" <td>0.299193</td>\n",
|
457 |
+
" <td>0.912500</td>\n",
|
458 |
+
" <td>0.823067</td>\n",
|
459 |
+
" <td>0.909627</td>\n",
|
460 |
+
" </tr>\n",
|
461 |
+
" <tr>\n",
|
462 |
+
" <td>2</td>\n",
|
463 |
+
" <td>0.224200</td>\n",
|
464 |
+
" <td>0.239580</td>\n",
|
465 |
+
" <td>0.926477</td>\n",
|
466 |
+
" <td>0.850237</td>\n",
|
467 |
+
" <td>0.923902</td>\n",
|
468 |
+
" </tr>\n",
|
469 |
+
" <tr>\n",
|
470 |
+
" <td>3</td>\n",
|
471 |
+
" <td>0.221600</td>\n",
|
472 |
+
" <td>0.242810</td>\n",
|
473 |
+
" <td>0.930227</td>\n",
|
474 |
+
" <td>0.878553</td>\n",
|
475 |
+
" <td>0.930349</td>\n",
|
476 |
+
" </tr>\n",
|
477 |
+
" <tr>\n",
|
478 |
+
" <td>4</td>\n",
|
479 |
+
" <td>0.166100</td>\n",
|
480 |
+
" <td>0.264178</td>\n",
|
481 |
+
" <td>0.933409</td>\n",
|
482 |
+
" <td>0.884759</td>\n",
|
483 |
+
" <td>0.933031</td>\n",
|
484 |
+
" </tr>\n",
|
485 |
+
" <tr>\n",
|
486 |
+
" <td>5</td>\n",
|
487 |
+
" <td>0.144100</td>\n",
|
488 |
+
" <td>0.279282</td>\n",
|
489 |
+
" <td>0.935000</td>\n",
|
490 |
+
" <td>0.887659</td>\n",
|
491 |
+
" <td>0.934987</td>\n",
|
492 |
+
" </tr>\n",
|
493 |
+
" <tr>\n",
|
494 |
+
" <td>6</td>\n",
|
495 |
+
" <td>0.112800</td>\n",
|
496 |
+
" <td>0.307647</td>\n",
|
497 |
+
" <td>0.935909</td>\n",
|
498 |
+
" <td>0.889239</td>\n",
|
499 |
+
" <td>0.935365</td>\n",
|
500 |
+
" </tr>\n",
|
501 |
+
" <tr>\n",
|
502 |
+
" <td>7</td>\n",
|
503 |
+
" <td>0.084600</td>\n",
|
504 |
+
" <td>0.326399</td>\n",
|
505 |
+
" <td>0.932841</td>\n",
|
506 |
+
" <td>0.892447</td>\n",
|
507 |
+
" <td>0.933191</td>\n",
|
508 |
+
" </tr>\n",
|
509 |
+
" <tr>\n",
|
510 |
+
" <td>8</td>\n",
|
511 |
+
" <td>0.068300</td>\n",
|
512 |
+
" <td>0.332626</td>\n",
|
513 |
+
" <td>0.936591</td>\n",
|
514 |
+
" <td>0.891629</td>\n",
|
515 |
+
" <td>0.936354</td>\n",
|
516 |
+
" </tr>\n",
|
517 |
+
" <tr>\n",
|
518 |
+
" <td>9</td>\n",
|
519 |
+
" <td>0.065500</td>\n",
|
520 |
+
" <td>0.348174</td>\n",
|
521 |
+
" <td>0.935227</td>\n",
|
522 |
+
" <td>0.889484</td>\n",
|
523 |
+
" <td>0.935040</td>\n",
|
524 |
+
" </tr>\n",
|
525 |
+
" <tr>\n",
|
526 |
+
" <td>10</td>\n",
|
527 |
+
" <td>0.046100</td>\n",
|
528 |
+
" <td>0.355350</td>\n",
|
529 |
+
" <td>0.935000</td>\n",
|
530 |
+
" <td>0.894578</td>\n",
|
531 |
+
" <td>0.934971</td>\n",
|
532 |
+
" </tr>\n",
|
533 |
+
" </tbody>\n",
|
534 |
+
"</table><p>"
|
535 |
+
],
|
536 |
+
"text/plain": [
|
537 |
+
"<IPython.core.display.HTML object>"
|
538 |
+
]
|
539 |
+
},
|
540 |
+
"metadata": {},
|
541 |
+
"output_type": "display_data"
|
542 |
+
},
|
543 |
+
{
|
544 |
+
"name": "stderr",
|
545 |
+
"output_type": "stream",
|
546 |
+
"text": [
|
547 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
548 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
549 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
550 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
551 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
552 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
553 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
554 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
555 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
556 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
557 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
558 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
559 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
560 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
561 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
562 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
563 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
564 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
565 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
566 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
|
567 |
+
]
|
568 |
+
},
|
569 |
+
{
|
570 |
+
"data": {
|
571 |
+
"text/html": [
|
572 |
+
"\n",
|
573 |
+
" <div>\n",
|
574 |
+
" \n",
|
575 |
+
" <progress value='734' max='734' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
576 |
+
" [734/734 00:27]\n",
|
577 |
+
" </div>\n",
|
578 |
+
" "
|
579 |
+
],
|
580 |
+
"text/plain": [
|
581 |
+
"<IPython.core.display.HTML object>"
|
582 |
+
]
|
583 |
+
},
|
584 |
+
"metadata": {},
|
585 |
+
"output_type": "display_data"
|
586 |
+
},
|
587 |
+
{
|
588 |
+
"name": "stderr",
|
589 |
+
"output_type": "stream",
|
590 |
+
"text": [
|
591 |
+
"Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n",
|
592 |
+
"- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
593 |
+
"- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
594 |
+
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias']\n",
|
595 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
596 |
+
]
|
597 |
+
},
|
598 |
+
{
|
599 |
+
"name": "stdout",
|
600 |
+
"output_type": "stream",
|
601 |
+
"text": [
|
602 |
+
"lung\n"
|
603 |
+
]
|
604 |
+
},
|
605 |
+
{
|
606 |
+
"name": "stderr",
|
607 |
+
"output_type": "stream",
|
608 |
+
"text": [
|
609 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
610 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
|
611 |
+
]
|
612 |
+
},
|
613 |
+
{
|
614 |
+
"data": {
|
615 |
+
"text/html": [
|
616 |
+
"\n",
|
617 |
+
" <div>\n",
|
618 |
+
" \n",
|
619 |
+
" <progress value='21750' max='21750' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
620 |
+
" [21750/21750 30:32, Epoch 10/10]\n",
|
621 |
+
" </div>\n",
|
622 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
623 |
+
" <thead>\n",
|
624 |
+
" <tr style=\"text-align: left;\">\n",
|
625 |
+
" <th>Epoch</th>\n",
|
626 |
+
" <th>Training Loss</th>\n",
|
627 |
+
" <th>Validation Loss</th>\n",
|
628 |
+
" <th>Accuracy</th>\n",
|
629 |
+
" <th>Macro F1</th>\n",
|
630 |
+
" <th>Weighted F1</th>\n",
|
631 |
+
" </tr>\n",
|
632 |
+
" </thead>\n",
|
633 |
+
" <tbody>\n",
|
634 |
+
" <tr>\n",
|
635 |
+
" <td>1</td>\n",
|
636 |
+
" <td>0.337600</td>\n",
|
637 |
+
" <td>0.341523</td>\n",
|
638 |
+
" <td>0.906360</td>\n",
|
639 |
+
" <td>0.759979</td>\n",
|
640 |
+
" <td>0.899310</td>\n",
|
641 |
+
" </tr>\n",
|
642 |
+
" <tr>\n",
|
643 |
+
" <td>2</td>\n",
|
644 |
+
" <td>0.211900</td>\n",
|
645 |
+
" <td>0.258954</td>\n",
|
646 |
+
" <td>0.928429</td>\n",
|
647 |
+
" <td>0.835534</td>\n",
|
648 |
+
" <td>0.925903</td>\n",
|
649 |
+
" </tr>\n",
|
650 |
+
" <tr>\n",
|
651 |
+
" <td>3</td>\n",
|
652 |
+
" <td>0.208600</td>\n",
|
653 |
+
" <td>0.282081</td>\n",
|
654 |
+
" <td>0.930421</td>\n",
|
655 |
+
" <td>0.842786</td>\n",
|
656 |
+
" <td>0.928013</td>\n",
|
657 |
+
" </tr>\n",
|
658 |
+
" <tr>\n",
|
659 |
+
" <td>4</td>\n",
|
660 |
+
" <td>0.144400</td>\n",
|
661 |
+
" <td>0.253047</td>\n",
|
662 |
+
" <td>0.935479</td>\n",
|
663 |
+
" <td>0.871712</td>\n",
|
664 |
+
" <td>0.935234</td>\n",
|
665 |
+
" </tr>\n",
|
666 |
+
" <tr>\n",
|
667 |
+
" <td>5</td>\n",
|
668 |
+
" <td>0.109200</td>\n",
|
669 |
+
" <td>0.268833</td>\n",
|
670 |
+
" <td>0.939464</td>\n",
|
671 |
+
" <td>0.876173</td>\n",
|
672 |
+
" <td>0.938870</td>\n",
|
673 |
+
" </tr>\n",
|
674 |
+
" <tr>\n",
|
675 |
+
" <td>6</td>\n",
|
676 |
+
" <td>0.132700</td>\n",
|
677 |
+
" <td>0.282697</td>\n",
|
678 |
+
" <td>0.940536</td>\n",
|
679 |
+
" <td>0.883271</td>\n",
|
680 |
+
" <td>0.940191</td>\n",
|
681 |
+
" </tr>\n",
|
682 |
+
" <tr>\n",
|
683 |
+
" <td>7</td>\n",
|
684 |
+
" <td>0.081800</td>\n",
|
685 |
+
" <td>0.295864</td>\n",
|
686 |
+
" <td>0.940843</td>\n",
|
687 |
+
" <td>0.884201</td>\n",
|
688 |
+
" <td>0.940170</td>\n",
|
689 |
+
" </tr>\n",
|
690 |
+
" <tr>\n",
|
691 |
+
" <td>8</td>\n",
|
692 |
+
" <td>0.035900</td>\n",
|
693 |
+
" <td>0.306600</td>\n",
|
694 |
+
" <td>0.941916</td>\n",
|
695 |
+
" <td>0.884777</td>\n",
|
696 |
+
" <td>0.941578</td>\n",
|
697 |
+
" </tr>\n",
|
698 |
+
" <tr>\n",
|
699 |
+
" <td>9</td>\n",
|
700 |
+
" <td>0.050800</td>\n",
|
701 |
+
" <td>0.311677</td>\n",
|
702 |
+
" <td>0.940536</td>\n",
|
703 |
+
" <td>0.883437</td>\n",
|
704 |
+
" <td>0.940294</td>\n",
|
705 |
+
" </tr>\n",
|
706 |
+
" <tr>\n",
|
707 |
+
" <td>10</td>\n",
|
708 |
+
" <td>0.035800</td>\n",
|
709 |
+
" <td>0.315360</td>\n",
|
710 |
+
" <td>0.940843</td>\n",
|
711 |
+
" <td>0.883551</td>\n",
|
712 |
+
" <td>0.940612</td>\n",
|
713 |
+
" </tr>\n",
|
714 |
+
" </tbody>\n",
|
715 |
+
"</table><p>"
|
716 |
+
],
|
717 |
+
"text/plain": [
|
718 |
+
"<IPython.core.display.HTML object>"
|
719 |
+
]
|
720 |
+
},
|
721 |
+
"metadata": {},
|
722 |
+
"output_type": "display_data"
|
723 |
+
},
|
724 |
+
{
|
725 |
+
"name": "stderr",
|
726 |
+
"output_type": "stream",
|
727 |
+
"text": [
|
728 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
729 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
730 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
731 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
732 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
733 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
734 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
735 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
736 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
737 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
738 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
739 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
740 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
741 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
742 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
743 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
744 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
745 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
746 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
747 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
|
748 |
+
]
|
749 |
+
},
|
750 |
+
{
|
751 |
+
"data": {
|
752 |
+
"text/html": [
|
753 |
+
"\n",
|
754 |
+
" <div>\n",
|
755 |
+
" \n",
|
756 |
+
" <progress value='544' max='544' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
757 |
+
" [544/544 00:19]\n",
|
758 |
+
" </div>\n",
|
759 |
+
" "
|
760 |
+
],
|
761 |
+
"text/plain": [
|
762 |
+
"<IPython.core.display.HTML object>"
|
763 |
+
]
|
764 |
+
},
|
765 |
+
"metadata": {},
|
766 |
+
"output_type": "display_data"
|
767 |
+
},
|
768 |
+
{
|
769 |
+
"name": "stderr",
|
770 |
+
"output_type": "stream",
|
771 |
+
"text": [
|
772 |
+
"Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n",
|
773 |
+
"- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
774 |
+
"- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
775 |
+
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias']\n",
|
776 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
777 |
+
]
|
778 |
+
},
|
779 |
+
{
|
780 |
+
"name": "stdout",
|
781 |
+
"output_type": "stream",
|
782 |
+
"text": [
|
783 |
+
"brain\n"
|
784 |
+
]
|
785 |
+
},
|
786 |
+
{
|
787 |
+
"name": "stderr",
|
788 |
+
"output_type": "stream",
|
789 |
+
"text": [
|
790 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
791 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
|
792 |
+
]
|
793 |
+
},
|
794 |
+
{
|
795 |
+
"data": {
|
796 |
+
"text/html": [
|
797 |
+
"\n",
|
798 |
+
" <div>\n",
|
799 |
+
" \n",
|
800 |
+
" <progress value='8880' max='8880' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
801 |
+
" [8880/8880 11:14, Epoch 10/10]\n",
|
802 |
+
" </div>\n",
|
803 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
804 |
+
" <thead>\n",
|
805 |
+
" <tr style=\"text-align: left;\">\n",
|
806 |
+
" <th>Epoch</th>\n",
|
807 |
+
" <th>Training Loss</th>\n",
|
808 |
+
" <th>Validation Loss</th>\n",
|
809 |
+
" <th>Accuracy</th>\n",
|
810 |
+
" <th>Macro F1</th>\n",
|
811 |
+
" <th>Weighted F1</th>\n",
|
812 |
+
" </tr>\n",
|
813 |
+
" </thead>\n",
|
814 |
+
" <tbody>\n",
|
815 |
+
" <tr>\n",
|
816 |
+
" <td>1</td>\n",
|
817 |
+
" <td>0.163100</td>\n",
|
818 |
+
" <td>0.156640</td>\n",
|
819 |
+
" <td>0.970345</td>\n",
|
820 |
+
" <td>0.736455</td>\n",
|
821 |
+
" <td>0.960714</td>\n",
|
822 |
+
" </tr>\n",
|
823 |
+
" <tr>\n",
|
824 |
+
" <td>2</td>\n",
|
825 |
+
" <td>0.149800</td>\n",
|
826 |
+
" <td>0.134897</td>\n",
|
827 |
+
" <td>0.968844</td>\n",
|
828 |
+
" <td>0.747114</td>\n",
|
829 |
+
" <td>0.960726</td>\n",
|
830 |
+
" </tr>\n",
|
831 |
+
" <tr>\n",
|
832 |
+
" <td>3</td>\n",
|
833 |
+
" <td>0.105600</td>\n",
|
834 |
+
" <td>0.115354</td>\n",
|
835 |
+
" <td>0.972222</td>\n",
|
836 |
+
" <td>0.775271</td>\n",
|
837 |
+
" <td>0.964932</td>\n",
|
838 |
+
" </tr>\n",
|
839 |
+
" <tr>\n",
|
840 |
+
" <td>4</td>\n",
|
841 |
+
" <td>0.086900</td>\n",
|
842 |
+
" <td>0.207918</td>\n",
|
843 |
+
" <td>0.968844</td>\n",
|
844 |
+
" <td>0.707927</td>\n",
|
845 |
+
" <td>0.958257</td>\n",
|
846 |
+
" </tr>\n",
|
847 |
+
" <tr>\n",
|
848 |
+
" <td>5</td>\n",
|
849 |
+
" <td>0.056400</td>\n",
|
850 |
+
" <td>0.106548</td>\n",
|
851 |
+
" <td>0.974099</td>\n",
|
852 |
+
" <td>0.839838</td>\n",
|
853 |
+
" <td>0.971611</td>\n",
|
854 |
+
" </tr>\n",
|
855 |
+
" <tr>\n",
|
856 |
+
" <td>6</td>\n",
|
857 |
+
" <td>0.037600</td>\n",
|
858 |
+
" <td>0.117437</td>\n",
|
859 |
+
" <td>0.978228</td>\n",
|
860 |
+
" <td>0.856578</td>\n",
|
861 |
+
" <td>0.975665</td>\n",
|
862 |
+
" </tr>\n",
|
863 |
+
" <tr>\n",
|
864 |
+
" <td>7</td>\n",
|
865 |
+
" <td>0.030500</td>\n",
|
866 |
+
" <td>0.127885</td>\n",
|
867 |
+
" <td>0.974474</td>\n",
|
868 |
+
" <td>0.856296</td>\n",
|
869 |
+
" <td>0.973531</td>\n",
|
870 |
+
" </tr>\n",
|
871 |
+
" <tr>\n",
|
872 |
+
" <td>8</td>\n",
|
873 |
+
" <td>0.019300</td>\n",
|
874 |
+
" <td>0.143203</td>\n",
|
875 |
+
" <td>0.977853</td>\n",
|
876 |
+
" <td>0.859362</td>\n",
|
877 |
+
" <td>0.975776</td>\n",
|
878 |
+
" </tr>\n",
|
879 |
+
" <tr>\n",
|
880 |
+
" <td>9</td>\n",
|
881 |
+
" <td>0.007400</td>\n",
|
882 |
+
" <td>0.153758</td>\n",
|
883 |
+
" <td>0.972598</td>\n",
|
884 |
+
" <td>0.852835</td>\n",
|
885 |
+
" <td>0.972314</td>\n",
|
886 |
+
" </tr>\n",
|
887 |
+
" <tr>\n",
|
888 |
+
" <td>10</td>\n",
|
889 |
+
" <td>0.017200</td>\n",
|
890 |
+
" <td>0.153911</td>\n",
|
891 |
+
" <td>0.975976</td>\n",
|
892 |
+
" <td>0.858196</td>\n",
|
893 |
+
" <td>0.974498</td>\n",
|
894 |
+
" </tr>\n",
|
895 |
+
" </tbody>\n",
|
896 |
+
"</table><p>"
|
897 |
+
],
|
898 |
+
"text/plain": [
|
899 |
+
"<IPython.core.display.HTML object>"
|
900 |
+
]
|
901 |
+
},
|
902 |
+
"metadata": {},
|
903 |
+
"output_type": "display_data"
|
904 |
+
},
|
905 |
+
{
|
906 |
+
"name": "stderr",
|
907 |
+
"output_type": "stream",
|
908 |
+
"text": [
|
909 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
910 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
911 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
912 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
913 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
914 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
915 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
916 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
917 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
918 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
919 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
920 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
921 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
922 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
923 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
924 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
925 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
926 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
927 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
928 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
|
929 |
+
]
|
930 |
+
},
|
931 |
+
{
|
932 |
+
"data": {
|
933 |
+
"text/html": [
|
934 |
+
"\n",
|
935 |
+
" <div>\n",
|
936 |
+
" \n",
|
937 |
+
" <progress value='222' max='222' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
938 |
+
" [222/222 00:04]\n",
|
939 |
+
" </div>\n",
|
940 |
+
" "
|
941 |
+
],
|
942 |
+
"text/plain": [
|
943 |
+
"<IPython.core.display.HTML object>"
|
944 |
+
]
|
945 |
+
},
|
946 |
+
"metadata": {},
|
947 |
+
"output_type": "display_data"
|
948 |
+
},
|
949 |
+
{
|
950 |
+
"name": "stderr",
|
951 |
+
"output_type": "stream",
|
952 |
+
"text": [
|
953 |
+
"Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n",
|
954 |
+
"- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
955 |
+
"- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
956 |
+
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias']\n",
|
957 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
958 |
+
]
|
959 |
+
},
|
960 |
+
{
|
961 |
+
"name": "stdout",
|
962 |
+
"output_type": "stream",
|
963 |
+
"text": [
|
964 |
+
"placenta\n"
|
965 |
+
]
|
966 |
+
},
|
967 |
+
{
|
968 |
+
"name": "stderr",
|
969 |
+
"output_type": "stream",
|
970 |
+
"text": [
|
971 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
972 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
|
973 |
+
]
|
974 |
+
},
|
975 |
+
{
|
976 |
+
"data": {
|
977 |
+
"text/html": [
|
978 |
+
"\n",
|
979 |
+
" <div>\n",
|
980 |
+
" \n",
|
981 |
+
" <progress value='6180' max='6180' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
982 |
+
" [6180/6180 10:28, Epoch 10/10]\n",
|
983 |
+
" </div>\n",
|
984 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
985 |
+
" <thead>\n",
|
986 |
+
" <tr style=\"text-align: left;\">\n",
|
987 |
+
" <th>Epoch</th>\n",
|
988 |
+
" <th>Training Loss</th>\n",
|
989 |
+
" <th>Validation Loss</th>\n",
|
990 |
+
" <th>Accuracy</th>\n",
|
991 |
+
" <th>Macro F1</th>\n",
|
992 |
+
" <th>Weighted F1</th>\n",
|
993 |
+
" </tr>\n",
|
994 |
+
" </thead>\n",
|
995 |
+
" <tbody>\n",
|
996 |
+
" <tr>\n",
|
997 |
+
" <td>1</td>\n",
|
998 |
+
" <td>0.128700</td>\n",
|
999 |
+
" <td>0.125175</td>\n",
|
1000 |
+
" <td>0.960626</td>\n",
|
1001 |
+
" <td>0.935752</td>\n",
|
1002 |
+
" <td>0.959463</td>\n",
|
1003 |
+
" </tr>\n",
|
1004 |
+
" <tr>\n",
|
1005 |
+
" <td>2</td>\n",
|
1006 |
+
" <td>0.064000</td>\n",
|
1007 |
+
" <td>0.215607</td>\n",
|
1008 |
+
" <td>0.951456</td>\n",
|
1009 |
+
" <td>0.920579</td>\n",
|
1010 |
+
" <td>0.949828</td>\n",
|
1011 |
+
" </tr>\n",
|
1012 |
+
" <tr>\n",
|
1013 |
+
" <td>3</td>\n",
|
1014 |
+
" <td>0.051300</td>\n",
|
1015 |
+
" <td>0.203044</td>\n",
|
1016 |
+
" <td>0.961165</td>\n",
|
1017 |
+
" <td>0.934195</td>\n",
|
1018 |
+
" <td>0.959470</td>\n",
|
1019 |
+
" </tr>\n",
|
1020 |
+
" <tr>\n",
|
1021 |
+
" <td>4</td>\n",
|
1022 |
+
" <td>0.045300</td>\n",
|
1023 |
+
" <td>0.115701</td>\n",
|
1024 |
+
" <td>0.978964</td>\n",
|
1025 |
+
" <td>0.966387</td>\n",
|
1026 |
+
" <td>0.978788</td>\n",
|
1027 |
+
" </tr>\n",
|
1028 |
+
" <tr>\n",
|
1029 |
+
" <td>5</td>\n",
|
1030 |
+
" <td>0.048200</td>\n",
|
1031 |
+
" <td>0.149484</td>\n",
|
1032 |
+
" <td>0.973571</td>\n",
|
1033 |
+
" <td>0.958927</td>\n",
|
1034 |
+
" <td>0.973305</td>\n",
|
1035 |
+
" </tr>\n",
|
1036 |
+
" <tr>\n",
|
1037 |
+
" <td>6</td>\n",
|
1038 |
+
" <td>0.040900</td>\n",
|
1039 |
+
" <td>0.134339</td>\n",
|
1040 |
+
" <td>0.978964</td>\n",
|
1041 |
+
" <td>0.967466</td>\n",
|
1042 |
+
" <td>0.978899</td>\n",
|
1043 |
+
" </tr>\n",
|
1044 |
+
" <tr>\n",
|
1045 |
+
" <td>7</td>\n",
|
1046 |
+
" <td>0.001600</td>\n",
|
1047 |
+
" <td>0.159900</td>\n",
|
1048 |
+
" <td>0.978425</td>\n",
|
1049 |
+
" <td>0.966713</td>\n",
|
1050 |
+
" <td>0.978211</td>\n",
|
1051 |
+
" </tr>\n",
|
1052 |
+
" <tr>\n",
|
1053 |
+
" <td>8</td>\n",
|
1054 |
+
" <td>0.002400</td>\n",
|
1055 |
+
" <td>0.125351</td>\n",
|
1056 |
+
" <td>0.979504</td>\n",
|
1057 |
+
" <td>0.968064</td>\n",
|
1058 |
+
" <td>0.979428</td>\n",
|
1059 |
+
" </tr>\n",
|
1060 |
+
" <tr>\n",
|
1061 |
+
" <td>9</td>\n",
|
1062 |
+
" <td>0.009400</td>\n",
|
1063 |
+
" <td>0.120132</td>\n",
|
1064 |
+
" <td>0.980583</td>\n",
|
1065 |
+
" <td>0.969631</td>\n",
|
1066 |
+
" <td>0.980506</td>\n",
|
1067 |
+
" </tr>\n",
|
1068 |
+
" <tr>\n",
|
1069 |
+
" <td>10</td>\n",
|
1070 |
+
" <td>0.001500</td>\n",
|
1071 |
+
" <td>0.137864</td>\n",
|
1072 |
+
" <td>0.978964</td>\n",
|
1073 |
+
" <td>0.967180</td>\n",
|
1074 |
+
" <td>0.978825</td>\n",
|
1075 |
+
" </tr>\n",
|
1076 |
+
" </tbody>\n",
|
1077 |
+
"</table><p>"
|
1078 |
+
],
|
1079 |
+
"text/plain": [
|
1080 |
+
"<IPython.core.display.HTML object>"
|
1081 |
+
]
|
1082 |
+
},
|
1083 |
+
"metadata": {},
|
1084 |
+
"output_type": "display_data"
|
1085 |
+
},
|
1086 |
+
{
|
1087 |
+
"name": "stderr",
|
1088 |
+
"output_type": "stream",
|
1089 |
+
"text": [
|
1090 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1091 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
1092 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1093 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
1094 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1095 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
1096 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1097 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
1098 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1099 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
1100 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1101 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
1102 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1103 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
1104 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1105 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
1106 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1107 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
1108 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1109 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
|
1110 |
+
]
|
1111 |
+
},
|
1112 |
+
{
|
1113 |
+
"data": {
|
1114 |
+
"text/html": [
|
1115 |
+
"\n",
|
1116 |
+
" <div>\n",
|
1117 |
+
" \n",
|
1118 |
+
" <progress value='155' max='155' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
1119 |
+
" [155/155 00:05]\n",
|
1120 |
+
" </div>\n",
|
1121 |
+
" "
|
1122 |
+
],
|
1123 |
+
"text/plain": [
|
1124 |
+
"<IPython.core.display.HTML object>"
|
1125 |
+
]
|
1126 |
+
},
|
1127 |
+
"metadata": {},
|
1128 |
+
"output_type": "display_data"
|
1129 |
+
},
|
1130 |
+
{
|
1131 |
+
"name": "stderr",
|
1132 |
+
"output_type": "stream",
|
1133 |
+
"text": [
|
1134 |
+
"Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n",
|
1135 |
+
"- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
1136 |
+
"- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
1137 |
+
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias']\n",
|
1138 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
1139 |
+
]
|
1140 |
+
},
|
1141 |
+
{
|
1142 |
+
"name": "stdout",
|
1143 |
+
"output_type": "stream",
|
1144 |
+
"text": [
|
1145 |
+
"immune\n"
|
1146 |
+
]
|
1147 |
+
},
|
1148 |
+
{
|
1149 |
+
"name": "stderr",
|
1150 |
+
"output_type": "stream",
|
1151 |
+
"text": [
|
1152 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1153 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
|
1154 |
+
]
|
1155 |
+
},
|
1156 |
+
{
|
1157 |
+
"data": {
|
1158 |
+
"text/html": [
|
1159 |
+
"\n",
|
1160 |
+
" <div>\n",
|
1161 |
+
" \n",
|
1162 |
+
" <progress value='17140' max='17140' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
1163 |
+
" [17140/17140 22:02, Epoch 10/10]\n",
|
1164 |
+
" </div>\n",
|
1165 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
1166 |
+
" <thead>\n",
|
1167 |
+
" <tr style=\"text-align: left;\">\n",
|
1168 |
+
" <th>Epoch</th>\n",
|
1169 |
+
" <th>Training Loss</th>\n",
|
1170 |
+
" <th>Validation Loss</th>\n",
|
1171 |
+
" <th>Accuracy</th>\n",
|
1172 |
+
" <th>Macro F1</th>\n",
|
1173 |
+
" <th>Weighted F1</th>\n",
|
1174 |
+
" </tr>\n",
|
1175 |
+
" </thead>\n",
|
1176 |
+
" <tbody>\n",
|
1177 |
+
" <tr>\n",
|
1178 |
+
" <td>1</td>\n",
|
1179 |
+
" <td>0.288900</td>\n",
|
1180 |
+
" <td>0.231582</td>\n",
|
1181 |
+
" <td>0.936770</td>\n",
|
1182 |
+
" <td>0.868405</td>\n",
|
1183 |
+
" <td>0.934816</td>\n",
|
1184 |
+
" </tr>\n",
|
1185 |
+
" <tr>\n",
|
1186 |
+
" <td>2</td>\n",
|
1187 |
+
" <td>0.203200</td>\n",
|
1188 |
+
" <td>0.206292</td>\n",
|
1189 |
+
" <td>0.937354</td>\n",
|
1190 |
+
" <td>0.888661</td>\n",
|
1191 |
+
" <td>0.939555</td>\n",
|
1192 |
+
" </tr>\n",
|
1193 |
+
" <tr>\n",
|
1194 |
+
" <td>3</td>\n",
|
1195 |
+
" <td>0.183500</td>\n",
|
1196 |
+
" <td>0.195811</td>\n",
|
1197 |
+
" <td>0.944942</td>\n",
|
1198 |
+
" <td>0.891149</td>\n",
|
1199 |
+
" <td>0.944008</td>\n",
|
1200 |
+
" </tr>\n",
|
1201 |
+
" <tr>\n",
|
1202 |
+
" <td>4</td>\n",
|
1203 |
+
" <td>0.151000</td>\n",
|
1204 |
+
" <td>0.219581</td>\n",
|
1205 |
+
" <td>0.947665</td>\n",
|
1206 |
+
" <td>0.906578</td>\n",
|
1207 |
+
" <td>0.947093</td>\n",
|
1208 |
+
" </tr>\n",
|
1209 |
+
" <tr>\n",
|
1210 |
+
" <td>5</td>\n",
|
1211 |
+
" <td>0.090000</td>\n",
|
1212 |
+
" <td>0.247120</td>\n",
|
1213 |
+
" <td>0.946693</td>\n",
|
1214 |
+
" <td>0.898812</td>\n",
|
1215 |
+
" <td>0.945808</td>\n",
|
1216 |
+
" </tr>\n",
|
1217 |
+
" <tr>\n",
|
1218 |
+
" <td>6</td>\n",
|
1219 |
+
" <td>0.060400</td>\n",
|
1220 |
+
" <td>0.249662</td>\n",
|
1221 |
+
" <td>0.948444</td>\n",
|
1222 |
+
" <td>0.905014</td>\n",
|
1223 |
+
" <td>0.947975</td>\n",
|
1224 |
+
" </tr>\n",
|
1225 |
+
" <tr>\n",
|
1226 |
+
" <td>7</td>\n",
|
1227 |
+
" <td>0.071300</td>\n",
|
1228 |
+
" <td>0.272767</td>\n",
|
1229 |
+
" <td>0.949416</td>\n",
|
1230 |
+
" <td>0.911514</td>\n",
|
1231 |
+
" <td>0.949748</td>\n",
|
1232 |
+
" </tr>\n",
|
1233 |
+
" <tr>\n",
|
1234 |
+
" <td>8</td>\n",
|
1235 |
+
" <td>0.052600</td>\n",
|
1236 |
+
" <td>0.305051</td>\n",
|
1237 |
+
" <td>0.945331</td>\n",
|
1238 |
+
" <td>0.902348</td>\n",
|
1239 |
+
" <td>0.944987</td>\n",
|
1240 |
+
" </tr>\n",
|
1241 |
+
" <tr>\n",
|
1242 |
+
" <td>9</td>\n",
|
1243 |
+
" <td>0.026900</td>\n",
|
1244 |
+
" <td>0.294135</td>\n",
|
1245 |
+
" <td>0.948638</td>\n",
|
1246 |
+
" <td>0.904058</td>\n",
|
1247 |
+
" <td>0.948296</td>\n",
|
1248 |
+
" </tr>\n",
|
1249 |
+
" <tr>\n",
|
1250 |
+
" <td>10</td>\n",
|
1251 |
+
" <td>0.034500</td>\n",
|
1252 |
+
" <td>0.292029</td>\n",
|
1253 |
+
" <td>0.950195</td>\n",
|
1254 |
+
" <td>0.908547</td>\n",
|
1255 |
+
" <td>0.949753</td>\n",
|
1256 |
+
" </tr>\n",
|
1257 |
+
" </tbody>\n",
|
1258 |
+
"</table><p>"
|
1259 |
+
],
|
1260 |
+
"text/plain": [
|
1261 |
+
"<IPython.core.display.HTML object>"
|
1262 |
+
]
|
1263 |
+
},
|
1264 |
+
"metadata": {},
|
1265 |
+
"output_type": "display_data"
|
1266 |
+
},
|
1267 |
+
{
|
1268 |
+
"name": "stderr",
|
1269 |
+
"output_type": "stream",
|
1270 |
+
"text": [
|
1271 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1272 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
1273 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1274 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
1275 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1276 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
1277 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1278 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
1279 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1280 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
1281 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1282 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
1283 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1284 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
1285 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1286 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
1287 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1288 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
1289 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1290 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
|
1291 |
+
]
|
1292 |
+
},
|
1293 |
+
{
|
1294 |
+
"data": {
|
1295 |
+
"text/html": [
|
1296 |
+
"\n",
|
1297 |
+
" <div>\n",
|
1298 |
+
" \n",
|
1299 |
+
" <progress value='429' max='429' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
1300 |
+
" [429/429 00:13]\n",
|
1301 |
+
" </div>\n",
|
1302 |
+
" "
|
1303 |
+
],
|
1304 |
+
"text/plain": [
|
1305 |
+
"<IPython.core.display.HTML object>"
|
1306 |
+
]
|
1307 |
+
},
|
1308 |
+
"metadata": {},
|
1309 |
+
"output_type": "display_data"
|
1310 |
+
},
|
1311 |
+
{
|
1312 |
+
"name": "stderr",
|
1313 |
+
"output_type": "stream",
|
1314 |
+
"text": [
|
1315 |
+
"Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n",
|
1316 |
+
"- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
1317 |
+
"- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
1318 |
+
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias']\n",
|
1319 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
1320 |
+
]
|
1321 |
+
},
|
1322 |
+
{
|
1323 |
+
"name": "stdout",
|
1324 |
+
"output_type": "stream",
|
1325 |
+
"text": [
|
1326 |
+
"large_intestine\n"
|
1327 |
+
]
|
1328 |
+
},
|
1329 |
+
{
|
1330 |
+
"name": "stderr",
|
1331 |
+
"output_type": "stream",
|
1332 |
+
"text": [
|
1333 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1334 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
|
1335 |
+
]
|
1336 |
+
},
|
1337 |
+
{
|
1338 |
+
"data": {
|
1339 |
+
"text/html": [
|
1340 |
+
"\n",
|
1341 |
+
" <div>\n",
|
1342 |
+
" \n",
|
1343 |
+
" <progress value='33070' max='33070' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
1344 |
+
" [33070/33070 43:02, Epoch 10/10]\n",
|
1345 |
+
" </div>\n",
|
1346 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
1347 |
+
" <thead>\n",
|
1348 |
+
" <tr style=\"text-align: left;\">\n",
|
1349 |
+
" <th>Epoch</th>\n",
|
1350 |
+
" <th>Training Loss</th>\n",
|
1351 |
+
" <th>Validation Loss</th>\n",
|
1352 |
+
" <th>Accuracy</th>\n",
|
1353 |
+
" <th>Macro F1</th>\n",
|
1354 |
+
" <th>Weighted F1</th>\n",
|
1355 |
+
" </tr>\n",
|
1356 |
+
" </thead>\n",
|
1357 |
+
" <tbody>\n",
|
1358 |
+
" <tr>\n",
|
1359 |
+
" <td>1</td>\n",
|
1360 |
+
" <td>0.306200</td>\n",
|
1361 |
+
" <td>0.312431</td>\n",
|
1362 |
+
" <td>0.908266</td>\n",
|
1363 |
+
" <td>0.786242</td>\n",
|
1364 |
+
" <td>0.900768</td>\n",
|
1365 |
+
" </tr>\n",
|
1366 |
+
" <tr>\n",
|
1367 |
+
" <td>2</td>\n",
|
1368 |
+
" <td>0.223900</td>\n",
|
1369 |
+
" <td>0.248096</td>\n",
|
1370 |
+
" <td>0.925101</td>\n",
|
1371 |
+
" <td>0.841251</td>\n",
|
1372 |
+
" <td>0.920987</td>\n",
|
1373 |
+
" </tr>\n",
|
1374 |
+
" <tr>\n",
|
1375 |
+
" <td>3</td>\n",
|
1376 |
+
" <td>0.173600</td>\n",
|
1377 |
+
" <td>0.259997</td>\n",
|
1378 |
+
" <td>0.925907</td>\n",
|
1379 |
+
" <td>0.850348</td>\n",
|
1380 |
+
" <td>0.926290</td>\n",
|
1381 |
+
" </tr>\n",
|
1382 |
+
" <tr>\n",
|
1383 |
+
" <td>4</td>\n",
|
1384 |
+
" <td>0.162900</td>\n",
|
1385 |
+
" <td>0.282306</td>\n",
|
1386 |
+
" <td>0.925000</td>\n",
|
1387 |
+
" <td>0.873669</td>\n",
|
1388 |
+
" <td>0.925531</td>\n",
|
1389 |
+
" </tr>\n",
|
1390 |
+
" <tr>\n",
|
1391 |
+
" <td>5</td>\n",
|
1392 |
+
" <td>0.143400</td>\n",
|
1393 |
+
" <td>0.254494</td>\n",
|
1394 |
+
" <td>0.937903</td>\n",
|
1395 |
+
" <td>0.876749</td>\n",
|
1396 |
+
" <td>0.937836</td>\n",
|
1397 |
+
" </tr>\n",
|
1398 |
+
" <tr>\n",
|
1399 |
+
" <td>6</td>\n",
|
1400 |
+
" <td>0.104500</td>\n",
|
1401 |
+
" <td>0.289942</td>\n",
|
1402 |
+
" <td>0.934677</td>\n",
|
1403 |
+
" <td>0.875333</td>\n",
|
1404 |
+
" <td>0.934339</td>\n",
|
1405 |
+
" </tr>\n",
|
1406 |
+
" <tr>\n",
|
1407 |
+
" <td>7</td>\n",
|
1408 |
+
" <td>0.080300</td>\n",
|
1409 |
+
" <td>0.313914</td>\n",
|
1410 |
+
" <td>0.935484</td>\n",
|
1411 |
+
" <td>0.877271</td>\n",
|
1412 |
+
" <td>0.934986</td>\n",
|
1413 |
+
" </tr>\n",
|
1414 |
+
" <tr>\n",
|
1415 |
+
" <td>8</td>\n",
|
1416 |
+
" <td>0.063500</td>\n",
|
1417 |
+
" <td>0.339868</td>\n",
|
1418 |
+
" <td>0.936290</td>\n",
|
1419 |
+
" <td>0.882267</td>\n",
|
1420 |
+
" <td>0.936187</td>\n",
|
1421 |
+
" </tr>\n",
|
1422 |
+
" <tr>\n",
|
1423 |
+
" <td>9</td>\n",
|
1424 |
+
" <td>0.042500</td>\n",
|
1425 |
+
" <td>0.345784</td>\n",
|
1426 |
+
" <td>0.938911</td>\n",
|
1427 |
+
" <td>0.882963</td>\n",
|
1428 |
+
" <td>0.938682</td>\n",
|
1429 |
+
" </tr>\n",
|
1430 |
+
" <tr>\n",
|
1431 |
+
" <td>10</td>\n",
|
1432 |
+
" <td>0.038900</td>\n",
|
1433 |
+
" <td>0.352199</td>\n",
|
1434 |
+
" <td>0.939516</td>\n",
|
1435 |
+
" <td>0.885509</td>\n",
|
1436 |
+
" <td>0.939497</td>\n",
|
1437 |
+
" </tr>\n",
|
1438 |
+
" </tbody>\n",
|
1439 |
+
"</table><p>"
|
1440 |
+
],
|
1441 |
+
"text/plain": [
|
1442 |
+
"<IPython.core.display.HTML object>"
|
1443 |
+
]
|
1444 |
+
},
|
1445 |
+
"metadata": {},
|
1446 |
+
"output_type": "display_data"
|
1447 |
+
},
|
1448 |
+
{
|
1449 |
+
"name": "stderr",
|
1450 |
+
"output_type": "stream",
|
1451 |
+
"text": [
|
1452 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1453 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
1454 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1455 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
1456 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1457 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
1458 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1459 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
1460 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1461 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
1462 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1463 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
1464 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1465 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
1466 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1467 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
1468 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1469 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
1470 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1471 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
|
1472 |
+
]
|
1473 |
+
},
|
1474 |
+
{
|
1475 |
+
"data": {
|
1476 |
+
"text/html": [
|
1477 |
+
"\n",
|
1478 |
+
" <div>\n",
|
1479 |
+
" \n",
|
1480 |
+
" <progress value='827' max='827' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
1481 |
+
" [827/827 00:26]\n",
|
1482 |
+
" </div>\n",
|
1483 |
+
" "
|
1484 |
+
],
|
1485 |
+
"text/plain": [
|
1486 |
+
"<IPython.core.display.HTML object>"
|
1487 |
+
]
|
1488 |
+
},
|
1489 |
+
"metadata": {},
|
1490 |
+
"output_type": "display_data"
|
1491 |
+
},
|
1492 |
+
{
|
1493 |
+
"name": "stderr",
|
1494 |
+
"output_type": "stream",
|
1495 |
+
"text": [
|
1496 |
+
"Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n",
|
1497 |
+
"- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
1498 |
+
"- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
1499 |
+
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias']\n",
|
1500 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
1501 |
+
]
|
1502 |
+
},
|
1503 |
+
{
|
1504 |
+
"name": "stdout",
|
1505 |
+
"output_type": "stream",
|
1506 |
+
"text": [
|
1507 |
+
"pancreas\n"
|
1508 |
+
]
|
1509 |
+
},
|
1510 |
+
{
|
1511 |
+
"name": "stderr",
|
1512 |
+
"output_type": "stream",
|
1513 |
+
"text": [
|
1514 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1515 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
|
1516 |
+
]
|
1517 |
+
},
|
1518 |
+
{
|
1519 |
+
"data": {
|
1520 |
+
"text/html": [
|
1521 |
+
"\n",
|
1522 |
+
" <div>\n",
|
1523 |
+
" \n",
|
1524 |
+
" <progress value='18280' max='18280' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
1525 |
+
" [18280/18280 23:32, Epoch 10/10]\n",
|
1526 |
+
" </div>\n",
|
1527 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
1528 |
+
" <thead>\n",
|
1529 |
+
" <tr style=\"text-align: left;\">\n",
|
1530 |
+
" <th>Epoch</th>\n",
|
1531 |
+
" <th>Training Loss</th>\n",
|
1532 |
+
" <th>Validation Loss</th>\n",
|
1533 |
+
" <th>Accuracy</th>\n",
|
1534 |
+
" <th>Macro F1</th>\n",
|
1535 |
+
" <th>Weighted F1</th>\n",
|
1536 |
+
" </tr>\n",
|
1537 |
+
" </thead>\n",
|
1538 |
+
" <tbody>\n",
|
1539 |
+
" <tr>\n",
|
1540 |
+
" <td>1</td>\n",
|
1541 |
+
" <td>0.340100</td>\n",
|
1542 |
+
" <td>0.343200</td>\n",
|
1543 |
+
" <td>0.896244</td>\n",
|
1544 |
+
" <td>0.655661</td>\n",
|
1545 |
+
" <td>0.879469</td>\n",
|
1546 |
+
" </tr>\n",
|
1547 |
+
" <tr>\n",
|
1548 |
+
" <td>2</td>\n",
|
1549 |
+
" <td>0.178300</td>\n",
|
1550 |
+
" <td>0.224033</td>\n",
|
1551 |
+
" <td>0.930890</td>\n",
|
1552 |
+
" <td>0.859772</td>\n",
|
1553 |
+
" <td>0.925342</td>\n",
|
1554 |
+
" </tr>\n",
|
1555 |
+
" <tr>\n",
|
1556 |
+
" <td>3</td>\n",
|
1557 |
+
" <td>0.154200</td>\n",
|
1558 |
+
" <td>0.208034</td>\n",
|
1559 |
+
" <td>0.941284</td>\n",
|
1560 |
+
" <td>0.887012</td>\n",
|
1561 |
+
" <td>0.939485</td>\n",
|
1562 |
+
" </tr>\n",
|
1563 |
+
" <tr>\n",
|
1564 |
+
" <td>4</td>\n",
|
1565 |
+
" <td>0.121200</td>\n",
|
1566 |
+
" <td>0.216660</td>\n",
|
1567 |
+
" <td>0.940372</td>\n",
|
1568 |
+
" <td>0.880716</td>\n",
|
1569 |
+
" <td>0.939431</td>\n",
|
1570 |
+
" </tr>\n",
|
1571 |
+
" <tr>\n",
|
1572 |
+
" <td>5</td>\n",
|
1573 |
+
" <td>0.099900</td>\n",
|
1574 |
+
" <td>0.254255</td>\n",
|
1575 |
+
" <td>0.940554</td>\n",
|
1576 |
+
" <td>0.889088</td>\n",
|
1577 |
+
" <td>0.938300</td>\n",
|
1578 |
+
" </tr>\n",
|
1579 |
+
" <tr>\n",
|
1580 |
+
" <td>6</td>\n",
|
1581 |
+
" <td>0.065800</td>\n",
|
1582 |
+
" <td>0.267429</td>\n",
|
1583 |
+
" <td>0.942743</td>\n",
|
1584 |
+
" <td>0.897682</td>\n",
|
1585 |
+
" <td>0.942815</td>\n",
|
1586 |
+
" </tr>\n",
|
1587 |
+
" <tr>\n",
|
1588 |
+
" <td>7</td>\n",
|
1589 |
+
" <td>0.061200</td>\n",
|
1590 |
+
" <td>0.282509</td>\n",
|
1591 |
+
" <td>0.945478</td>\n",
|
1592 |
+
" <td>0.898797</td>\n",
|
1593 |
+
" <td>0.943881</td>\n",
|
1594 |
+
" </tr>\n",
|
1595 |
+
" <tr>\n",
|
1596 |
+
" <td>8</td>\n",
|
1597 |
+
" <td>0.036800</td>\n",
|
1598 |
+
" <td>0.301781</td>\n",
|
1599 |
+
" <td>0.943837</td>\n",
|
1600 |
+
" <td>0.903816</td>\n",
|
1601 |
+
" <td>0.944163</td>\n",
|
1602 |
+
" </tr>\n",
|
1603 |
+
" <tr>\n",
|
1604 |
+
" <td>9</td>\n",
|
1605 |
+
" <td>0.035400</td>\n",
|
1606 |
+
" <td>0.317026</td>\n",
|
1607 |
+
" <td>0.942560</td>\n",
|
1608 |
+
" <td>0.902241</td>\n",
|
1609 |
+
" <td>0.942071</td>\n",
|
1610 |
+
" </tr>\n",
|
1611 |
+
" <tr>\n",
|
1612 |
+
" <td>10</td>\n",
|
1613 |
+
" <td>0.014200</td>\n",
|
1614 |
+
" <td>0.313259</td>\n",
|
1615 |
+
" <td>0.946754</td>\n",
|
1616 |
+
" <td>0.904955</td>\n",
|
1617 |
+
" <td>0.946129</td>\n",
|
1618 |
+
" </tr>\n",
|
1619 |
+
" </tbody>\n",
|
1620 |
+
"</table><p>"
|
1621 |
+
],
|
1622 |
+
"text/plain": [
|
1623 |
+
"<IPython.core.display.HTML object>"
|
1624 |
+
]
|
1625 |
+
},
|
1626 |
+
"metadata": {},
|
1627 |
+
"output_type": "display_data"
|
1628 |
+
},
|
1629 |
+
{
|
1630 |
+
"name": "stderr",
|
1631 |
+
"output_type": "stream",
|
1632 |
+
"text": [
|
1633 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1634 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
1635 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1636 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
1637 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1638 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
1639 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1640 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
1641 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1642 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
1643 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1644 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
1645 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1646 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
1647 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1648 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
1649 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1650 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
1651 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1652 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
|
1653 |
+
]
|
1654 |
+
},
|
1655 |
+
{
|
1656 |
+
"data": {
|
1657 |
+
"text/html": [
|
1658 |
+
"\n",
|
1659 |
+
" <div>\n",
|
1660 |
+
" \n",
|
1661 |
+
" <progress value='457' max='457' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
1662 |
+
" [457/457 00:11]\n",
|
1663 |
+
" </div>\n",
|
1664 |
+
" "
|
1665 |
+
],
|
1666 |
+
"text/plain": [
|
1667 |
+
"<IPython.core.display.HTML object>"
|
1668 |
+
]
|
1669 |
+
},
|
1670 |
+
"metadata": {},
|
1671 |
+
"output_type": "display_data"
|
1672 |
+
},
|
1673 |
+
{
|
1674 |
+
"name": "stderr",
|
1675 |
+
"output_type": "stream",
|
1676 |
+
"text": [
|
1677 |
+
"Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n",
|
1678 |
+
"- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
1679 |
+
"- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
1680 |
+
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias']\n",
|
1681 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
1682 |
+
]
|
1683 |
+
},
|
1684 |
+
{
|
1685 |
+
"name": "stdout",
|
1686 |
+
"output_type": "stream",
|
1687 |
+
"text": [
|
1688 |
+
"liver\n"
|
1689 |
+
]
|
1690 |
+
},
|
1691 |
+
{
|
1692 |
+
"name": "stderr",
|
1693 |
+
"output_type": "stream",
|
1694 |
+
"text": [
|
1695 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1696 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
|
1697 |
+
]
|
1698 |
+
},
|
1699 |
+
{
|
1700 |
+
"data": {
|
1701 |
+
"text/html": [
|
1702 |
+
"\n",
|
1703 |
+
" <div>\n",
|
1704 |
+
" \n",
|
1705 |
+
" <progress value='18690' max='18690' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
1706 |
+
" [18690/18690 26:56, Epoch 10/10]\n",
|
1707 |
+
" </div>\n",
|
1708 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
1709 |
+
" <thead>\n",
|
1710 |
+
" <tr style=\"text-align: left;\">\n",
|
1711 |
+
" <th>Epoch</th>\n",
|
1712 |
+
" <th>Training Loss</th>\n",
|
1713 |
+
" <th>Validation Loss</th>\n",
|
1714 |
+
" <th>Accuracy</th>\n",
|
1715 |
+
" <th>Macro F1</th>\n",
|
1716 |
+
" <th>Weighted F1</th>\n",
|
1717 |
+
" </tr>\n",
|
1718 |
+
" </thead>\n",
|
1719 |
+
" <tbody>\n",
|
1720 |
+
" <tr>\n",
|
1721 |
+
" <td>1</td>\n",
|
1722 |
+
" <td>0.388500</td>\n",
|
1723 |
+
" <td>0.385503</td>\n",
|
1724 |
+
" <td>0.878188</td>\n",
|
1725 |
+
" <td>0.673887</td>\n",
|
1726 |
+
" <td>0.871348</td>\n",
|
1727 |
+
" </tr>\n",
|
1728 |
+
" <tr>\n",
|
1729 |
+
" <td>2</td>\n",
|
1730 |
+
" <td>0.315900</td>\n",
|
1731 |
+
" <td>0.302775</td>\n",
|
1732 |
+
" <td>0.907437</td>\n",
|
1733 |
+
" <td>0.754182</td>\n",
|
1734 |
+
" <td>0.903474</td>\n",
|
1735 |
+
" </tr>\n",
|
1736 |
+
" <tr>\n",
|
1737 |
+
" <td>3</td>\n",
|
1738 |
+
" <td>0.242600</td>\n",
|
1739 |
+
" <td>0.321844</td>\n",
|
1740 |
+
" <td>0.907972</td>\n",
|
1741 |
+
" <td>0.779504</td>\n",
|
1742 |
+
" <td>0.905881</td>\n",
|
1743 |
+
" </tr>\n",
|
1744 |
+
" <tr>\n",
|
1745 |
+
" <td>4</td>\n",
|
1746 |
+
" <td>0.238600</td>\n",
|
1747 |
+
" <td>0.323119</td>\n",
|
1748 |
+
" <td>0.911539</td>\n",
|
1749 |
+
" <td>0.790922</td>\n",
|
1750 |
+
" <td>0.910299</td>\n",
|
1751 |
+
" </tr>\n",
|
1752 |
+
" <tr>\n",
|
1753 |
+
" <td>5</td>\n",
|
1754 |
+
" <td>0.160100</td>\n",
|
1755 |
+
" <td>0.328203</td>\n",
|
1756 |
+
" <td>0.915641</td>\n",
|
1757 |
+
" <td>0.793490</td>\n",
|
1758 |
+
" <td>0.913836</td>\n",
|
1759 |
+
" </tr>\n",
|
1760 |
+
" <tr>\n",
|
1761 |
+
" <td>6</td>\n",
|
1762 |
+
" <td>0.163100</td>\n",
|
1763 |
+
" <td>0.348942</td>\n",
|
1764 |
+
" <td>0.917425</td>\n",
|
1765 |
+
" <td>0.813604</td>\n",
|
1766 |
+
" <td>0.916911</td>\n",
|
1767 |
+
" </tr>\n",
|
1768 |
+
" <tr>\n",
|
1769 |
+
" <td>7</td>\n",
|
1770 |
+
" <td>0.124100</td>\n",
|
1771 |
+
" <td>0.373799</td>\n",
|
1772 |
+
" <td>0.916890</td>\n",
|
1773 |
+
" <td>0.820355</td>\n",
|
1774 |
+
" <td>0.916688</td>\n",
|
1775 |
+
" </tr>\n",
|
1776 |
+
" <tr>\n",
|
1777 |
+
" <td>8</td>\n",
|
1778 |
+
" <td>0.118700</td>\n",
|
1779 |
+
" <td>0.399474</td>\n",
|
1780 |
+
" <td>0.916890</td>\n",
|
1781 |
+
" <td>0.818839</td>\n",
|
1782 |
+
" <td>0.916640</td>\n",
|
1783 |
+
" </tr>\n",
|
1784 |
+
" <tr>\n",
|
1785 |
+
" <td>9</td>\n",
|
1786 |
+
" <td>0.066800</td>\n",
|
1787 |
+
" <td>0.414363</td>\n",
|
1788 |
+
" <td>0.917603</td>\n",
|
1789 |
+
" <td>0.830703</td>\n",
|
1790 |
+
" <td>0.917226</td>\n",
|
1791 |
+
" </tr>\n",
|
1792 |
+
" <tr>\n",
|
1793 |
+
" <td>10</td>\n",
|
1794 |
+
" <td>0.075800</td>\n",
|
1795 |
+
" <td>0.413828</td>\n",
|
1796 |
+
" <td>0.919030</td>\n",
|
1797 |
+
" <td>0.828149</td>\n",
|
1798 |
+
" <td>0.918506</td>\n",
|
1799 |
+
" </tr>\n",
|
1800 |
+
" </tbody>\n",
|
1801 |
+
"</table><p>"
|
1802 |
+
],
|
1803 |
+
"text/plain": [
|
1804 |
+
"<IPython.core.display.HTML object>"
|
1805 |
+
]
|
1806 |
+
},
|
1807 |
+
"metadata": {},
|
1808 |
+
"output_type": "display_data"
|
1809 |
+
},
|
1810 |
+
{
|
1811 |
+
"name": "stderr",
|
1812 |
+
"output_type": "stream",
|
1813 |
+
"text": [
|
1814 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1815 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
1816 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1817 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
1818 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1819 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
1820 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1821 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
1822 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1823 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
1824 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1825 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
1826 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1827 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
1828 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1829 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
1830 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1831 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
|
1832 |
+
"<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
1833 |
+
" batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
|
1834 |
+
]
|
1835 |
+
},
|
1836 |
+
{
|
1837 |
+
"data": {
|
1838 |
+
"text/html": [
|
1839 |
+
"\n",
|
1840 |
+
" <div>\n",
|
1841 |
+
" \n",
|
1842 |
+
" <progress value='936' max='468' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
1843 |
+
" [468/468 00:39]\n",
|
1844 |
+
" </div>\n",
|
1845 |
+
" "
|
1846 |
+
],
|
1847 |
+
"text/plain": [
|
1848 |
+
"<IPython.core.display.HTML object>"
|
1849 |
+
]
|
1850 |
+
},
|
1851 |
+
"metadata": {},
|
1852 |
+
"output_type": "display_data"
|
1853 |
+
}
|
1854 |
+
],
|
1855 |
+
"source": [
|
1856 |
+
"for organ in organ_list:\n",
|
1857 |
+
" print(organ)\n",
|
1858 |
+
" organ_trainset = trainset_dict[organ]\n",
|
1859 |
+
" organ_evalset = evalset_dict[organ]\n",
|
1860 |
+
" organ_label_dict = traintargetdict_dict[organ]\n",
|
1861 |
+
" \n",
|
1862 |
+
" # set logging steps\n",
|
1863 |
+
" logging_steps = round(len(organ_trainset)/geneformer_batch_size/10)\n",
|
1864 |
+
" \n",
|
1865 |
+
" # reload pretrained model\n",
|
1866 |
+
" model = BertForSequenceClassification.from_pretrained(\"/path/to/pretrained_model/\", \n",
|
1867 |
+
" num_labels=len(organ_label_dict.keys()),\n",
|
1868 |
+
" output_attentions = False,\n",
|
1869 |
+
" output_hidden_states = False).to(\"cuda\")\n",
|
1870 |
+
" \n",
|
1871 |
+
" # define output directory path\n",
|
1872 |
+
" current_date = datetime.datetime.now()\n",
|
1873 |
+
" datestamp = f\"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}\"\n",
|
1874 |
+
" output_dir = f\"/path/to/models/{datestamp}_geneformer_CellClassifier_{organ}_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_F{freeze_layers}/\"\n",
|
1875 |
+
" \n",
|
1876 |
+
" # ensure not overwriting previously saved model\n",
|
1877 |
+
" saved_model_test = os.path.join(output_dir, f\"pytorch_model.bin\")\n",
|
1878 |
+
" if os.path.isfile(saved_model_test) == True:\n",
|
1879 |
+
" raise Exception(\"Model already saved to this directory.\")\n",
|
1880 |
+
"\n",
|
1881 |
+
" # make output directory\n",
|
1882 |
+
" subprocess.call(f'mkdir {output_dir}', shell=True)\n",
|
1883 |
+
" \n",
|
1884 |
+
" # set training arguments\n",
|
1885 |
+
" training_args = {\n",
|
1886 |
+
" \"learning_rate\": max_lr,\n",
|
1887 |
+
" \"do_train\": True,\n",
|
1888 |
+
" \"do_eval\": True,\n",
|
1889 |
+
" \"evaluation_strategy\": \"epoch\",\n",
|
1890 |
+
" \"save_strategy\": \"epoch\",\n",
|
1891 |
+
" \"logging_steps\": logging_steps,\n",
|
1892 |
+
" \"group_by_length\": True,\n",
|
1893 |
+
" \"length_column_name\": \"length\",\n",
|
1894 |
+
" \"disable_tqdm\": False,\n",
|
1895 |
+
" \"lr_scheduler_type\": lr_schedule_fn,\n",
|
1896 |
+
" \"warmup_steps\": warmup_steps,\n",
|
1897 |
+
" \"weight_decay\": 0.001,\n",
|
1898 |
+
" \"per_device_train_batch_size\": geneformer_batch_size,\n",
|
1899 |
+
" \"per_device_eval_batch_size\": geneformer_batch_size,\n",
|
1900 |
+
" \"num_train_epochs\": epochs,\n",
|
1901 |
+
" \"load_best_model_at_end\": True,\n",
|
1902 |
+
" \"output_dir\": output_dir,\n",
|
1903 |
+
" }\n",
|
1904 |
+
" \n",
|
1905 |
+
" training_args_init = TrainingArguments(**training_args)\n",
|
1906 |
+
"\n",
|
1907 |
+
" # create the trainer\n",
|
1908 |
+
" trainer = Trainer(\n",
|
1909 |
+
" model=model,\n",
|
1910 |
+
" args=training_args_init,\n",
|
1911 |
+
" data_collator=DataCollatorForCellClassification(),\n",
|
1912 |
+
" train_dataset=organ_trainset,\n",
|
1913 |
+
" eval_dataset=organ_evalset,\n",
|
1914 |
+
" compute_metrics=compute_metrics\n",
|
1915 |
+
" )\n",
|
1916 |
+
" # train the cell type classifier\n",
|
1917 |
+
" trainer.train()\n",
|
1918 |
+
" predictions = trainer.predict(organ_evalset)\n",
|
1919 |
+
" with open(f\"{output_dir}predictions.pickle\", \"wb\") as fp:\n",
|
1920 |
+
" pickle.dump(predictions, fp)\n",
|
1921 |
+
" trainer.save_metrics(\"eval\",predictions.metrics)\n",
|
1922 |
+
" trainer.save_model(output_dir)"
|
1923 |
+
]
|
1924 |
+
}
|
1925 |
+
],
|
1926 |
+
"metadata": {
|
1927 |
+
"kernelspec": {
|
1928 |
+
"display_name": "Python 3 (ipykernel)",
|
1929 |
+
"language": "python",
|
1930 |
+
"name": "python3"
|
1931 |
+
},
|
1932 |
+
"language_info": {
|
1933 |
+
"codemirror_mode": {
|
1934 |
+
"name": "ipython",
|
1935 |
+
"version": 3
|
1936 |
+
},
|
1937 |
+
"file_extension": ".py",
|
1938 |
+
"mimetype": "text/x-python",
|
1939 |
+
"name": "python",
|
1940 |
+
"nbconvert_exporter": "python",
|
1941 |
+
"pygments_lexer": "ipython3",
|
1942 |
+
"version": "3.10.11"
|
1943 |
+
},
|
1944 |
+
"vscode": {
|
1945 |
+
"interpreter": {
|
1946 |
+
"hash": "eba1599a1f7e611c14c87ccff6793920aa63510b01fc0e229d6dd014149b8829"
|
1947 |
+
}
|
1948 |
+
}
|
1949 |
+
},
|
1950 |
+
"nbformat": 4,
|
1951 |
+
"nbformat_minor": 5
|
1952 |
+
}
|
examples/extract_and_plot_cell_embeddings.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
examples/gene_classification.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
examples/hyperparam_optimiz_for_disease_classifier.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
|
4 |
+
# hyperparameter optimization with raytune for disease classification
|
5 |
+
|
6 |
+
# imports
|
7 |
+
import os
|
8 |
+
import subprocess
|
9 |
+
GPU_NUMBER = [0,1,2,3]
|
10 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
|
11 |
+
os.environ["NCCL_DEBUG"] = "INFO"
|
12 |
+
os.environ["CONDA_OVERRIDE_GLIBC"] = "2.56"
|
13 |
+
os.environ["LD_LIBRARY_PATH"] = "/path/to/miniconda3/lib:/path/to/sw/lib:/path/to/sw/lib"
|
14 |
+
|
15 |
+
# initiate runtime environment for raytune
|
16 |
+
import pyarrow # must occur prior to ray import
|
17 |
+
import ray
|
18 |
+
from ray import tune
|
19 |
+
from ray.tune import ExperimentAnalysis
|
20 |
+
from ray.tune.suggest.hyperopt import HyperOptSearch
|
21 |
+
ray.shutdown() #engage new ray session
|
22 |
+
runtime_env = {"conda": "base",
|
23 |
+
"env_vars": {"LD_LIBRARY_PATH": "/path/to/miniconda3/lib:/path/to/sw/lib:/path/to/sw/lib"}}
|
24 |
+
ray.init(runtime_env=runtime_env)
|
25 |
+
|
26 |
+
def initialize_ray_with_check(ip_address):
|
27 |
+
"""
|
28 |
+
Initialize Ray with a specified IP address and check its status and accessibility.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
- ip_address (str): The IP address (with port) to initialize Ray.
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
- bool: True if initialization was successful and dashboard is accessible, False otherwise.
|
35 |
+
"""
|
36 |
+
try:
|
37 |
+
ray.init(address=ip_address)
|
38 |
+
print(ray.nodes())
|
39 |
+
|
40 |
+
services = ray.get_webui_url()
|
41 |
+
if not services:
|
42 |
+
raise RuntimeError("Ray dashboard is not accessible.")
|
43 |
+
else:
|
44 |
+
print(f"Ray dashboard is accessible at: {services}")
|
45 |
+
return True
|
46 |
+
except Exception as e:
|
47 |
+
print(f"Error initializing Ray: {e}")
|
48 |
+
return False
|
49 |
+
|
50 |
+
# Usage:
|
51 |
+
ip = 'your_ip:xxxx' # Replace with your actual IP address and port
|
52 |
+
if initialize_ray_with_check(ip):
|
53 |
+
print("Ray initialized successfully.")
|
54 |
+
else:
|
55 |
+
print("Error during Ray initialization.")
|
56 |
+
|
57 |
+
import datetime
|
58 |
+
import numpy as np
|
59 |
+
import pandas as pd
|
60 |
+
import random
|
61 |
+
import seaborn as sns; sns.set()
|
62 |
+
from collections import Counter
|
63 |
+
from datasets import load_from_disk
|
64 |
+
from scipy.stats import ranksums
|
65 |
+
from sklearn.metrics import accuracy_score
|
66 |
+
from transformers import BertForSequenceClassification
|
67 |
+
from transformers import Trainer
|
68 |
+
from transformers.training_args import TrainingArguments
|
69 |
+
|
70 |
+
from geneformer import DataCollatorForCellClassification
|
71 |
+
|
72 |
+
# number of CPU cores
|
73 |
+
num_proc=30
|
74 |
+
|
75 |
+
# load train dataset with columns:
|
76 |
+
# cell_type (annotation of each cell's type)
|
77 |
+
# disease (healthy or disease state)
|
78 |
+
# individual (unique ID for each patient)
|
79 |
+
# length (length of that cell's rank value encoding)
|
80 |
+
train_dataset=load_from_disk("/path/to/disease_train_data.dataset")
|
81 |
+
|
82 |
+
# filter dataset for given cell_type
|
83 |
+
def if_cell_type(example):
|
84 |
+
return example["cell_type"].startswith("Cardiomyocyte")
|
85 |
+
|
86 |
+
trainset_v2 = train_dataset.filter(if_cell_type, num_proc=num_proc)
|
87 |
+
|
88 |
+
# create dictionary of disease states : label ids
|
89 |
+
target_names = ["healthy", "disease1", "disease2"]
|
90 |
+
target_name_id_dict = dict(zip(target_names,[i for i in range(len(target_names))]))
|
91 |
+
|
92 |
+
trainset_v3 = trainset_v2.rename_column("disease","label")
|
93 |
+
|
94 |
+
# change labels to numerical ids
|
95 |
+
def classes_to_ids(example):
|
96 |
+
example["label"] = target_name_id_dict[example["label"]]
|
97 |
+
return example
|
98 |
+
|
99 |
+
trainset_v4 = trainset_v3.map(classes_to_ids, num_proc=num_proc)
|
100 |
+
|
101 |
+
# separate into train, validation, test sets
|
102 |
+
indiv_set = set(trainset_v4["individual"])
|
103 |
+
random.seed(42)
|
104 |
+
train_indiv = random.sample(indiv_set,round(0.7*len(indiv_set)))
|
105 |
+
eval_indiv = [indiv for indiv in indiv_set if indiv not in train_indiv]
|
106 |
+
valid_indiv = random.sample(eval_indiv,round(0.5*len(eval_indiv)))
|
107 |
+
test_indiv = [indiv for indiv in eval_indiv if indiv not in valid_indiv]
|
108 |
+
|
109 |
+
def if_train(example):
|
110 |
+
return example["individual"] in train_indiv
|
111 |
+
|
112 |
+
classifier_trainset = trainset_v4.filter(if_train,num_proc=num_proc).shuffle(seed=42)
|
113 |
+
|
114 |
+
def if_valid(example):
|
115 |
+
return example["individual"] in valid_indiv
|
116 |
+
|
117 |
+
classifier_validset = trainset_v4.filter(if_valid,num_proc=num_proc).shuffle(seed=42)
|
118 |
+
|
119 |
+
# define output directory path
|
120 |
+
current_date = datetime.datetime.now()
|
121 |
+
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
|
122 |
+
output_dir = f"/path/to/models/{datestamp}_geneformer_DiseaseClassifier/"
|
123 |
+
|
124 |
+
# ensure not overwriting previously saved model
|
125 |
+
saved_model_test = os.path.join(output_dir, f"pytorch_model.bin")
|
126 |
+
if os.path.isfile(saved_model_test) == True:
|
127 |
+
raise Exception("Model already saved to this directory.")
|
128 |
+
|
129 |
+
# make output directory
|
130 |
+
subprocess.call(f'mkdir {output_dir}', shell=True)
|
131 |
+
|
132 |
+
# set training parameters
|
133 |
+
# how many pretrained layers to freeze
|
134 |
+
freeze_layers = 2
|
135 |
+
# batch size for training and eval
|
136 |
+
geneformer_batch_size = 12
|
137 |
+
# number of epochs
|
138 |
+
epochs = 1
|
139 |
+
# logging steps
|
140 |
+
logging_steps = round(len(classifier_trainset)/geneformer_batch_size/10)
|
141 |
+
|
142 |
+
# define function to initiate model
|
143 |
+
def model_init():
|
144 |
+
model = BertForSequenceClassification.from_pretrained("/path/to/pretrained_model/",
|
145 |
+
num_labels=len(target_names),
|
146 |
+
output_attentions = False,
|
147 |
+
output_hidden_states = False)
|
148 |
+
if freeze_layers is not None:
|
149 |
+
modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
|
150 |
+
for module in modules_to_freeze:
|
151 |
+
for param in module.parameters():
|
152 |
+
param.requires_grad = False
|
153 |
+
|
154 |
+
model = model.to("cuda:0")
|
155 |
+
return model
|
156 |
+
|
157 |
+
# define metrics
|
158 |
+
# note: macro f1 score recommended for imbalanced multiclass classifiers
|
159 |
+
def compute_metrics(pred):
|
160 |
+
labels = pred.label_ids
|
161 |
+
preds = pred.predictions.argmax(-1)
|
162 |
+
# calculate accuracy using sklearn's function
|
163 |
+
acc = accuracy_score(labels, preds)
|
164 |
+
return {
|
165 |
+
'accuracy': acc,
|
166 |
+
}
|
167 |
+
|
168 |
+
# set training arguments
|
169 |
+
training_args = {
|
170 |
+
"do_train": True,
|
171 |
+
"do_eval": True,
|
172 |
+
"evaluation_strategy": "steps",
|
173 |
+
"eval_steps": logging_steps,
|
174 |
+
"logging_steps": logging_steps,
|
175 |
+
"group_by_length": True,
|
176 |
+
"length_column_name": "length",
|
177 |
+
"disable_tqdm": True,
|
178 |
+
"skip_memory_metrics": True, # memory tracker causes errors in raytune
|
179 |
+
"per_device_train_batch_size": geneformer_batch_size,
|
180 |
+
"per_device_eval_batch_size": geneformer_batch_size,
|
181 |
+
"num_train_epochs": epochs,
|
182 |
+
"load_best_model_at_end": True,
|
183 |
+
"output_dir": output_dir,
|
184 |
+
}
|
185 |
+
|
186 |
+
training_args_init = TrainingArguments(**training_args)
|
187 |
+
|
188 |
+
# create the trainer
|
189 |
+
trainer = Trainer(
|
190 |
+
model_init=model_init,
|
191 |
+
args=training_args_init,
|
192 |
+
data_collator=DataCollatorForCellClassification(),
|
193 |
+
train_dataset=classifier_trainset,
|
194 |
+
eval_dataset=classifier_validset,
|
195 |
+
compute_metrics=compute_metrics,
|
196 |
+
)
|
197 |
+
|
198 |
+
# specify raytune hyperparameter search space
|
199 |
+
ray_config = {
|
200 |
+
"num_train_epochs": tune.choice([epochs]),
|
201 |
+
"learning_rate": tune.loguniform(1e-6, 1e-3),
|
202 |
+
"weight_decay": tune.uniform(0.0, 0.3),
|
203 |
+
"lr_scheduler_type": tune.choice(["linear","cosine","polynomial"]),
|
204 |
+
"warmup_steps": tune.uniform(100, 2000),
|
205 |
+
"seed": tune.uniform(0,100),
|
206 |
+
"per_device_train_batch_size": tune.choice([geneformer_batch_size])
|
207 |
+
}
|
208 |
+
|
209 |
+
hyperopt_search = HyperOptSearch(
|
210 |
+
metric="eval_accuracy", mode="max")
|
211 |
+
|
212 |
+
# optimize hyperparameters
|
213 |
+
trainer.hyperparameter_search(
|
214 |
+
direction="maximize",
|
215 |
+
backend="ray",
|
216 |
+
resources_per_trial={"cpu":8,"gpu":1},
|
217 |
+
hp_space=lambda _: ray_config,
|
218 |
+
search_alg=hyperopt_search,
|
219 |
+
n_trials=100, # number of trials
|
220 |
+
progress_reporter=tune.CLIReporter(max_report_frequency=600,
|
221 |
+
sort_by_metric=True,
|
222 |
+
max_progress_rows=100,
|
223 |
+
mode="max",
|
224 |
+
metric="eval_accuracy",
|
225 |
+
metric_columns=["loss", "eval_loss", "eval_accuracy"])
|
226 |
+
)
|
examples/in_silico_perturbation.ipynb
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"id": "e10ac0c9-40ce-41fb-b6fa-3d62b76f2e57",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"from geneformer import InSilicoPerturber\n",
|
11 |
+
"from geneformer import InSilicoPerturberStats"
|
12 |
+
]
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"cell_type": "code",
|
16 |
+
"execution_count": null,
|
17 |
+
"id": "67b44366-f255-4415-a865-6a27a8ffcce7",
|
18 |
+
"metadata": {
|
19 |
+
"tags": []
|
20 |
+
},
|
21 |
+
"outputs": [],
|
22 |
+
"source": [
|
23 |
+
"# in silico perturbation in deletion mode to determine genes whose \n",
|
24 |
+
"# deletion in the dilated cardiomyopathy (dcm) state significantly shifts\n",
|
25 |
+
"# the embedding towards non-failing (nf) state\n",
|
26 |
+
"isp = InSilicoPerturber(perturb_type=\"delete\",\n",
|
27 |
+
" perturb_rank_shift=None,\n",
|
28 |
+
" genes_to_perturb=\"all\",\n",
|
29 |
+
" combos=0,\n",
|
30 |
+
" anchor_gene=None,\n",
|
31 |
+
" model_type=\"CellClassifier\",\n",
|
32 |
+
" num_classes=3,\n",
|
33 |
+
" emb_mode=\"cell\",\n",
|
34 |
+
" cell_emb_style=\"mean_pool\",\n",
|
35 |
+
" filter_data={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]},\n",
|
36 |
+
" cell_states_to_model={'state_key': 'disease', \n",
|
37 |
+
" 'start_state': 'dcm', \n",
|
38 |
+
" 'goal_state': 'nf', \n",
|
39 |
+
" 'alt_states': ['hcm']},\n",
|
40 |
+
" max_ncells=2000,\n",
|
41 |
+
" emb_layer=0,\n",
|
42 |
+
" forward_batch_size=400,\n",
|
43 |
+
" nproc=16)"
|
44 |
+
]
|
45 |
+
},
|
46 |
+
{
|
47 |
+
"cell_type": "code",
|
48 |
+
"execution_count": null,
|
49 |
+
"id": "0525a663-871a-4ce0-a135-cc203817ffa9",
|
50 |
+
"metadata": {},
|
51 |
+
"outputs": [],
|
52 |
+
"source": [
|
53 |
+
"# outputs intermediate files from in silico perturbation\n",
|
54 |
+
"isp.perturb_data(\"path/to/model\",\n",
|
55 |
+
" \"path/to/input_data\",\n",
|
56 |
+
" \"path/to/output_directory\",\n",
|
57 |
+
" \"output_prefix\")"
|
58 |
+
]
|
59 |
+
},
|
60 |
+
{
|
61 |
+
"cell_type": "code",
|
62 |
+
"execution_count": null,
|
63 |
+
"id": "f8aadabb-516a-4dc0-b307-6de880e64e26",
|
64 |
+
"metadata": {},
|
65 |
+
"outputs": [],
|
66 |
+
"source": [
|
67 |
+
"ispstats = InSilicoPerturberStats(mode=\"goal_state_shift\",\n",
|
68 |
+
" genes_perturbed=\"all\",\n",
|
69 |
+
" combos=0,\n",
|
70 |
+
" anchor_gene=None,\n",
|
71 |
+
" cell_states_to_model={\"disease\":([\"dcm\"],[\"nf\"],[\"hcm\"])})"
|
72 |
+
]
|
73 |
+
},
|
74 |
+
{
|
75 |
+
"cell_type": "code",
|
76 |
+
"execution_count": null,
|
77 |
+
"id": "ffecfae6-e737-43e3-99e9-fa37ff46610b",
|
78 |
+
"metadata": {},
|
79 |
+
"outputs": [],
|
80 |
+
"source": [
|
81 |
+
"# extracts data from intermediate files and processes stats to output in final .csv\n",
|
82 |
+
"ispstats.get_stats(\"path/to/input_data\",\n",
|
83 |
+
" None,\n",
|
84 |
+
" \"path/to/output_directory\",\n",
|
85 |
+
" \"output_prefix\")"
|
86 |
+
]
|
87 |
+
}
|
88 |
+
],
|
89 |
+
"metadata": {
|
90 |
+
"kernelspec": {
|
91 |
+
"display_name": "Python 3 (ipykernel)",
|
92 |
+
"language": "python",
|
93 |
+
"name": "python3"
|
94 |
+
},
|
95 |
+
"language_info": {
|
96 |
+
"codemirror_mode": {
|
97 |
+
"name": "ipython",
|
98 |
+
"version": 3
|
99 |
+
},
|
100 |
+
"file_extension": ".py",
|
101 |
+
"mimetype": "text/x-python",
|
102 |
+
"name": "python",
|
103 |
+
"nbconvert_exporter": "python",
|
104 |
+
"pygments_lexer": "ipython3",
|
105 |
+
"version": "3.10.11"
|
106 |
+
}
|
107 |
+
},
|
108 |
+
"nbformat": 4,
|
109 |
+
"nbformat_minor": 5
|
110 |
+
}
|
examples/pretraining_new_model/obtain_nonzero_median_digests.ipynb
ADDED
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "charged-worcester",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"# Obtain non-zero median expression value of each gene across Genecorpus-30M"
|
9 |
+
]
|
10 |
+
},
|
11 |
+
{
|
12 |
+
"cell_type": "markdown",
|
13 |
+
"id": "28e87f2a-a33e-4fe3-81af-ad4cd62fcc1b",
|
14 |
+
"metadata": {},
|
15 |
+
"source": [
|
16 |
+
"#### Upon request, we are providing the code that we used for obtaining the non-zero median expression value of each gene across the broad range of cell types represented in Genecorpus-30M that we use as a normalization factor to prioritize genes that uniquely distinguish cell state.\n",
|
17 |
+
"\n",
|
18 |
+
"#### Please read the important information below before using this code.\n",
|
19 |
+
"\n",
|
20 |
+
"#### If using Geneformer, to ensure consistency of the normalization factor used for each gene for all future datasets, <ins>**users should use the Geneformer transcriptome tokenizer to tokenize their datasets and should not re-calculate this normalization factor for their individual dataset** </ins>. This code for re-calculating the normalization factor should only be used by users who are pretraining a new model from scratch with a new pretraining corpus other than Genecorpus-30M.\n",
|
21 |
+
"\n",
|
22 |
+
"#### It is critical that this calculation is performed on a large-scale pretraining corpus that has tens of millions of cells from a broad range of human tissues. <ins>**The richness of variable cell states in the pretraining corpus is what allows this normalization factor to accomplish the goal of prioritizing genes that uniquely distinguish cell states.** </ins> This normalization factor for each gene is calculated once from the large-scale pretraining corpus and is used for all future datasets presented to the model. \n",
|
23 |
+
"\n",
|
24 |
+
"#### Of note, as discussed in the Methods, we only included droplet-based sequencing platforms in the pretraining corpus to assure expression value unit comparability for the calculation of this normalization factor. Users wishing to pretrain a new model from scratch with a new pretraining corpus should choose either droplet-based or plate-based platforms for calculating this normalization factor, or they should exercise caution that including both platforms may cause unintended effects on the results. Once the normalization factor is calculated however, data from any platform can be used with the model because the expression value units will be consistent within each individual cell.\n",
|
25 |
+
"\n",
|
26 |
+
"#### Please see the Methods in the manuscript for a description of the procedure enacted by this code, an excerpt of which is below for convenience:\n",
|
27 |
+
"\n",
|
28 |
+
"#### \"To accomplish this, we first calculated the non-zero median value of expression of each detected gene across all cells passing quality filtering from the entire Genecorpus-30M. We aggregated the transcript count distribution for each gene in a memory-efficient manner by scanning through chunks of .loom data using loompy, normalizing the gene transcript counts in each cell by the total transcript count of that cell to account for varying sequencing depth and updating the normalized count distribution of the gene within the t-digest data structure developed for accurate online accumulation of rank-based statistics. We then normalized the genes in each single-cell transcriptome by the non-zero median value of expression of that gene across Genecorpus-30M and ordered the genes by the rank of their normalized expression in that specific cell. Of note, we opted to use the non-zero median value of expression rather than include zeros in the distribution so as not to weight the value by tissue representation within Genecorpus-30M, assuming that a representative range of transcript values would be observed within the cells in which each gene was detected. This normalization factor for each gene is calculated once from the pretraining corpus and is used for all future datasets presented to the model. The provided tokenizer code includes this normalization procedure and should be used for tokenizing new datasets presented to Geneformer to ensure consistency of the normalization factor used for each gene.\""
|
29 |
+
]
|
30 |
+
},
|
31 |
+
{
|
32 |
+
"cell_type": "code",
|
33 |
+
"execution_count": 1,
|
34 |
+
"id": "textile-destruction",
|
35 |
+
"metadata": {},
|
36 |
+
"outputs": [],
|
37 |
+
"source": [
|
38 |
+
"import os\n",
|
39 |
+
"import numpy as np\n",
|
40 |
+
"import loompy as lp\n",
|
41 |
+
"import pandas as pd\n",
|
42 |
+
"import crick\n",
|
43 |
+
"import pickle\n",
|
44 |
+
"import math\n",
|
45 |
+
"from tqdm.notebook import tqdm"
|
46 |
+
]
|
47 |
+
},
|
48 |
+
{
|
49 |
+
"cell_type": "markdown",
|
50 |
+
"id": "4af8cfef-05f2-47e0-b8d2-71ca025059c7",
|
51 |
+
"metadata": {
|
52 |
+
"tags": []
|
53 |
+
},
|
54 |
+
"source": [
|
55 |
+
"### The following code is an example of how the nonzero median expression values are obtained for a single input file. This calculation should be run as a script to be parallelized for all dataset files."
|
56 |
+
]
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"cell_type": "code",
|
60 |
+
"execution_count": 30,
|
61 |
+
"id": "physical-intro",
|
62 |
+
"metadata": {},
|
63 |
+
"outputs": [],
|
64 |
+
"source": [
|
65 |
+
"input_file = \"study1.loom\"\n",
|
66 |
+
"current_database = \"database1\"\n",
|
67 |
+
"\n",
|
68 |
+
"rootdir = f\"/path/to/{current_database}/data/\"\n",
|
69 |
+
"output_file = input_file.replace(\".loom\", \".gene_median_digest_dict.pickle\")\n",
|
70 |
+
"outdir = rootdir.replace(\"/data/\", \"/tdigest/\")\n",
|
71 |
+
"\n",
|
72 |
+
"with lp.connect(f\"{rootdir}{input_file}\") as data:\n",
|
73 |
+
" # define coordinates of protein-coding or miRNA genes\n",
|
74 |
+
" coding_miRNA_loc = np.where((data.ra.gene_type == \"protein_coding\") | (data.ra.gene_type == \"miRNA\"))[0]\n",
|
75 |
+
" coding_miRNA_genes = data.ra[\"ensembl_id\"][coding_miRNA_loc]\n",
|
76 |
+
" \n",
|
77 |
+
" # initiate tdigests\n",
|
78 |
+
" median_digests = [crick.tdigest.TDigest() for _ in range(len(coding_miRNA_loc))]\n",
|
79 |
+
" \n",
|
80 |
+
" # initiate progress meters\n",
|
81 |
+
" progress = tqdm(total=len(coding_miRNA_loc))\n",
|
82 |
+
" last_view_row = 0\n",
|
83 |
+
" progress.update(0)\n",
|
84 |
+
" \n",
|
85 |
+
" for (ix, selection, view) in data.scan(items=coding_miRNA_loc, axis=0):\n",
|
86 |
+
" # define coordinates of cells passing filter\n",
|
87 |
+
" filter_passed_loc = np.where(view.ca.filter_pass == 1)[0]\n",
|
88 |
+
" subview = view.view[:, filter_passed_loc]\n",
|
89 |
+
" # normalize by total counts per cell and multiply by 10,000 to allocate bits to precision\n",
|
90 |
+
" subview_norm_array = subview[:,:]/subview.ca.n_counts*10_000\n",
|
91 |
+
" # if integer, convert to float to prevent error with filling with nan\n",
|
92 |
+
" if np.issubdtype(subview_norm_array.dtype, np.integer):\n",
|
93 |
+
" subview_norm_array = subview_norm_array.astype(np.float32)\n",
|
94 |
+
" # mask zeroes from distribution tdigest by filling with nan\n",
|
95 |
+
" nonzero_data = np.ma.masked_equal(subview_norm_array, 0.0).filled(np.nan)\n",
|
96 |
+
" # update tdigests\n",
|
97 |
+
" [median_digests[i+last_view_row].update(nonzero_data[i,:]) for i in range(nonzero_data.shape[0])]\n",
|
98 |
+
" # update progress meters\n",
|
99 |
+
" progress.update(view.shape[0])\n",
|
100 |
+
" last_view_row = last_view_row + view.shape[0]\n",
|
101 |
+
" \n",
|
102 |
+
"median_digest_dict = dict(zip(coding_miRNA_genes, median_digests))\n",
|
103 |
+
"with open(f\"{outdir}{output_file}\", \"wb\") as fp:\n",
|
104 |
+
" pickle.dump(median_digest_dict, fp)"
|
105 |
+
]
|
106 |
+
},
|
107 |
+
{
|
108 |
+
"cell_type": "markdown",
|
109 |
+
"id": "190a3754-aafa-4ccf-ba97-951c94ea3030",
|
110 |
+
"metadata": {
|
111 |
+
"tags": []
|
112 |
+
},
|
113 |
+
"source": [
|
114 |
+
"### After the above code is run as a script in parallel for all datasets to obtain the nonzero median tdigests for their contained genes, the following code can be run to merge the tdigests across all datasets."
|
115 |
+
]
|
116 |
+
},
|
117 |
+
{
|
118 |
+
"cell_type": "code",
|
119 |
+
"execution_count": 2,
|
120 |
+
"id": "distributed-riding",
|
121 |
+
"metadata": {},
|
122 |
+
"outputs": [],
|
123 |
+
"source": [
|
124 |
+
"# merge new tdigests into total tdigest dict\n",
|
125 |
+
"def merge_digest(dict_key_ensembl_id, dict_value_tdigest, new_tdigest_dict):\n",
|
126 |
+
" new_gene_tdigest = new_tdigest_dict.get(dict_key_ensembl_id)\n",
|
127 |
+
" if new_gene_tdigest is not None:\n",
|
128 |
+
" dict_value_tdigest.merge(new_gene_tdigest)\n",
|
129 |
+
" return dict_value_tdigest\n",
|
130 |
+
" elif new_gene_tdigest is None:\n",
|
131 |
+
" return dict_value_tdigest"
|
132 |
+
]
|
133 |
+
},
|
134 |
+
{
|
135 |
+
"cell_type": "code",
|
136 |
+
"execution_count": null,
|
137 |
+
"id": "distinct-library",
|
138 |
+
"metadata": {},
|
139 |
+
"outputs": [],
|
140 |
+
"source": [
|
141 |
+
"# use tdigest1.merge(tdigest2) to merge tdigest1, tdigest2, ...tdigestn\n",
|
142 |
+
"# then, extract median by tdigest1.quantile(0.5)\n",
|
143 |
+
"\n",
|
144 |
+
"databases = [\"database1\", \"database2\", \"...databaseN\"]\n",
|
145 |
+
"\n",
|
146 |
+
"# obtain gene list\n",
|
147 |
+
"gene_info = pd.read_csv(\"/path/to/gene_info_table.csv\", index_col=0)\n",
|
148 |
+
"func_gene_list = [i for i in gene_info[(gene_info[\"gene_type\"] == \"protein_coding\") | (gene_info[\"gene_type\"] == \"miRNA\")][\"ensembl_id\"]]\n",
|
149 |
+
"\n",
|
150 |
+
"# initiate tdigests\n",
|
151 |
+
"median_digests = [crick.tdigest.TDigest() for _ in range(len(func_gene_list))]\n",
|
152 |
+
"total_tdigest_dict = dict(zip(func_gene_list, median_digests))\n",
|
153 |
+
"\n",
|
154 |
+
"# merge tdigests\n",
|
155 |
+
"for current_database in databases:\n",
|
156 |
+
" rootdir = f\"/path/to/{current_database}/tdigest/\"\n",
|
157 |
+
" \n",
|
158 |
+
" for subdir, dirs, files in os.walk(rootdir):\t\n",
|
159 |
+
" for file in files:\n",
|
160 |
+
" if file.endswith(\".gene_median_digest_dict.pickle\"):\n",
|
161 |
+
" with open(f\"{rootdir}{file}\", \"rb\") as fp:\n",
|
162 |
+
" tdigest_dict = pickle.load(fp)\n",
|
163 |
+
" total_tdigest_dict = {k: merge_digest(k,v,tdigest_dict) for k, v in total_tdigest_dict.items()}\n",
|
164 |
+
"\n",
|
165 |
+
"# save dict of merged tdigests\n",
|
166 |
+
"with open(f\"/path/to/total_gene_tdigest_dict.pickle\", \"wb\") as fp:\n",
|
167 |
+
" pickle.dump(total_tdigest_dict, fp)\n",
|
168 |
+
"\n",
|
169 |
+
"# extract medians and save dict\n",
|
170 |
+
"total_median_dict = {k: v.quantile(0.5) for k, v in total_tdigest_dict.items()}\n",
|
171 |
+
"with open(f\"/path/to/total_gene_median_dict.pickle\", \"wb\") as fp:\n",
|
172 |
+
" pickle.dump(total_median_dict, fp)\n",
|
173 |
+
"\n",
|
174 |
+
"# save dict of only detected genes' medians \n",
|
175 |
+
"detected_median_dict = {k: v for k, v in total_median_dict.items() if not math.isnan(v)}\n",
|
176 |
+
"with open(f\"/path/to/detected_gene_median_dict.pickle\", \"wb\") as fp:\n",
|
177 |
+
" pickle.dump(detected_median_dict, fp)"
|
178 |
+
]
|
179 |
+
},
|
180 |
+
{
|
181 |
+
"cell_type": "markdown",
|
182 |
+
"id": "e8e17ad6-79ac-4f34-aa0c-1eaa1bace2e5",
|
183 |
+
"metadata": {
|
184 |
+
"tags": []
|
185 |
+
},
|
186 |
+
"source": [
|
187 |
+
"### The below code displays some characteristics of the genes detected in the pretraining corpus."
|
188 |
+
]
|
189 |
+
},
|
190 |
+
{
|
191 |
+
"cell_type": "code",
|
192 |
+
"execution_count": 38,
|
193 |
+
"id": "decent-switzerland",
|
194 |
+
"metadata": {},
|
195 |
+
"outputs": [],
|
196 |
+
"source": [
|
197 |
+
"gene_detection_counts_dict = {k: v.size() for k, v in total_tdigest_dict.items()}"
|
198 |
+
]
|
199 |
+
},
|
200 |
+
{
|
201 |
+
"cell_type": "code",
|
202 |
+
"execution_count": 44,
|
203 |
+
"id": "polished-innocent",
|
204 |
+
"metadata": {},
|
205 |
+
"outputs": [
|
206 |
+
{
|
207 |
+
"name": "stderr",
|
208 |
+
"output_type": "stream",
|
209 |
+
"text": [
|
210 |
+
"/home1/ct68/miniconda3/lib/python3.8/site-packages/seaborn/distributions.py:2557: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).\n",
|
211 |
+
" warnings.warn(msg, FutureWarning)\n"
|
212 |
+
]
|
213 |
+
},
|
214 |
+
{
|
215 |
+
"data": {
|
216 |
+
"image/png": "iVBORw0KGgoAAAANSUhEUgAABR8AAAMRCAYAAABlG8GWAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAABcSAAAXEgFnn9JSAAC/KUlEQVR4nOzdd5hjZ3X48e/Z7l2vK240G0wzBgOmmmp6NT/TQmjBlCS0ACGE3gklJARCLyGYGgi9hRqwgYBpxnRMMTZgsI1x2+Lt5/fHe8d7dUfSSBpdaWb2+3kePaN7dcs7M1ea0dF5z4nMRJIkSZIkSZLGbdm0ByBJkiRJkiRpaTL4KEmSJEmSJKkVBh8lSZIkSZIktcLgoyRJkiRJkqRWGHyUJEmSJEmS1AqDj5IkSZIkSZJaYfBRkiRJkiRJUisMPkqSJEmSJElqhcFHSZIkSZIkSa0w+ChJkiRJkiSpFQYfJUmSJEmSJLXC4KMkSZIkSZKkVhh8lCRJkiRJktQKg4+SJEmSJEmSWrFi2gOQJEmTERFXAY4GDgf2A1YDG4BLgPOA0zPzT9Ma3yAi4mTgkbVVj8rMk/tsfwTwm9qqczLziDbGJqm3iDib8toz4xqZefZ0RqM9XUQcD3yltsq/DZLUIoOPkiQtYRFxLHAScF863/j32v4c4AvAe4GvZWa2OkANLCJOAe4wzmNmZozzeJImw9cDSdJi4rRrSZL6iIjlEXFpRGR1e/uA271m0mNtjOeGEfFF4HvA3zFA4LFyOPDXwKnAryLiERHh/wuSJEmSRuKbCUmS+rsJsE9t+Ss9trtZY7tT2hpQP1E8AzgduEufTRO4mDLtuld24zWBdwPfGOsgJUmSJO0xnHYtSVJ/zWltp/TY7k61+7uAr7Yymj6qDMX/AB7V5eE/AB8DPgt8F7gwM3dW+60GrgPcFrgf5XtZXtv3ei0OW6M7DXjXtAchaUHw9UCStGAZfJQkqb/ja/fPzMw/9NiuHnz8QWZe3N6QenoDswOPG4FXAK/JzMu77ZSZW4EfVbc3R8SRwHMptSKtAbZwnZmZb5n2IKRB2Myjdb4eSJIWLKddS5LUQ5VJeNvaqq5TriNiFXDrubZrU0Q8Gnh8Y/X5wO0y8+W9Ao/dZOavM/PRlO/prDEOU5IkSdIexuCjJEm93RjYr7bcK6h4K2BtbfmUdobTXURcGWg2uLkEuG1mnjHqcTPzNErNyy+MPDhJkiRJezSDj5Ik9bZY6j0+n85mNwBPycxfzffAmXkZ8JfzPY4kSZKkPZM1HyVJ6u342v2fZOYFPbarBx+/n5mXtjekThFxILPrPH41M989rnNk5q5R9ouI/SlZoYcAB1G6av8J+A1wWmZuG9cY2xQRhwI3Ao4A9gVWAZcDlwG/BX49jkCveouIFcAtgBsABwJbKE2UvtfWzz4i1lGu3+sA+1M+WPhjZg7U1CMiDgeOpVz7BwKbgQuAnwI/zMxeXeYHHV8ARwLHAIdRPoCI6jwXAedQ6gCeN+Lx96mOfR1KBvhaYCuwCTgXOBv4aWZun8/3MR8RsS9wG+DawN7ApZTr4quZeeGYzrEGuD1wdeBgys/gt8C3MvO34zjHYlS7/q5L+dnsQ0lsuRi4kPLc/E1L574ecDRwJeAAYAfld/8r4EeZ+acxnWcl5TXgBpTXgMspz+Fv+ZovScMx+ChJUhdVvcfb1Vad0mO7vYBbzrVdix4GrG6se9OEx3CFiFgOPBL4a+DmdHbNrtsYEZ8BXpyZP5vU+AZVBbseW91uOsD2FwFfAz4MvH/UgO1CEBF3BT5H5wyZl2TmC4c4xq2BU+n8X/M1mfm0Hts3A3HXyMyzI2It8EzgiZQAXrd9Twf+KTM/NsT4TgLeWVt1amYeXz12beBFwAOY/dyCPh2FI2I/4O8p2cLX6TOE8yLivcArMvOiQcddnWN/4OnAwylBn7m2/y3wReA9mXnqANvfHXgycDfmfq+wJSK+A3wMOLlfo62IOBs4vLbqGpl5dp/tXwTUr7l3ZeZJ1WPXBP4JeCCwssvuGRFfAp6VmafP8T30Ov9VgJdRroO9e2zzDcpr2Bd6jPnFmfmiUc6/EFXX3v2Be1MCsl2fk7XtzwXeAbx+vsHgqhHaPwInAFfus2lGxA+BjwLvyMxzRzjXeuA5wOPoLL1S3+ZnwAsy88PDHl+S9kROu5Yk7bEiInvdgJ2UTIcZT+yx3WZKJtyMf+hz3ONb+Dbu21i+iBIImLiIuC0lq+sdlGyRXoFHKG/mHwz8KCJeXmXRLAhVxtrpwJsZIPBYOQD4f8B7mD0FflHJzC8CL2msfl4VlJpTRFwJ+CCdgatvUoKIA4uIawDfA15A/yDHscBHI+JDVZbayKrGTT8CHkr3wGO/ff+W0qDpBfQPPAIcSgkgnhURDxziHHcDfkkJjMwZeKxcHXgM8O9zHHtNRHyQEni+F4MlKayhfEjzb8wuU9GKiHgI8GPgIXQPPELJAL0r8K2I+KsRz/FzygcpXQOPlVsDn4+If68+sFqyIuIY4DzgP4D7MUfgsXIVyvPhVxFxnxHPuyYi3kT5ffwt/QOPUH73NwJeDHx6hPPdiPIa8Cx6BB4rRwEfiog3L/XfvSSNgy+UkiQtUlV23m0aq78xjenMEfEI4H/pHnRJyhTljV0eWw48G/ivhfAGLiIOoGQw3rDHJpspUwo3TWxQ0/FS4PO15WXAeyPiqv12qn6H7wPq210I/MWQ03OvRLmertdYv5Ey9bGbBwIfj4ihgoYzIuKRlMB5c/9LgJ5jj4jlEfF64C10fmAxYydlKurWLo/tC/x3RPzdAOO7DfApugd9EthA+Vl3O88gPgz8RY/HtgJ/pjyPp5bVGxEPp1xfe9VW76L8fLu97q0ATo6IOw5xjkcC76V70HHmXDsb65/M7KZfS81aOj9oq9tOuT66vcZDuc4/GREPHuaEVTO1rwKPp3cw/DLKtd/1EEOe7waUxnKHNx66jN6v+Y+jBFglSX1M/Z98SZI0sqMomUd135n0IKqMlnfR+cb0YuBVwHHAmszcNzPXUzJJHgB8o3GYB1Ma50zby4Cr1ZYTeDdlCur+mbkuMw/KzL0p3+/RlAyskyn1LJeEatr4w4Hf1VZfiRIo65VtBuV3eLf6oYCHZ+bvhxzCG4BrVPd/TalreqXMXJ+ZaykZVU+m1F+ruzvwyiHPBSU7cKZcwS7KlOw7AKszc39KQPIISjZU0yuBJzXW/Yoy/fr6wMrMPCAz11TneQKlHuOMAF4bEXeiv7fR+Ry7DHg5pbzBuszcp7o211ACZzenBG0+Se+AbRlACQrdu7H6a5Tp41fJzDWZeaXM3JcSBDoCuA/wauDMOcY9LjegZN0Fpebnayh1QFdVP9/V1TZvoDNAGsB/VCUh+qqy3t5O53ukXcBbKdncqzPzAMrv4UaU17iZYO+TgXuM/N0tHpcDn6Fc87cB9svMVdX1sR5YR8kIfTWdwciZ38ORg5ykKmnyGcp1XHcxJTP7lpTfx76ZuQ/ld3ITyvPrS8wOEM9lL8qsgZkPED5G+X2uq86xN+V15x8oH0jUPSci5sp2lqQ9Wsyz1rUkSYtWRDyux0PLgNexe9rwl4EPddluNfDa2vIngc/2OeUnM/MPQw6zp4i4H6WuVd39MvPj4zrHAGM4HPg+nRlfXwQekZnnz7Hv8ygZdjN2ATfLzO/32edkylTIGY/KzJP7bH8EpcHNjHMy84ge266iBLP2ra3+y8z8YK/jd9n/AcDHM7NvsGcUEXEKnVNbr6iB15aIuBUl86gecPz3zHxql23vQsmWrAduXpqZc2YFdan5OONTlN/B5h77HVidsz49fhdw28z8Zp/znURnzccZG4D7ZuYpc425Os59gU80Vv8r8Nx+GcgRsTclg69eNuEPwJGZuaXL9rcAvlVbdQlwq8wcKPBXZfTeOTO7vY4REf8D3LO26s3AEwdtihMRtwf+1K9+6xhqPs74DXCvzPx5n30fQfnQoO7/ZeYn++yzjDLN/8a11RuBe2fmV/vsdzQl2HVol4dbq/k46deDiLguJUD9jkGbqlV1Mz9JKY0w4z8z8zED7Hsyna/1AB+nvOZfMsD+R1Cey6/r8fjxlCzHps2UD0x6li+pMiT/j84SGz1r2kqSgMz05s2bN2/evNVuwM0oGVszt4f32O4Oje3uOeFxPqFx/gRuN+Ex/Gfj/F+lZKMMuv8bG/t/YI7tT25sf9Ic2x/R2P7sPtter7Htt6d9LTbGd0qX3/d8bicOeN6ndNn3AY1trkIJ3Na3+RKwbMBzdBvfDylZs3PteyXg/Ma+n5ljn5N6nPM+Q/w+llHqL9b3f9UQ+6+hBLvq+z+ux7aPG/U8A47lvNqxtwH7tnD9nt34Ho6YY/sXdfn9XAZce8DzfbKx77vn2P5eo14PlIDl9i77v2jcP8faOafyejDCOA+mTMmeOc8WShZ5v32OoXyIUB/fhwd9PRlwXMf3+Dn85YD7P62x32/b+l178+bN21K4Oe1akqTZmtMfv9xju+Nr93cCX29lNL2t77JuoIyUcajq/z28tmoH8NjMHKbm3HMpAYUZD6yy2abhgMbyr6YyigUmM/+d8sa/7j+jdIWeqT36AeCg2uN/AB6a8+v6/eTskgXYZXwXUq6juntUDWuG8enMHKZBxQOBa9WWfwU8b9Cdq+/tHxure2Vjt31t1o9/YQ6Y2TYFr8zMXw647dsay83pu03Nn/0nBr0eMvMMdk/bV01mXkCppTpjNWVadj/PprNe4x8pf1varjf6xcz8wIDbvpPyN2/G1SLikBbGJElLgsFHSZJmqwcfz8zeU6WPr93/Xmb2Knrflm6NNSbZCOVBdE7H/Vxm/mKYA2SZPve52qrllO6503BJY/kmC6EJzgLxaKD+u90H+HBVl+2VwG1rj+2gZA816zEO46c54NTnynvpDGIvY3YNw7k0g1VzeVhj+S05ZLOnzPwyJetwxjHVFOmmSxrLg3ZhH1T9+IfM1VhoSnYx3O+oWVf2Or2ez1UA/c6N1cMGE9885PZ7ktMay7fqtWFVU/a+jdWvywGmWo/BwL/DzLwYaJYZaDbIkiRVenUNkyRpj1S98akHUrpmPVYddetvoE5pcVi9dMswXDfB89+hsfy5rlvN7Xt0dtk9jlLba9LOpDQzmKlfeT3grRHxtCkElgdxGqXRz6jOGHTDzNwQEQ+k1B2c6TR8DKUj9XGNzZ+bmV+bx7hgdh3Fuca3JSK+QMlGnHErSvORgQ4BnDro+aogVjNIPur1/31211sMSiONZu3YZvDmMRFxBvDWMWWDnQacUN1fRgks/2X2qck4BT+uslwHkpkXRcSl7K7huoySLd4tq/MYSjfnGVvonfHe63w/j4izgGsOs98YTez1oK7qSH09SjOx9ZRyAs0u081mLFejt1vS+buA8uHCJAz8GlA5C7hhbXm/8Q1FkpYWg4+SJHW6BZ0BvK/02O5WdHaaPqWtAfWxscu6fbusa0sze+V6fZr49HNMY/mwEcczL5m5MyLeSmdH48cCD4qIj1A6r34tMxdKV+szM/MtkzpZZv4oIh5Pqbs5oxl4/BTwL2M43ekj7lMPPt5oiH3PyczL5t7sCtehs8kSwJ0iYpSs3Ss1lmdd/5l5ekScxu7n3HJKZt4zIuK/KYHPb2WPxjwDeCO7g49QAkC/jIjPUgLBp2Tmr0c89ricPcI+G+h8TdyH7sHHZsbajzNzR5ft5vJ9phd8nNjrQUTcjZL5e29glDIZzedOXTOr95zM/P0I5xjWZZl50ZD7ND+U2qfrVpIkg4+SJDXUp1wnvYOPx9fuT6PeI5Q6WE3dpmyOXZX5dVBj9ZPGdPhp1XwEeAlwezprku1LmXb8aICI+AWl0+nXgC9n5jmTHuS0ZOa7qgBbt261ZwOPzMwcw6lG+Zk29xnmOvrzkOfq1tm4a1fdEfQa9yMoU4nrz7sjgGdUtx0R8X3KtflVSsDw4kFOmJmfj4jXAH9fW72CEpA8ASAizqvO/zXg1OzTlb4ll4ywz87G8vIe2zWDYb1Kbcyl22vykhER16JMfb/jPA/VrV7xjObflUnV3r1khH0Gvb4kaY9n8FGStEeJiJtRuln3Us+cuojSAKXbdvev3f8T8LAe2/0hMz857DgH1C0T6RiGnLI6ov1pr3Z0c8rdxGTm5RFxZ+AVlG7iq7psdp3q9iiAiPg28HbgXZm5fVJjnaLnU7pFN99oP2bQYNcAhslCnNHMaOuXXdXULYu4nzYD5F2v/8z8VUTclFKXrls9yxWUpio3B54KbK+mor82M78010kz82kR8XPgZczOxoQScL1/dSMizgbeTanHN2zwdhTjCGr3sl9jedQyC6Nct4tCRNyA0sF+HE1V+v3taD63LhnD+QbR5vUlSXs8g4+SpD3NfYAXDrjtgQxWgP7QPtudCrQVfPwppe5jvfHMXB1dx6VbUG5cukZxJ6XqQvz3EfFq4K+AE4Fj6Z3Vcovq9oyIeEhmfm8iA52CKBH2t9D9Z/F4hqyT18cogYBJXjdTuf4z83fAfSLiWMq1eS/g2j02X0kJUt47Ij5HyUrt2wQoM98WEe8HHkxpKHVbeteRPQJ4AfDUiHhiZk6qLl8bmvVzR/39tnldTE1VC/kDzA48/gD4KPAdSubxecDlwNZ6LdKIOJ7eswjmYlBQkpYAg4+SJC1Smbk9Ir5B5xS420TEqmG77o6gW6bTUZn585bPOzFVnbGXAy+PiPWUenvHAbehBGWaGWrXBr4cEbfNzB9NdLCT84/M7kQ744ER8ZTM/PcxnGeU2qXNemvjysLspnn9n5+Z3aZityIzT6fUuHxqRBxGKRNwa8p1eVNmB4fvAXwpIm6dmX2zPKvH3wG8owo63YRy3d+WUpLg4MYu+wDviYgVmXnyvL6x6WleK/uNeJxR91voHgYcXVveATxqiIDz3kOcq9lUaJgMZknSAtXWdClJkjQZzazKA4D7tX3SKrjZnGJ4rbbPOy2ZuSEzv5iZL8nMu1N+zicAn29sug+Dd1heVKpajy9rrD6rsfwvEXHLMZzu8DHs0+ZU4GbToUOqAPXEZeYfM/MjmfkPmXlLSnDwr4GfNTa9ISV4PMyxt2fmtzPz3zPzQZQs71tS6v41G7K8NiIWa6CoWavxqBGPM+p+C939G8uvHDLTtVnHsZ/mc2vJ/l2RpD2JwUdJ0h4lM1+UmdHtBvxPbdMz+mz31dp2p/Xarrod3/K39D6gmeX4hJbPOaPZcOL4CZ136jJza2Z+OjPvQWd3bIDbR8TVpzGutkTEwZRpl/VZM6dSaozWO1OvBP47Iubb+OjYMezzg3mOoZ+fAVsa6+7Q4vkGlpkXZeZ/ULp9f6rx8CPmeeysgpF/S8m4rgcg96WzY/Zi8t3G8lUj4irDHCAiVgM3HtuIFpYbN5bfPeT+txhi2+bv4vCIuOqQ55MkLTAGHyVJAiJiOXC72qpTemy3hpL5M2PUOlZjkZl/At7VWH37iPircZ2j6mzdzRcbyw+IiD2xpMurmJ05daNpDKQN1e//fcCVa6vPBx6SmZsoTZouqT12deDd0aMD04D+35BjXAPcrbH6tHmcv6+qLmizw/2D2zrfKKrmR89orL7GuDI0M/PrwEcaqxfldV/Vwmx2VX7YkIe5H73rYy52zan2A3ejr/623muIc30H2NRY9/Ah9pckLUAGHyVJKm4O1N+U9woqHkdng5epBh8rL2Z2d9Z/j4h5T1eLiH2A/+rx8EeAXbXlI4DHzPeci01mJrPfjC+lIMQLgbvUlncBD83MPwJk5m+oOn/X3Bt45jzOef2IGCaT8OF01nzcBXxmHucfxH83lh8SEddv+ZzD+k2XdeO8NpvHX8zX/fsay0+NiIFqj1Yfujx3/ENaMJrZ9fsNse9DKR9IDKQKmn+8sfrvBv1dSJIWJoOPkiQV9aYtu+icWt1ru+3A/7U2ogFl5rnAPzRW7wd8PSJGzkSKiFtRptTevcd5f06Zilv3rxFxk3mcc2qdrkfN2qyacjQDvefNf0TTFxF3A57XWP2izOzoap2ZHwde3djunyLi9vM4/eurqaxzjfFKzK5F+fkqKNqmk4Gza8vLgQ9FxH6jHrDX9T+PjOJmMHQnjZp688xWbh5/MV/3b6O8ps84DPiPKnNvLv8C3KCVUS0Mv28sDzS9vio/8doRzvdKOrtcX5nSAMn3rpK0SPkCLklSUQ8qnpGZl/TY7vja/W9l5ubWRjSEzHw75c1z3SHA1yLi2RGx16DHiohrRsQ7KIHVI+fY/PnApbXlvYH/jYgHDHq+6pyHRcQLmV2jbpKeEBGfjYh7DPkm9xXAlWrLGylTBxe1qs7a++j8f/HzzA70zXgWncH45cAHqnqRo7ghJZjX89qNiAOBz9E5LTT7jHFsqgytpzdWX58S9B8qEBURx0TE2ylZzN28OyLeHhHHDHHMdcwO/HwtM3c21t0gIn4YEY+p9hn0+P8PuE9j9ULIBB9JZv6BEvSqeyDwyYi4Wrd9IuLAiDgZeGq1qlkHdKn4cmP5ZRHR929DlQX8VUpzrqFk5o+BdzZWPwD48BDZqEdExJOHPbckqR17Yl0mSZI6RMQq4Da1Vaf02G4vFlC9xy6eAOxFZ1OJ9cDLgSdFxEeBzwLfAy6cCUJU2WXXpvwM7g/cmRI4mlNmnhURD6ZMcZ3ZZ3/Km8TTgP+gvAH9dWbuqs4XlGDRDYGbUrJojqMEub430nc+HsuAe1S3P0XEJyi/4zOAX1UdvgGIiEOA2wNPqr7Wvb2qhdi260bE4+Z5jK9k5pnNlVU23AfpDKr+Hnj4zO+xKTN3VNfC99nd3fYw4P0Rcbde+/XwLcpz7QTgRxHxT8AnM/OianyHUQJDz2N2Pbo3ZOZEMpIz8yPV2OrZoUcDZ0TExyglC76RmVdkBFaZdIdTmvUcR6lved3q4Wb26Iy1wEOAx0bEmcDHgG9Srs3f155by4BrUK7hpwHXbBznNT2Of0PKc/X1EfF5SpD5dOAn9Wu5qhd5c+CvKK8z9cD09+idMb5YvBS4J3Cz2rp7Ab+KiC9RmqFcSHmNuxElK3wmYPt7SimKp9T2rWfvta211wPgLcDj2f37PgT4bkS8DPjvzPwtXPG6cTPKVOu/BVZV25/C8A3JnkRpInXj2rr7AXeIiNcDnwZ+UH0IMJOBfn3Kc+r+wJ2AHwOvG/K8kqQWGHyUJKkEOdbWlnsFFW/N7jdT/babiszcGRGPBH4OvITOAOKVKW/mnjSzeURcXG2znv6zIZpdrZvn/XwVdHonnXUzb1XdAHZFxKXVeeY630JwEPDY6gZARGwGNlOCDb2y8b7L5Gq/1X++o3oU0C3Y8M+U633GDuDBmXlhv4Nl5rkR8TBKNuLM7/jOlLqRLxxiXE+i1FS8BiX79p0AEbGBcs2u7bHfl5jdZKVtL6CM6VnAzLTp5ZTg6AMBImIHJUN4DfOvi3hdOjusZ/Vz2U6pe7myx35vzMxPznHsvYATqxsAEbENuIxS67ZXs5o/UwLTkwy2jV1mbo+IuwNfoHwwMmMVJQjZq3HKRZRA+f0a6yeZCdna60Fm/jgiXkNneY/9KNPN/yUiNgFbKUHZZumAzwP/ypDBx8y8PCLuTcmGr3eyP4Da60n1dyUo1+bUynZIkvpb6P/4S5I0CfUp1zuBrw2w3VZK5tGCksXLKW+c+wVHg/Imbl96/z/wE+ABmXnHHo/Xz/sR4Bb07jC8jPLGtN/5dgE/mOtcLZorcLKWkgnYK/D438AdM/PysY5qwiLiRErWXN2zMvMbg+yfmV+kZJDVPa+qHzmoCylBy2YgZD29A48fA+5bdaKemOo59xxKBuNve2y2AjiQ/oHHzXQPBEP/azMoQccD6R543Ao8PzOf1OWxuY4NJfB2JXoHHs8AjqtqwC56VXbtHSkZc80p6t18i/L9n0Fn0yPo7AK/2D0DeEePx9ZR/p40g38fomQh7hjlhNVU+NtRPnzolTm9L+Xn3i3wOEy2tSSpRQYfJUmaXe/x0h7bHV+7f9qkgxzDyMwfZOadKFMk3wj8bsBdzwbeRHkzfYPM/OgQ5/x5Zh5HCRp9nM5akL1spmTGPB04PDOn2S37DcBtKTUcv0kJ2sxlM+UN9h0y88GZubHF8bUuIq5JaaRS94nM7DUduJeXULIQZywD3hsRVxn0AFXDmGOrY/25z6Y/AB6YmfefZuA3Mz9FaTz0KODrzO4Q3M0FlOntJwGHVrVbu3kYJbj5VuCnDDad93zKNX39zPynPuP+AXAU8I+U5+IlAxx7F2Uq7SOBm2bmLwfYZ9HIzA2Z+RTKNN4XUj5U+QMlu3QTJbv8XZRMyOMy8xfVrs0SABdPZsTty8xdmflY4EHM/SHRtykfXP3FfOsiZ+bmzHw0pTTAe+j/WgDl2vwW5Xoe5gMPSVKLYpHPjpAkSQOqGojcALg6ZcrcKkpzlIuBPwLfy8y53tgNc77llHpd16JkZe1PCchspLyRP5NSC3J7r2NMU1UL9HqUab9XpmR+LaeM/8+UINBPMnOQIKX6iIjmP6TXyMyza4+voGTW3pByLW2hXEPfW6iBr1qN2KtSxryeEqy+DDiHcv3/bpSpylXTjetRajoeTMk8S2AD5bn8I0qd0qEzv6qarNeqblejZJatrsZ+KfAL4Id9PqTZY0XELyk/txk3rJqnLDkRcS3K9X0oJRt8I+W6/nZmntvieZdR/q5ch5KRux9wOeXv2C+BH/VpGCdJmhKDj5IkSZqquYKP0kJXdTj/UW3VRmDfUQLAkiQtNU67liRJkqT5eX5j+csGHiVJKgw+SpIkSRIQEatH2OepwF80Vr9pLAOSJGkJMPgoSZIkScXLI+JjEXGPqu5rTxFxREScDLym8dC3gS+0NUBJkhabFdMegCRJkiQtEMuBE6vbhoj4FqWW4/mUOo57A4dRmq3cvNq+bgPw0FEaCUmStFQZfJQkSZKk2dYDd6lugzgPuH9m/rq9IUmStPg47VqSJEmSirOArUPusx14F3CzzPzm+IckSdLiFs4IkCRJ0jRFRPMf0mtk5tnTGIsUEeuBuwG3Bm4EHA4cBKwFErgE+DPwQ+BrwCcz83dTGawkSYuAwUdJkiRJkiRJrXDatSRJkiRJkqRWGHyUJEmSJEmS1AqDj5IkSZIkSZJaYfBRkiRJkiRJUisMPkqSJEmSJElqhcFHSZIkSZIkSa1YMe0BSJKk3iLiFOAOtVV3zMxTpjOa6YqIVcAxwDWBw4B1wE7gEuBi4EzgJ5m5Y1pjnISIyPpyZsYc258MPLK26lGZefL4RyZNX0QcD3ylturUzDx+KoPRghAR1wKuC1wN2AdYBWxk99+OnwC/zszsdQxJ0vwYfJSkJaRLkKGXncBllH+8fwl8G/hcZv5fa4OTRhARa4GHAX8B3A5YPccul0fEd4H/Bj6YmX9qeYiSFrkh/3ZuAC4FzgFOB04FPpOZ20c894uAF3Z56NOZecKIx2wG0e6amV8a5Vi1Y34CuG9j9ecy857zOW4bImIZcB/gwcA9gAMG2G1DRJwOfBL4SGae0+IQJWmP47RrSdozLQf2B64B3A14HvD1iPhBRNx7qiNrSUTcOCJeVLudNO0xqbeIWBkRTwd+B7wNuAtzBx4B9qIEKV8P/CEi3hURh7c3Ukl7kOXAfsDhwO2BpwIfA34fEU+PiOVjPNd9IuLWYzzeyCLiYOBeXR66W0RcZdLj6SciHgL8CvgE8FAGCzwCrKfMMng1cHZEfDUi7t7OKCVpz2PwUZJUdwzw6Yh47bQH0oIbU7JLZm4nTXMw6i0ijqRk4/4L/d84bgH+DGzt8fgK4K+AMyPi/411kJK028GU16tTI2L9GI/7ijEeaz4eQfcZc8sor7FTFxEHRcRngfdTPljtJSmzPi6lZLL2cjvgcxHxubENUpL2YE67lqSl7ZfAv3VZv4IS1LkJJaNs78bjT4mIyzLzBS2PT+oQETcDPs/soGMCXwI+U339bWZuqO13KOV6vitlqt2Va/uupv+bUUmq6/W3cybz8SjgzsChjcdvA3w8Iu6ambvGMI7bR8Q9MnPaAbCT+jz2KKYcJK0+sPoi3V/nzwA+DXwZ+DFwUWburPZbARwB3ILyv9D9KL/fuhu0MWZJ2tMYfJSkpe0PmfmWfhtExAGUNw5/03joeRHxscz8fmuj05z2pEYJEXFd4AuUkgB1XwWenpnf6bVvZp4HfBb4bEQ8A/hL4KWUN5Z7tMw8CTN9tYeoGnL1bcI0gEH+dq4CHgf8M7Cm9tCdgIcD757nGGa8PCI+P61mKBFxczoDcDPjmPkZXzsibjOtmtHVB09fBq7eeOhHwHMz81O99q2ak/2qur0/Ip5AyfJ8Nn5gJUlj5bRrSdrDZeZFmfm3lCljdUH3IvjS2EXEXpQaXc3A45spHb57Bh6bMnNHZr6Xkp30Gna/WZakscjMbZn5OkpdwaZnz+PQ59P5mnUT4EHzON58Paqx/GU6u4l322YiqsYyH2Z24PEjwC36BR67ycwtmfl2yt+O5wHbxjJQSZLBR0nSFZ4D/Kax7h5VUEhq20uB6zbWvTEznzDq9MXqjeTTgAcCm+c7QElqysyPUbKu664XEc2A2KB+BXygse6lY25mM5CIWAM8pLH6XdWt7i8iYt1kRtXhKZSp7nX/DfxFZm4Z9aCZuTUzXwbcitn/F0mSRmDwUZIEXDH96B2N1auBBdFtU0tX9Sb97xqrfwY8fRzHz8yPAv8xjmNJUhf/3WXdLedxvBcAO2rL12E62YUn0lkDcSPwUUq24Yba+vXAAyY2KiAi9qVkJ9adCzxuTPU2qcrO3GMcx5KkPZ01HyVJdd/osu7wUQ4UEYdQ3nxdg/LGZAvwg8z84gD7XplSAP5g4EBgE/An4BfA6dOqfTWXqvbUzSnjPojShflPlOYF350pcj9tVTbrbYDrAftS3kSeD/xfZv5+CkN6CrCqse6J88lcaRrlzegkr8OIuAbl2rkysBa4CPgpcFpmLripfxGxkpIVdAPKVPnLgQuAb2Xmr8Z0jqD8/K8DHFatPh84IzN/MI5zjFPV5fjWlN/hQZROun8Cfgd8c5zXc+O8y4GbAkdTrtUVlC7wH8nMP7Vxzi5jWEv53q9LCVZdDpwFfC0z/zzA/uuB46r996F0Iz4H+Epmbmpn1GP1oy7rDh71YJn5q4h4B/C3tdUvjIj3tnUd9dAMeH545vcRER9uPP4oxlfnchCPZnZjsmdk5sXjPMl8r78qI/Q4dr8uLKe8Lvye8je3laz8iDiI8nf+msBelL8pvwO+mpmXjfE8AdyI8jp9EOV/iospfw++nZm/G9e5JC1ymenNmzdv3pbIDTiZUitq5nbKkPsf1dg/Kf/Mz3WeF9UeuytwCrCry7F6jgdYCTwJ+GGX/eq3C4C3AFcb4Ps5fo5jzXU7YoBz7AU8jdJRs9v3PHP7M/C2QcbdOP4pjeMcP8f2J/X6mQOHUGoobuozztOAO03wmp15I1Yfw0+n+Bwa+3U4x/nuDXy7z3kuBV4N7Ffbp2ObAc5xcmOfk4Z83pxde2w9pUHVxX3G/FPggfP4mewFvAj4Q59znAU8AVjWY8ynjHr+EcZ7L0oNvG19xrsZ+CRwyyGPfUSv3zflTf4rgQt7nPP4MX1//a6HQ6rnweYeY9hKyTq+Uo9jX4MSsLq8x/6XA68F9p3HeOe8Fro8R4a6foBrdxn7cwfc90WN/b5erb9yl5/r0wY8ZnMsdxnh935VSgC96zUF3KHx2C7gmhN83jVfoy8AVk3q/AOM736U+phb+7wubAE+M8LrQvN6fVHtsZtQuns3f3czt+2UTN1rz/P7uxZltsz5fb6/mb8HTwBWTvt34s2bt+nenHYtSarr1iE0B9oxYkVEvJnSrfgOPY7Va99bAmcCrwduOMfmB1GyQX4REc8a9BxtiIgTKfW5Xk355L/f93wA8NeUcTenGLcuIu5EeRPwOEpmXS+3BP43Ip47kYGVjJArNdY1p/9PxCSvw4hYHRHvo7xJvHmfTfehBLd/FBE3GvY841Sd/0fAs+icitl0FPChiHhz1RBimHMcA/yE0uzqsD6bXgN4I/CViGhmP01ERBwcEV+mBA+OpwSue9kLOAE4LSLeX2UKzufct6A8n59JycqduIi4QzWGv6V8f92sAh4DfCcirtXY/4GUANIj6OwWXbeGkhn9zSqzfKHap8u6eWXMZeYfgDc0Vj+7yhKdhJPoLNF1DnBqbfmrdNZDDOCR7Q8LIuJwZr9GvzcXQJZ4RBwVEd+lTE+/I7Oz+utWUz68OC0i/rPqoD6fc/8D8B3Kh1q9XntXUBoYfT8i7jbCOVZFxOsppVEezdwZvkdRXqt/EhHXH/Z8kpYOg4/SFETEdSLiARHxxIh4dkQ8NiJOiIjrD/tGTRqzbm/uLhxw37dRAlt1OykZUj2nvEbECZTsgGv02OQSOmtfzVgDvCIi/mMaz5sq4PRRSnZK0y7KuLtNp1oDvC4imt3FW1MFHv+H2VPULqFkXnTzTxFxUovDmnGHLutOmcB5O0zyOqzeYH6E7l1yoWR7bWysuyrwpYg4cpBzjFtE3ICS3Xd446HL6B1keRyldt2g57gx/X8Hl1KyiOpuT7m2ewWvWhER16ZkCd+xxyYbKb/Hbh5CCZo2g+6DnvuGwBeZ/dqziXkGvIYYwy2Y/ZqyizK1s9vz5AjgMzNB1yrw+AFg79o2/f5eHAV8bAH/j9Ttg4HfdFk3rFdSrvsZVwL+YQzHHcRJjeV3Z+YVH0ZW95vTrB9ZTcNtW7e/G1+dwHn7qoJ536SUQehmI52/z7pHAV8c9YOJ6gPDf6XMJpixk97PyXXAJyLiekOc40DgS5TZAd3Kt23rc75rA9+IiGaDIEl7iIX6B1wLUEQsi4ijI+KREfH6iPhmRGyOiKzdjp/2OBeqiFgZEX8fET+iZNZ8mPKJ9suBt1OmY/0E+HNE/JcdhjUl3ZrLnDPAfg9kd+2nDcCLKbXgVmXmAZTAwE1ovFGJiKMob0Cb/2x/ilLkfU1m7k/JHDiK0hG5+Yb+McCze4zrF8Djq1vzTdIva4/1unWtVRYRT6RMPa2/yfoD8Hzg2Or73j8z11GyAv4K+HHjME+PiElkiRxKmWK1mvJG5D8pb9xWV2PcCziS8rNtBiJfExH7tzy+YxvLWynZUBMzgeuw6RWUzJS631OCdYdm5trMXE/JaHsk5W8GlMDDewc8xzjtBXyMUtuR6v49gHWZuW9m7g1chRIUuaSx73Mi4jpznaCqi/ZRZmfxfZGSMbguM/fLzDXA1YEnA+dV29ySkik5EdXf548zO0j6Y0oW3wGZuT4z11KyNx9PqbVWdwvgfSMGav6L3Zl2pwL3p0xL3rv6XRxICRz9cYRjD2Iv4IOU58sOyrTrW1Je9w6kvN7fjhKsrrsO8Mwq2HEyJUiymRJgO4YyLXPm78U9gWZdz1tRMq0Wogc3lndSgtPzkpkXUQJKdU8bNXA9qIi4PeXvQl23eo7vonN2xOHAndoaV82Nu6z77gTO21NEHEv5X37f2urNlOfH8ZTXsPWZuR8l8HcPStZ03e0pWYLDuhvl7xKU4OaL2P2cOpDyt+vmlNeOujXAWwc5QVXn95OU53bdqcDDgKtk5ura+W5EeY9Tb0y0L/Dhqia4pD3NtOd9e1scN0qGxkb61/QYW32hpXajfAL64wF+fvVb1/pI3rz1uzGPulWUT7HPauy/BdhrgPPUa/tcdcDzLaPUSazvvxN49Bz7XZsSEG3WMLrZHPudNOrPpnGcmzG7htN7gfUD/Hzf0thvE3DYHPudMszrbJfvc+b2J+A2c+x7x+pnWd/vyS1fsz9onO+MCT9nJn0d3orZtbi+0O/6oQSO39fj95oDfI/N5+tJc2x/fI9zbQLuN8e+N6C8+a3v928DjPHfu5yvb307yhvZr/UY6yktXjOv73K+t9GnphmlVubnu+z31DnOdUSv3zvwDxN6jvS6Hi4GbttnvxWU4Ep9nwuBr1f3zwau12f/vYHvN/b//gjjnfNa6PIcGfj6oXz41vzZfHqI/V/U2PfrXX4Ozbp6fZ9TXcYzVM1H4J39xtTY9tTGtu+dwDX5iea1OInnQp/x7Mvs/5++Bxw5wL6PogTx6/ved8jrtX7Ouf6neH6X/Y4ZYJyvbuyzGXjoAPtdi/IBWn3fj07z9+XNm7fp3Mx81KBuSvmUTkOKiOMo08iOrq3+HSUI8Q+Ufzr+jvLG6zT6TE+VWvYyZmfyfDYze00dbLoUuGsO3i35RGZPVXtmZv5nv50y85fAXeic0rwCmFSNwlfRWcPpg8AjMnNDj+0ByMwdlAyoT9dWr6XUM2vbDkrQ6P/6bZSZX6E0pKl7YGujKpp1/ebsjDtmJzLZ6/D5dM48+TlwYr/rJzO3UjIgpz2t8DGZ+bF+G2TmjymZz3V9r6GqZuPfNFb/a2b+2xznupRSL+3sftuNU1V3sDnWTwF/m5nbe+1X/X5PZHZW77MiYvUIQ/nXzHz1CPuN00Mz8+u9Hqxe855A5/81B1I68G4FTsjMn/fZfyOzXx9vPK3SA01V7bunMDsbeRuDZ0HPqfo5vKyx+gkRcbVxnaMuIvZm9nP2XX12ObmxfP+I2Hesg5qtWXLg4pbPN5en0Pn/05mUgO+v59oxM9/J7OvlOSOM4ffA3TNzroznlzE7q3iu1+hr0vlcTOBBmfn+uQaVmb+ivE7XO2yfWM04kLQHMfioUWylFDN+C9OZ/rVoVMXVP8/u6VGXUQqzH5GZj8/Mf8vMkzPzDZn51Mw8jvIP1XMp/7xKrYuI/SPiLcAzGg8ls4MI/bw0M88dYvvmm8ofAK8ZZMcq8PNPjdX3jYheteLGoqpzVq/xdjHwxMzMQfavtnsanW/G/3oCdcze2S9I0PC2xvKxLY+v2TyhVz2stkzsOoyIIyhT7er+LjO71QZtnmsmeD2tD6i+mJkfGHDbd9JZ8+tqc0yzO4nOmo1/ZMBp1FVQ72kDjmscHk/nhw+XA08Y5DWg+iCnWRf3EOAvhxzDnxiilmZLPpmZn51ro8w8h5Lp2PTmzPzRAPt/FfhtY/XNBhvivFw5Ih7X5fbEiHhORLynGtdrKZnJM3ZSgvRzfm9DegudJVBW016pgQfRWYtzC6V0Ry8forPW6F7MnoY+bns3li9p+Xw9VTUam03kHp+ZwwREX0PJnJxxy6oG7jCemZlz1ujOzF2UDvR1/ZqeATydzlqS78nM5pTxfuf8NSXJYkZQ3g9J2oMYfNSg3k35pP+mlKlht8jMxwP/O91hLVxVHad3sPuN9Qbgbpn5tuoPf1eZeX5mvjwzL+u1jTSEfm+gnh8RH6W8ger2T+DLMvOMAc+zndnZDz1FxD7AbRurX5+ZOwc9BiVDr16jcBmzAzvj9rDG8vsyc6hMvSpgVa9NdQClNlObmtmMPVWZa/XXn3VAKxk2lWbW10QaZsBUrsP70Pm/15mZ+aVBT5SZP6Vk0k/DMNfQxZROqHX9mhrcvbF88iAB2ZpPAsN88DEf92osf2SIbG8y85vAt+Y45lzeM0RGelvePsS23+myrhn86KdZy2/gBhnzcG3KNd+8vYGSOfZwSuC47kxK9v/YP5jP0sX5RY3VJ0XEdcd9LnbXb57x8SrLuKsqM/Ojcxxj3Ob9dyMift+oW9/vdkqfQ92dUo93xo+rWQQDqz5c+nBj9fFDHOIi+geIm77RWO75nKo+fGx+QPK6Ic41o1lv8vgRjiFpETP4qIFk5gsy8+2ZeXq/aUXjEMVNI+IREfEPEfH06v7Rc++9oDycUjh6xjMzs/mGQ2pbvzdQLwHux+wMAoB/z8znD3GeHw4ZhLsVnX+DktlvXvrKzEsozSjq2u6ieIfG8udGPM73GsvHjXicQVzC7ClWc/lNY3m/sYyku2b34kmW+Jj0ddj8PX98mHNVhhrfGJ065PZnNZb367ZR9UHdLRqr/2eYE1XB4s8Ps88oqgynGzdWf2SEQzWDBMO+bg0V2GhB0j2bsZdm5uJFlPrAo+6/3xD7TkJSyikcPWzQaUjvoTOov5zdTUbGopqx02wo0m/Kda9tbtXytNrm7KBploZaCP8XfL0KYA5qoNfnyjHsbjYGcGFmNsc6p8z8GZ2N2m5YTfGXtIdYMe0BSDMiYj3wTOCxzP40eWabXwIvzMzmp2cL0ZNq93/FgN3kpCn7EfDsYabT1PYbxg0by78ecorSjO9SOuHOaC2DsAo8NMd98xHrbjWn5TbrHo7Tb/tlW/fQrD+4T9etxmMjnVNu264VVjfp67BZW3LoN3Aj7jNfl2XpujuMQa+hw+h845uUBkDD+v4I+wzrKGb/79wtq28uzUy+q0TEAUP8jMc9pXdYl1ZB90E1s9J+O2ipisrGxnKbr0ejCEqJkjXA89o6SWbujIjn0RnwfmBEHJuZp4/pNCc1lv/I7A9XuvkyJUh89caxnjmWUc3WvCYm+Xej6VaN5atERLO8wiCawdph/i84e8hzDfM3vvn9bRzx+4MSNN6rur8MOJjZv0tJS5TBRy0IEXErSgZIv5pQULK43h8R9wMe1nYW5qgi4hg6MzneMcKbf6lNOylTay8Bfgl8G/jcXA1J+hi2SciBjeVmpt2gmsXcm8cdp4OZPWNgXDW32hz3JSPs05x2vLzbRhFxR2CYaX+fzMw/NNb9kc4pawcMcbz5mvR12Fx/9gjnGmWf+bpkhH0GuobozKgB2FBN4xzWXE0WxqH5+9s+ZJ3bGd2aUBxIyQgcxKSbMjUNWxameS3Md/9e19I4nZqZx9dXVFm6ewNHAnej1IudaXyyDHhuRKzKzGb95LHJzI9GxHfYXaMvgJczhpIj1fTav2qsfu8gZSgyM6s6mPWGW4+IiOcMWcZiUH+glIKaMcrfjWfRfeYHlAZfzaBbL4c2lh9S3eZrmP8LLhnmwFUgu76q32zI5vd3BEOU4ZjDgczOwpS0RBl81NRVb14/Ten6OuPMat2vKUXrrwv8Bbvrjj2Ikh3RdkHrUd2tsTzUFDJpjGa9gWrJsMGCZsBh1CYjzf3aDFy1GSBcO/cmIxsmw2hYj6xug/o55U1j3a/pzEA8KiJWTujDpUlfh83zjVLbd9INeaDda2i/xnLfrvF9TKJOclvXCwzx2jVicHac5ns9tHk9tabK1txAycw9IyLeTCmDcJfaZv8YEd/NzGHq7w3rOXRmI949Iu6QmcOWRmi6C7Pr+w4y5XrGyXQGHw+jBEWHnUkxiFnThiPiKsN8GNCvNmeVFDFo8LGt/w2G+b+gzefUYv3fR9ICY81HTVVEHEwpQDzzx2cL8BjgqMx8ema+uao1+XRKALI+dfkvIuIRkx3xwOpZjxuAHwNExLER8YaI+ElEXBYRGyPiNxHxsYj4m4jYq/vhpCUnGsvj+se5zX/AV829yciaP489SXO64GpmT4duy7Svw0UZhBmzZs3PUZ9nbT4/Z7R1vYz7WJqAqtP6/Zldv/JNEXGlLruM67xfYnbjqVeM4dCP7rLux4M2ZaHMomhqq/HMGV3WzdWxuS1tvfYslP8L/N9H0lgYfNS0vZLdU613AffLzP/sVgsoMy/PzMfRWevmpdU0kYXmJrX7vwTWRMQbKXWenghcn9IFex1l+sKJlMDqWRFx/4mOVJqO5vTC/UY8TrPO0yj1+gbVnOqYwLrMjDHcTmpx3Atdt2ydO07o3JO+DpvrR6lTNs3aZm2Y9TOJxnzAAe03hrHMpXm9jPq76LZfm69dakkVgHwU5X/YGQcy5kYwXTy7sXxcRJzQdcsBRMR+wP+b14i6OyEi2sic+2qXdbfvsm4Smv8b3HNM/xccMY1vpovm9/fBMX1/kZmnTOMbkjQdCzFooz1ERBwKPKy26j8yc5AOcU8GZqbjHQ7ca9xjG4ODavcvAD4EPIHdn/BtA37P7BothwIfjointD1Aacqa/8weMeJxrjnHccfpT43l6HL+PUpmnjSGNxrfYPbv7THtjx66nPeIEY8z6HU4jvONss9Cdj6dWX+rGO151WZn3RnN39+qiLhy1y376/b9TbuOo0aUmd+mdKKue2zVObrNc368sfpl8/hA/qF0Nv4al1V0/q8/Fpl5FrMzTh8eEZPIgG5q/m/Q2u99Spb69ydpQgw+apoeSGcq/2sG2alqVvCl2qq7jnNQ81VlbKyvrbozuwOkZwL3A/bJzKtl5v7A9YB31g8B/FtE3HkS45WmpNmt9VpV5sWwbtZY/sFow5lb1QX5nMbq49s6354iM3cw+437UVU94LZN+jpsrr9p1636G2WfBauqX/jzxupBa63Nd59h/YxSh7qu+bsfRHOf34/QTVwLywspHyzPWAE8v+VzPo/OjMsbMnqjk+b06C8Ajx/x9ok5jj0u72wsH0SpCT9p328sHz+FMbSp+f3daMS/k5L2cDac0TTdrnb/rMxsvvno59vAPav7t+y1UURcdZSBDejSarpN0zo6A/srq6+nA3fKzI5C85l5JvDoiPgZ8Kpq9TLgtRFxTLcp6NIScBrlTdPMcyUogfnmm4meImJfZn/48I0+uzSDBqN0TP0i8Nja8oOBN4xwHHX6d0p2eP0DqTdGxLGZuWUcJ4iIZZm5q7F60tfhNykZRjNOpHRcHcZSLM3xf3RmLj4MeN+gO1fZh8ePeUyzZObmiDiDzuDhA4BPDnmov2gs93vd0iKQmedExLuAv66tflhEvDQzf9XSOX8SEe+ls0P1SyJiqGY3EXEDZgfE/zkzm3UlBz3ed+icwn3jiLhxZp4xyvH6eAelwc1+tXWviojPTjiY/0U6G6/dIyL2bf6/v4h9A9hEeX8DJX7wAMrPX5IGZuajpulGtfs/GXLf82v3+wUYf9fi7Yk9znl5l3W7gIf3+0ckM/+FzgLiN6Czg6K0ZGTmZcDXGqufNOSUsccB9SZNu4B+pRuaHxaMUq+t+abuthFx9xGOo5rMPBt4Y2P1UcC/jOP4EXE/OoPGM+ed9HX4aTozla4bEQO/zkfE9YE7DTG2xaIZaLx71W12UC9gtA8TRtHs3PvAYaZeR8QtmZ2l2UY3YE3ey9ldFgjKNdl29mMz4/KadHmtm0MzM/E84JRRB5SZ32N285mxZz9WsxFe1lh9ZeBtE64H/1mg3oF+HcN/qLRgZeY2ZmezPj8iVk9jPJIWL4OPmqZ6AeoTBu2mV3XUe1Nt3/0nPO6+MnMnpWt33Rcy82cD7P7axvKCmlIujdnrGsvHUmq6zikijmT2m7pPZOZv+uz2x8bytYatD5WZX2R2ltI7I+LqwxynbsTmGkvR85j9hvVJEfHGUd9IRsSaiHg1pVHZ2h6bTew6rIKszcDk6yKi19jq51oBvJkl+L9bVQu0/iHkMuDkiDh4rn2rwPLftDS0bt5CZ7BnLSVLd87ncUSsqfavOx/4wPiGp2mpnt/vbqx+WMu1H88G3tZYPXDAs3pdeXhj9Qe7ZIkP678ayw9rqR7jaygZ7HUPAD4SEXt12X7sqizL5t+Rf4yIe3bbfhAL8P+ClwI7a8uHM8/MxwX4PUpq2ZL7B1aLyn5jOs6cb9qm4LLG8lcG3O9UOgvvHzue4UgL0seZXQPvXyPiEf12qgI+X2L3FCAoU6qbGRBNP6Jz6vVeDJ8hAvAPdGa3HAb8X0QM1WkzIq5ZBcbePsIYlpzM3EyZqtfMEH8C8JWIGLi2XkSsiIiHURoSPI3dzb66+TiTvQ7/ic7sx6OAj0fE3n3OtQp4F9Pr5joJT6bz7991ga9GxK27bVz9jp8OfJDy+x3L9Py5ZOZ5zA72nEgJQPYsZ1T9fj8K3Ljx0CuqzCItDS+n8+/McsoHK236J8q02BmHDbHvvYFmkL8ZOBxF8xgHAiN34+6l+sD/AZQmjnUnAt+OiPsOe8yIuBFwkyF3+1fgt7Xl5cBHI+LxQ557v4j4e+BbQ56/VVVprDc1Vj8sIj4aEQcMepyIWBYR94iI/2F3+SxJewiDj5qmzbX7FwO/nsetqyG7sA57e2Wf7605pt923Wr2eC+rfhYzDuq1rbTYVZkVD6HztWA58O6I+FhE3GVmWk8U142IFwM/ZHbH3xdWU736ne9y4PON1W+MiC9ExIsj4kkR8bjGbX2X45xGKapfd1Xg1Ij4fEQ8NCIOr3+qX/3DfdWIuFdEvCgiTqe8TjyN8X0Qs+hVGeJ3p/N1EErQ7dvVz/fJEXH9ZrAuIg6u3tS8GjgbeC9wjQHOOenr8JvA6xur7wr8NCL+pp7tFxH7V0HQH7C7VmQzy2dJqOrLNX8u16UE9r8dEa+MiKdExLMi4j8o5U/+hVJXeQfwkuYhWxzuM5jdaffxwHer5/9+Mysj4pCI+BtKZmfzzfYXmJ0xpUWs6sL83sbqh7ec/Xg+pW7uKJrToX+TmfMOfFXBqjPmONdYVI0o78TshnA3AD4REWdExEsj4viIuFJEdJRoiIgDIuLWEfEPEXFqNe5jhhzDxcB96QwCrwHeFBE/joi/i4gbdjn3gRFxh4h4akR8AbgA+DfK9PmF5mnMTqa4H/CbiHhtRNw5IvapP1jNPjgmIh5WvW7/kTJN/Z4Yh5D2ODac0TRdCMz8kfpQZv7tNAczZj8BjqstD5ORUd92zXiGIy1MmfmziPhLSvZSfYrUidWNiLiEkl22ku7eAfT7MKDuZZTgVv3v313pXeLgc8yuFUlmvqOa0vVvjXHdrboB7IyIS6vH96Z/9p0qmfmtKPX+PkTnG8Cg8+dLRGyh1Nram/6vl5uAX/Q556Svw2cC16NcizOuBrwVeGtEbKZMcWsGvy+kTJFspYHFAvD3lN/loxvrb17dutlFmXZ9dmN9a5mQmXl5RJxI+TCjHuC+EVX9yojYQAli95qd8R3gYTaVW5JeBjyC3XVIZ7IfT2rxnK+iBMAHLkVUfdBxr8bqcWQ91o9149ryPSLisMxslkCZt8z8ZZR6qu9h9t/zG1W3mQzUrF7Pg/J6M9f74a9QZjzMNYYfRMS9KLWhD6k9dDS7P2TIiLiM8rq1D5OrVTtvmbkjIu4PvJ/OD1L2AZ5S3Yb5uyxpD+MnDpqmenfro6c2inac0VgeaEpClSlV/8fxz+MakLRQZeangDsCveo17kf3gM8W4NmZ+dhB61NVWWePYHZphKFl5huAOzA7A2rGcspzfz29A4/bGb7h1pKXmb+gBJueDVzSZ9M1wJXo/QZnK6VO4rUy83/mOOckr8OtlKBmr660a5kdePw9cNfM7Jntv9hVP7/HUhq6XTLALn8ATsjMd7L7w8wZg+w/ssz8JeVDxl5lVdbTO/D4X8DxmXlhG2PTdFXdrd/fWN129uOlwD8PudvDmf2aNs7g4wfozEBeTvn724oqA/TulO7fZ/fZdOZ/7f3oHXhM4OvA/8vMO2Xm9wccw1cpU7Z7/b0JSrO7/ekfeBzofJOWmZdQpuo/h84mO3Vz/V2G0tRo7EFoSQubwUdNU/0f9ltFxJWmNpLx+1Rj+cYD7nddOrNuzhrLaKQFrprmdT1K3bcfz7H5nyg1164zR/mDXuf6AHAk8CTgY5SMuEvorOM46LG+SZnadX/gi3RO3e3lUkrnyCcAV87MFw573j1BZm6rfr9Xo3SU/jKdjT562Uzp1Po44NDMfEJVp2+Qc07yOtySmQ+m1Ln8bp9NL6M0VbhhZp4x7HkWmyzeBFyL8hz5AiWQsIUSTD6H8vx5DHBkLajcrFvXnLrfxljPz8w7AfehXHP9XkMup3Q7Py4zH1rVONXS9U90NuiYRO3H11EC8oNqToP+SWbO9bo3sMz8LbMbtLUy9bp2zszM91BeP+5HCYAO+lqwkRJwfAHlteV2mfnJEcbwx8y8N3BTSib0nwbYbRvw1erc183MBdtwsvoZvwK4OmW8P2GwMhe/pDTcuidw1blKlEhaesLZHpqPiDgJeGdt1R2zdK0cZN8jKH+IZj51fFVmPnOc45umiPgWcItq8VzgiMzc0WcXIuKFwItqqx6Tmf/ZzgilhSsirkJ5/hxMKVS/ifIP/JnA6Qt1qmKUxiA3o3SCPJCSWbGFMnX7d5Tx/2bQDDl1qmovHkMJHh9KmQa9k/Lm8mJKRv1PsjQhGMf5JnYdRsQ1qnNdmfIh1MWUrNpvpg1J5hQRb6ezgdTfVdnJkxzDeuA2lN/hQZRr80+Uus/fzMyJNMWRtFs1q+jalA/4r0bJSl5JCTZeQnmt/SXw8zb+Nlfnv351O6C67aL8X3A+5QPQX1QZ8YtSlUBS/1u5ht0/319RfrbO5pL2cAYfNS/zCT5W+7+b3VMwdgD3zswvDLF/ACsX4huziHgQnVPqnpmZr+qz/dUo3Xj3rVZdRglYtp69IUnSYlUF/X9DCfrNuGVmfntKQ5IkSVKN0641bc9gd82PFcCnqm5zfQsUR8RhEfF3lCyXY1se40gy80PAN2urXh4RXZvqVHWAvsjuwCPAqw08SpI0p8fQGXi8kNm1lyVJkjQlZj5qIFV3s25Ze+vprLP0B0pdo6ZnZOZHexz7OEpH2Xqx+AspXSTPAC6i1MrZD7gOJdh4E3Y3cDguM08b8FuZqGpq+Tcp0wNnfJ9Sr+p3lE5wt6LUi1td2+Z/gbuPa+qgJEkLXUSsGnYmQ/U/xP/SWS/5lZn57LEOTpIkSSMz+KiBdJlePaxHZebJfY5/FPBxSnBxWLfIzO+MOK7WRcSNKN/bEQPu8lHgrzJzU1tjkiRpoYmIE4HnAm8APtkv+z8i9qF0xX4RsKr20KXA0Zl5bnsjlSRJ0jBWzL2J1L7M/FlE3AB4NKXL6PXn2OWnwGeB9y707p+Z+YOIuCHlDdIjgV5dvX8MvAz44EJtpiFJUstuBpwM7IiI7wI/pHS4vowyQ+BAygyI21IaDjX9jYFHSZKkhcXMRy1IVYfRWwGHAPsD2yjd6H4N/Dgz/zTF4Y0sIlZQOmFek/K9baV0uvtmZv5mmmOTJGmaqszHj424+3ZKh+u3jm9EkiRJGgeDj5Wqa/KRwA2Aq1HqD26m1Bv8AfCjSdffi4hlwK2rcR1GmUp0LvA1G5FIkqSlpKrf+ClKduMwvgI8Z6HWf5YkSdrT7dHBx4hYD5wA3Be4E3BQn80vptQ8/NfM/GOf7cYxrhXAM4En0Nm9ccY2yj/nT8/Ms9sciyRJ0qRU/wPdHrgdcFPgGpT/hdZRygVdSvlg+FfA14DPZebp0xmtJEmSBrHHBh+rwOMFwJohd70IeGxmjjotqK+IOAT4NKXm0VwuozQm+UQbY5EkSZIkSZLmY08OPu5HyWasOws4FTgTuJASmLwh8AA6m4TsBB407gBkROxFmTp0y9rqc4H3UmodHgjck5IRMGMLcKfM/OY4xyJJkiRJkiTNl8HHkj34TuA/M/OHPbZdC7wW+Ova6ouB62TmhWMc078AT6+t+jDw8Mzc2tjuoZROkCurVb+rxrJlTOM4D1hbHVeSJEmSJEl7rqsBmzPz0FF23pODj3sDzwX+JTMvGnCf9wEPra16YWa+ZEzjuSrwS3ZPA/8hcLPM3N5j+2cBr6itenpmvnpMY7ls9erV64888shxHE6SJEmSJEmL1K9//Wu2bt26ITP3GWX/PTb4OIqIuDLweyCqVd/JzFuM6dgvA55TW3WPzPx8n+1XAGcDV6lW/T4zrzamsfzk+te//vV/8pOfjONwkiRJkiRJWqSOPvpofvrTn/40M48eZf9l4x7QUpaZfwB+Vls1ztTA+9XunwN8YY6x7KBMF59x1YgYpEmNJEmSJEmSNBEGH4e3sXZ/3TgOGBHXAI6qrfpSDpaS+sXG8n3GMR5JkiRJkiRpHAw+Du+I2v3zxnTMGzWWTxtwv28DO2rLx4xnOJIkSZIkSdL8GXwcQkTcFji4tuqbYzr0UY3lXw2yU9Xd+g+1Vdcf03gkSZIkSZKkeTP4OJxnNJb/e0zHvWZj+bdD7FvftnkcSZIkSZIkaWoMPg4oIh4CnFBbdQbwiTEdvtmq/KIh9r24dn9lRKwew3gkSZIkSZKkeVsx7QEsBhFxNPC22qodwF9n5q4xnWLvxvKWIfa9vMuxtg6yY0T8pMdD4+ziLUmSJEmSpD2UmY9ziIjDgM/QGSB8VmZ+d4ynWdNY3jbEvs1A417zHIskSZIkSZI0FmY+9hERBwCfBw6vrX5bZr56zKdqZjqu6rKul+Y062YmZE+ZeXS39VVGpM1rJEmSJEmSNC9mPvYQEfsAnwNuWFv9PuDxLZxuY2O5mQnZTzPTsXksSZIkSZIkaSoMPnYREXsDnwVuXlv9YeCRY6zzWHdZY3n/Ifbdr3Z/e2YOVO9RkiRJkiRJapvBx4aIWEup8Xjr2upPAg/NzJ0tnfY3jeWrD7FvfUr4WWMYiyRJkiRJkjQWBh9rImIv4FPA7WurPws8KDO3t3jqnzaWrzXIThGxBrhyn+NIkiRJkiRJU2PwsRIRq4GPA3eqrf4ScP/MHKb79Ch+0Fg+bsD9bkFn06AfjWc4kiRJkiRJ0vwZfAQiYhXwEeButdVfAe6bmYN2nR5ZZv4G+Hlt1V0iIgbY9a6N5U+Pb1SSJEmSJEnS/OzxwceIWAF8ALh3bfXXgBMy8/IJDuVjtfuH0xkInaUa96Nqq84FvtvCuCRJkiRJkqSR7NHBx4hYDrwXuF9t9TeAe2Xmpnke+4iIyNrtlDl2eTNQ71T9qohY2Wf7pwNXqS2/NjNzxOFKkiRJkiRJY7fHBh+rac3vAB5cW30acI/M3Djp8WTm74A31lYdA7yvqkXZISIeAry4tupc4A3tjlCSJEmSJEkazoq5N1mybgs8srHu6sD3Byu3eIU7ZOa5YxrT8ymdtm9WLT8IuHVEvAc4C9gfuBdwh9o+W4G/nERtSkmSJEmSJGkYe3LwcXmXdVce4Tj9pkYPJTM3R8QJwGeAY6vVVwGe1WOXDcAjM/Pr4xqDJEmSJEmSNC577LTrhSozzwNuBbwAOK/HZtsoDWpulJkf67GNJEmSJEmSNFV7bOZjZp4CDDW/esjjnz3q8TNzO/DSiHg5cGvgWsAhlEzH3wNfy8yLxjRUSZIkSZIkqRV7bPBxMcjMncDXqpskSZIkSZK0qDjtWpIkSZIkSVIrDD5KkiRJkiRJaoXTrqUWvf9bvx3r8R56y6uP9XiSJEmSJEltMvNRkiRJkiRJUisMPkqSJEmSJElqhcFHSZIkSZIkSa0w+ChJkiRJkiSpFQYfJUmSJEmSJLXC4KMkSZIkSZKkVhh8lCRJkiRJktQKg4+SJEmSJEmSWmHwUZIkSZIkSVIrDD5KkiRJkiRJaoXBR0mSJEmSJEmtMPgoSZIkSZIkqRUGHyVJkiRJkiS1wuCjJEmSJEmSpFYYfJQkSZIkSZLUCoOPkiRJkiRJklph8FGSJEmSJElSKww+SpIkSZIkSWqFwUdJkiRJkiRJrTD4KEmSJEmSJKkVBh8lSZIkSZIktcLgoyRJkiRJkqRWGHyUJEmSJEmS1AqDj5IkSZIkSZJaYfBRkiRJkiRJUisMPkqSJEmSJElqhcFHSZIkSZIkSa0w+ChJkiRJkiSpFQYfJUmSJEmSJLXC4KMkSZIkSZKkVhh8lCRJkiRJktQKg4+SJEmSJEmSWmHwUZIkSZIkSVIrDD5KkiRJkiRJaoXBR0mSJEmSJEmtMPgoSZIkSZIkqRUGHyVJkiRJkiS1wuCjJEmSJEmSpFYYfJQkSZIkSZLUCoOPkiRJkiRJklph8FGSJEmSJElSKww+SpIkSZIkSWqFwUdJkiRJkiRJrTD4KEmSJEmSJKkVBh8lSZIkSZIktcLgoyRJkiRJkqRWGHyUJEmSJEmS1AqDj5IkSZIkSZJaYfBRkiRJkiRJUisMPkqSJEmSJElqhcFHSZIkSZIkSa0w+ChJkiRJkiSpFQYfJUmSJEmSJLXC4KMkSZIkSZKkVhh8lCRJkiRJktQKg4+SJEmSJEmSWmHwUZIkSZIkSVIrDD5KkiRJkiRJaoXBR0mSJEmSJEmtMPgoSZIkSZIkqRUGHyVJkiRJkiS1wuCjJEmSJEmSpFYYfJQkSZIkSZLUCoOPkiRJkiRJklph8FGSJEmSJElSKww+SpIkSZIkSWqFwUdJkiRJkiRJrTD4KEmSJEmSJKkVBh8lSZIkSZIktcLgoyRJkiRJkqRWGHyUJEmSJEmS1AqDj5IkSZIkSZJaYfBRkiRJkiRJUisMPkqSJEmSJElqhcFHSZIkSZIkSa0w+ChJkiRJkiSpFQYfJUmSJEmSJLXC4KMkSZIkSZKkVhh8lCRJkiRJktQKg4+SJEmSJEmSWmHwUZIkSZIkSVIrDD5KkiRJkiRJaoXBR0mSJEmSJEmtMPgoSZIkSZIkqRUGHyVJkiRJkiS1wuCjJEmSJEmSpFYYfJQkSZIkSZLUCoOPkiRJkiRJklph8FGSJEmSJElSKww+SpIkSZIkSWqFwUdJkiRJkiRJrTD4KEmSJEmSJKkVBh8lSZIkSZIktcLgoyRJkiRJkqRWGHyUJEmSJEmS1AqDj5IkSZIkSZJaYfBRkiRJkiRJUisMPkqSJEmSJElqhcFHSZIkSZIkSa0w+ChJkiRJkiSpFQYfJUmSJEmSJLXC4KMkSZIkSZKkVhh8lCRJkiRJktQKg4+SJEmSJEmSWmHwUZIkSZIkSVIrDD5KkiRJkiRJaoXBR0mSJEmSJEmtMPgoSZIkSZIkqRV7fPAxIpZFxNER8ciIeH1EfDMiNkdE1m7HtzyG4xvnG+Z2szbHJkmSJEmSJI1qxbQHME0R8RHg7sC6aY9FkiRJkiRJWmr26OAjcFMWZuDxHGDHgNtuaXMgkiRJkiRJ0qj29OBj3Vbgh8D3gL2Bh09xLMdn5tlTPL8kSZIkSZI0b3t68PHdwO8oAccfZeZ2gIg4iekGHyVJkiRJkqRFb48OPmbmC6Y9BkmSJEmSJGmp2uO7XUuSJEmSJElqh8FHSZIkSZIkSa0w+ChJkiRJkiSpFXt0zccF7OURcX3gcGAdcAlwHvBN4PPAJzJz5/SGJ0mSJEmSJM3N4OPC9JDG8kHV7YbA3wBnRcTTMvMTEx+ZJEmSJEmSNCCDjwvXxcBllMzHA+icIn9N4OMR8fLMfO6oJ4iIn/R46MhRjylJkiRJkiTNsObjwvFn4PXAPYADM/OAzDwiMw+iBB/vD/xfY5/nRMRTJjxOSZIkSZIkaSBmPi4M3wOumplbuj2YmZcCH4uIjwPPBV5ae/ifI+Kjmfm7YU+amUd3W19lRF5/2ONJkiRJkiRJdWY+LgCZuaFX4LGxXWbmPwFvqa1eDTyjtcFJkiRJkiRJIzL4uDg9D7i8tnzCtAYiSZIkSZIk9WLwcRHKzD8Dp9ZWHR4Rh01rPJIkSZIkSVI3Bh8XrzMbywdPZRSSJEmSJElSDwYfF6/LG8trpzIKSZIkSZIkqQeDj4vXIY3lC6cyCkmSJEmSJKkHg4+L1+1q97cD505rIJIkSZIkSVI3Bh8XoYi4J3Ct2qr/y8zN0xqPJEmSJEmS1I3BxxZExBERkbXbKX223WvIYx8GvLWx+uThRylJkiRJkiS1y+Dj9D04Ik6NiPtGxKp+G0bEXYBvAVerrf4B8J42ByhJkiRJkiSNYsW0BzBNEXF/4FVdHlrfWH5fRDS7SwM8IzM/Ooah3L66XRIR/wf8EPgjsIHSxfoawF2BGzX2Ow84MTN3jWEMkiRJkiRJ0ljt0cFHYB/gyAG2u3Kf/cdpP+De1W0upwEPz8yzxzwGSZIkSZIkaSycdj193wXeCfwMyDm2TeAbwMOB22bmr1semyRJkiRJkjSyPTrzMTNPpoVmLVU2Ygy47Y+BRwNExH7ATYCrA1cC9gK2ApcAZwPfzsxLxz1eSZIkSZIkqQ17dPBxocnMS4CvTHsckiRJkiRJ0jg47VqSJEmSJElSKww+SpIkSZIkSWqFwUdJkiRJkiRJrTD4KEmSJEmSJKkVBh8lSZIkSZIktcLgoyRJkiRJkqRWGHyUJEmSJEmS1AqDj5IkSZIkSZJaYfBRkiRJkiRJUisMPkot+945F/GpH/yBizdvm/ZQJEmSJEmSJmrFtAcgLWV/uORyPnL6uQBcsnkbjzjuiOkOSJIkSZIkaYLMfJRadO4ll19x/6wLN7Erc4qjkSRJkiRJmiyDj1KLLtuy/Yr7W3fs4pLN2/tsLUmSJEmStLQYfJRatGHLjo7lP156eY8tJUmSJEmSlh6Dj1KLNlzemen4x0u3TGkkkiRJkiRJk2fwUWrRZbMyHw0+SpIkSZKkPYfBR6lFG7Z0Zj6e57RrSZIkSZK0BzH4KLVk566cVfPx4s3buXzbzimNSJIkSZIkabIMPkot+fOmrWSX9edd5tRrSZIkSZK0ZzD4KLXkgsu2dl1vx2tJkiRJkrSnMPgoteT8HhmO59l0RpIkSZIk7SEMPkotOb9n5qPBR0mSJEmStGcw+Ci1pJ75eNDeqzvW79zVrRqkJEmSJEnS0mLwUWrJBRt2Zz4eefDeLItyf8eu5MKN3bMiJUmSJEmSlhKDj1JLLqhlPh6wdiVXqmU/OvVakiRJkiTtCQw+Si05f8PuAOP6vVZy2L5rrlg+z47XkiRJkiRpD2DwUWpJveHMPmtWcti+e12xbOajJEmSJEnaExh8lFqwY+cu/ryxHnxc0ZH5aPBRkiRJkiTtCVZMewDSUvTnTduoN7Rev2Ylq1bsjvVv3LqDDVu2s37NyimMTpIkSZIkaTLMfJRacH6t2czqFctYtWIZ69esZO/Vu+P955n9KEmSJEmSljiDj1ILmvUeZzj1WpIkSZIk7UkMPkotqGc+rt9rd7ZjZ/DRjteSJEmSJGlpM/goteCCDd0zHw+147UkSZIkSdqDGHyUWnBBLfNxnzXdMx8v3LiV7Tt3TXRckiRJkiRJk2TwUWpBx7TrWubjlfZezYplAcCuhAtqtSElSZIkSZKWGoOPUgvqDWfW1zIfly8LDtnHuo+SJEmSJGnPYPBRasEFG+rTrld2PHZovenMZdZ9lCRJkiRJS5fBR2nMtu/cxZ83bbtieZ+9OoOPHR2vLzH4KEmSJEmSli6Dj9KYXbhxK5m7l+vTrgEOq3W8Pu+yy8n6xpIkSZIkSUuIwUdpzOr1HvdauZyVyzufZofWaj5u2b6Ly7bsmNjYJEmSJEmSJsngozRmnZ2uV8x6fK9Vy1lVC0hu3mbwUZIkSZIkLU0GH6Uxu+Cy3s1mZuy1avkV9y/ftrP1MUmSJEmSJE2DwUdpzC7YsHvadbfMRyjTsWdsNvgoSZIkSZKWKIOP0pjVp103O13PqGc+btlu8FGSJEmSJC1NBh+lMas3nDHzUZIkSZIk7ckMPkpj1tlwZoCaj2Y+SpIkSZKkJcrgozRm9ZqP+/TIfFy70oYzkiRJkiRp6TP4KI3Rth27uGjTtiuWB+p2beajJEmSJElaogw+SmP0p41bO5Z71nw0+ChJkiRJkvYABh+lMarXe9x/7UpWLO/+FNvLadeSJEmSJGkPYPBRGqMLasHHg9ev6bmdmY+SJEmSJGlPYPBRGqPzL9s97frgfVb33K6e+bh5245WxyRJkiRJkjQtBh+lMbpgw+7Mx0P26Z35uHbV7lqQW7fvYldmq+OSJEmSJEmahokEHyNi30mcR5q2eubjIQNmPiawxanXkiRJkiRpCZpU5uMfIuLdEXH7CZ1Pmop6w5l+mY+rVy4jass2nZEkSZIkSUvRpIKPewEPA74SEWdGxNMj4qAJnVuamAvqNR/7NJxZFsGalTadkSRJkiRJS9ukaz4GcG3gn4HfR8SHIuIeEx6D1JrzazUf+zWcgUbHazMfJUmSJEnSEjSp4OMrgT821q0E7g98JiLOjogXRMTVJjQeaey27tjJJZu3X7Hcb9o1dNZ9NPNRkiRJkiQtRRMJPmbmc4CrAycCnwRmIi1R3a4OvBA4KyL+JyJOjIjl3Y4lLVT1KdcAB+09eObjZjMfJUmSJEnSEjSxadeZuSszP5mZJ1KCjc8FftXYbDlwd+AjlGnZr4iIa09qjNJ8XFCbcn3gulWsWtH/6VXPfLTbtSRJkiRJWoomXfMRgMw8LzNfkZnXAe4IvB+YSRubyYY8BHgG8POI+EpEPDQi+qeSSVN0fr3ZzBxTrsHMR0mSJEmStPRNJfhYl5mnZubDgcOAJwNnNDYJ4PbAe4A/RMS/R8Qxkx2lNLfzL6s1m1k/d5x8rTUfJUmSJEnSEjf14OOMzLw0M9+QmccCNwPeClxWPTyTDbk/8CTg+xHxrYh4TESsm86IpU4XbNid+XjIHJ2uwW7XkiRJkiRp6Vswwce6zDw9Mx9PyYY8CTgfyOo2E4i8GfA24NyIeJ2dsjVtf6oFHw9eP8C0azMfJUmSJEnSErcgg48AEXEo8FTgecDBtYdyZpPq6z7AE4FfRMQ/RcTKiQ1Sqtm4ZccV9/fZa8Wc25v5KEmSJEmSlrq5IyQTFBHLgHsDjwXuSel+fcXD1dc/Au8DbgjctbZ+NfBs4OYRcc/M3DWRQUuVTdt2Bx/Xrhog+GjmoyRJkiRJWuIWRPAxIq4FPBp4JHDozOraJruAL1DqQH4qM3dW+10NeBzweGC/ap+7AE8A3jCJsUsz6h2r161e3mfLorPb9Y4+W0qSJEmSJC1OU5t2HRGrI+LhEfEV4EzgmZQajzM1HaFkOb4MODIz75mZH58JPAJk5u8y87nAtYEv1Q7/iIl8E1LNpq27A4jrhsx83L4z2bHLZF1JkiRJkrS0TDzzMSJuQplW/RBg35nVtU12AV+kkeXYT2b+OSJOAs6hTNU+apxjlgZRn3a9bvXcT63m1OzLt+1k/ZoFW4ZVkiRJkiRpaBMJPkbEvsDDKEHHG82sZnf3aihZju8E3p6Z5wx7jsz8Q0ScA1wTWDfvQUtD2rx1d5x87aq5p12vXB4sj2Bnlh5KJfhovyRJkiRJkrR0TCrz8Y+UhjCwO+hI9fULwNuATw6S5TiHDfPcXxrZxtq0670HyHyMCPZatfyK/Ww6I0mSJEmSlppJBR/X0JnleB7zyHLs44+UxjPSRO3YuYutO3bXbFw7QPARSt3HK4KP2ww+SpIkSZKkpWWSNR/HneU4+wSZ9xr3MaVBbG5kLa4bYNo1dHa8NvNRkiRJkiQtNZMKPr6c8Wc5SgtGvdM1zG4m00u94/VmMx8lSZIkSdISM5HgY2Y+bxLnkaZlU63ZzKrly1i1YrCu1WY+SpIkSZKkpWxS3a5vX929PDO/M4/jHAvsDZCZXx3H2KRx2Lxtd+bj2tWDTbmGRvDRzEdJkiRJkrTETGra9SmUmo+/Aq47j+O8AzimOtYk61VKfdU7Xa8bcMo1dE67NvNRkiRJkiQtNZMM4AW7u13P9zjSgrK5Nu163RCZj2vNfJQkSZIkSUvYYIXpxiMneC5pojbVp12b+ShJkiRJkgRMNvg4DjORmh19t5ImbNOImY/WfJQkSZIkSUvZYgs+HlZ93TjVUUgN9YYzo9Z83GzmoyRJkiRJWmIWTfAxIu4IHEiZvn3OlIcjdejMfBwi+FjLfNyybSeZVieQJEmSJElLx1gbzkTEMcCN+2yyPiL+aohDLgP2BW4APLi2/rThRye1p7Pm4xDTrmuZjzsz2bZzF6tXDL6/JEmSJEnSQjbubtf3A17Q47EADgbeOeKxZ7pcJ/CfIx5DasWmrbuDj3uPmPkIpe6jwUdJkiRJkrRUtDHtOubeZOjj1QOPL8zM7475HNK8bK41ixmm2/WKZctYtXz309CO15IkSZIkaSkZd+bjjF4ByGEDkzsozWXOpky1fmdmfmce45JasbGW+ThMt2so2Y/bLt8F2PFakiRJkiQtLWMNPmbmi4EXN9dHxC5K1uKvM/M64zyntBB0dLseYto1lLqPl16+HTDzUZIkSZIkLS2T7HY97unY0oJR73Y9TMMZ6Kz7aOajJEmSJElaStqadt00kw150YTOJ01UveHMuiFqPkJnx2szHyVJkiRJ0lIykeBjNR1bWrLqDWeGnnZdy3zcbOajJEmSJElaQiY57VpasjZtG73hzFozHyVJkiRJ0hJl8FEag/q067XDTru25qMkSZIkSVqiDD5K87Rtxy6278wrlveex7RrMx8lSZIkSdJSMraajxFRj5pkZq7o8dg4dBxfmqbNtSnXAGuHnHbd0XDGzEdJkiRJkrSEjDOAF0BWX4d5TFrUNm5tBB9XDhl8NPNRkiRJkiQtUeOedt0vuGjgUUtSvUP1mpXLWLF8uKdVPfOxmUUpSZIkSZK0mI0z8/FRIz4mLWr1ZjPrhmw2A53Bx63bd7Erk2VhrF6SJEmSJC1+Yws+Zua7RnlMWuw2bd2d+ThsvUfo7I6dlABkfSq2JEmSJEnSYmW3a2meNm2bX+bj6pXLOmoSOPVakiRJkiQtFQYfpXmqBwvXrR4++LgsgjUrbTojSZIkSZKWHoOP0jxtrE+7HnG6dEfH620GHyVJkiRJ0tIwzoYzYxURa4FrAyuBczLzT1MektTV5lrDmb1HyHyEzqYzZj5KkiRJkqSlYiKZjxGxd0Rcs7pdZY5trxwRHwQuBk4HvgWcFxFfi4hbTGK80jA2batnPo4YfKxlPm4281GSJEmSJC0Rk5p2/a/AL6vbP/baKCIOBU4DHkjJeIza7TbA1yPixLYHKw1j09Z6zccRp13XMh+3mPkoSZIkSZKWiEkFH+8LVzT0fWOf7d4EXLW6n43HkjJN/L0RcfXxDk8a3XwbzoCZj5IkSZIkaWlqPfgYEUcAh1KChz/LzF/22O5o4ER2Bx0vBP4euBfwHGBj9dhewIvaHLM0jE21hjPrRm04Y81HSZIkSZK0BE2i4cz1a/e/2We7R1RfA7gcOC4zz6rWfS4iTgO+XC0/KCKekJlbxjtUaXj1adej1nxca7drSZIkSZK0BE1i2nV9ivTP+mx3z+prAh+sBR7LysxTgFOqxbXAsWManzQvm7bZ7VqSJEmSJKmbSQQf96ndv7jbBhFxJeAGtVX/3eNYp9TuHzW/YUnjUa/RuHbUhjNmPkqSJEmSpCVoEtOuV9buN5vIzLgNuxvSbAdO7bHd72v395/nuBa0iFgG3Bo4EjgMuBQ4F/haZnYN4mo6Nta7XY847drMR0mSJEmStBRNIvi4oXb/wB7b3KH6msD3MvPyHtvVg5er5jswuCLIdxRws9rtRpTGNjPuWE37bl1ErACeCTwBuHKXTbZFxKeAp2fm2ZMYk/rbXG84M4Zu12Y+SpIkSZKkpWISwcdza/d71Wk8oXb/632OdUDt/saRR1SJiI8AdwfWzfdY4xARhwCfpgRAe1kFPAC4a0T8VWZ+YiKDU0/1mo9rx9DtetvOXezYtYsVyyZRFUGSJEmSJKk9kwg+nl59DeCEiDgwM/8882BE3IUytXjG//Y51nVq9/84hrHdlIUTeNwL+ASdgcdzgfcCv6Zkjd4TuH312D7AByLiTpnZr4u4WpSZHd2uR818bHbJvnzbTtavMfgoSZIkSZIWt9ajG5n5G+D7lCnT64BPRcQNImJ1RNwReCe7p1NfSP/g4y1r93855qFuBb4DvIUS8Ju0l9D5/X0YODIzn5WZb8/MV2bmHYCHUepiAqwBPhgRayY8VlW27tjFrloxgHUjNpxZuTxYHnHFslOvJUmSJEnSUjCp1KqXs7uhzC2BHwCbgS8BV6keS+C1mdk16hIRhwPHVItbgB+PYVzvBv6GkgG5PjNvkZmPp38AdOwi4qrAk2qrfgg8NDO3NrfNzPcDL6ituhrwxHZHqF7qWY8wesOZiOis+2jTGUmSJEmStARMJPiYmR+hZBTOBCCjdpvJG/s28Oo+h3nYzOGAb2fmjj7bDjquF1RZhadn5va592jN4ylZjDOeMcd4/pXOWppPbWNQmtumrZ1BwnrtxmHZ8VqSJEmSJC01Eysql5lPoATZzmk8tAV4M3CXzNzWbd+IWMXuzMAA/qetcU7J/Wr3zwG+0G/jKvD6ztqqq0ZEvyY1akm92cy6VctZtiz6bN2fHa8lSZIkSdJSM4mGM1fIzLcCb42IawCHUqZe/6xX0LFmf+DZteXPtTTEiat+FkfVVn0pM7PX9jVfBJ5XW74P8N1xjk1z21zvdD1is5kZZj5KkiRJkqSlZqLBxxlVE5rfDLH9+cC72hvRVN2osXzagPt9G9jB7t/hMX22VUs21qZdr1s1+pRr6Mx83GzmoyRJkiRJWgImNu1aPR3VWP7VIDtl5hbgD7VV1x/biDSwzbWGM+vmm/lowxlJkiRJkrTEGHycvms2ln87xL71bZvH0QRs2lbPfBzjtGszHyVJkiRJ0hJg8HH69mksXzTEvhfX7q+MiNVjGI+GsGlrvebj/KZdr7XhjCRJkiRJWmKmUvMxItZSah0eBewHrKN0sR5YZr5k/CObir0by1uG2PfyLsfaOujOEfGTHg8dOcQY9mgd3a5tOCNJkiRJktRhosHHiLgBpUPzfYH5ZuktleDjmsbyXJ2/65qBxr3mORYNaXNLDWfMfJQkSZIkSUvBxIKPEfF44LXVOWeyHJMhMx5r+y0VzUzHVV3W9dIM4DYzIfvKzKO7ra8yIm1gM4CN9WnXY6z5uNnMR0mSJEmStARMJPgYEScCb6wW64HDpNQ43DiJcSxQze99DYMHH5uZjnvyz3EqNtemXe89xm7XW7btJDOJGCU2L0mSJEmStDC0HnyMEj35t2pxJtPxv4C3At/OzGFqHC5FlzWW9wcuGXDf/Wr3t2fmwPUeNR71btfzbThTz3zcmcn2ncmqFQYfJUmSJEnS4jWJzMebA0ewO+PxMZn5zgmcd7H4TWP56l3W9XJ47f5Z4xmOhlHvdr1uvtOuGzUjN2/bwaoVq+Z1TEmSJEmSpGlaNoFz3Lh2/38NPM7y08bytQbZKSLWAFfucxxNQEfDmXlOu16xbBmrlu9+StrxWpIkSZIkLXaTCD4eULv/2Qmcb7H5QWP5uAH3uwWdmas/Gs9wNIxN2+qZj/Obdg12vJYkSZIkSUvLJIKPf67dv3gC51tUMvM3wM9rq+4Sg3UZuWtj+dPjG5UGVZ92vXaemY/QWffRzEdJkiRJkrTYTSL4eHbt/kETON9i9LHa/cOBu/XbOCJWAI+qrToX+G4L49Ic6g1n9p5nwxkw81GSJEmSJC0tkwg+nsLu7Mc7TuB8UxcRR0RE1m6nzLHLm4F6p+pXRcTKPts/HbhKbfm1mZm9NlZ7NtczH+fZcAbMfJQkSZIkSUtL68HHzNwOvBEI4G4RceO2z7nYZObvKD+jGccA74uI1c1tI+IhwItrq84F3tDuCNXNrl3Zkfk4327X0Jn5uNnMR0mSJEmStMjNP1oymJdSahQeB3wkIu6Ymb+d0Ll7ioj7A6/q8tD6xvL7IuLyLts9IzM/OqbhPB+4PXCzavlBwK0j4j3AWcD+wL2AO9T22Qr8ZWZuGdMYNIRmZuK6MUy7XmvmoyRJkiRJWkImEnzMzJ0RcU/gv4B7Aj+IiFcA78rM8ycxhh72AY4cYLsr99l/LDJzc0ScAHwGOLZafRXgWT122QA8MjO/Pq4xaDj1TtcA68bRcMaaj5IkSZIkaQmZSPAxIr5c3V0G7AL2BV4BvCIizgHOA4bJ3svMvPN4Rzl9mXleRNyKEnB8AnBol822UQKU/1B1ytaUbNq6Ozi4LGD1ivlXMegIPpr5KEmSJEmSFrlJTbs+Hqg3RElKDUiAIygdngcVjWONLDNPBk4ex7Eaxz2b3d/fsPtuB14aES8Hbg1cCziEkun4e+BrmXnRmIaqedhUazazbvUKIkb6lXfoaDhj5qMkSZIkSVrkJhV8hP7BuPlHbZaYzNwJfK26aQHaPOZmM2C3a0mSJEmStLRMKvj4rgmdR5qYeubj2jE0mwFrPkqSJEmSpKVlUg1nHjWJ80iTVG8400bm45btO9mVybIxTOeWJEmSJEmahvl3yJD2UJtrDWfWjSnzcW0tiJnA1u27xnJcSZIkSZKkaTD4KI1o49bxZz6uXrmsowDq5lp2pSRJkiRJ0mJj8FEaUT0wuHb1eIKPyyJYY9MZSZIkSZK0RBh8lEa0qdYQZu8xTbuGRtMZg4+SJEmSJGkRm1S361ki4l7AXYFbAlcF9gfWAr/KzOs2tl0J3KRa3JmZ35vkWKVuNte7XY9p2jV0Np2x47UkSZIkSVrMJh58jIiHAS8FDq+v7nEfgMzcHhHvBq5dHeMmmfnDVgcqzWFjveHMKjMfJUmSJEmSmiY27ToiVkTEfwHvpgQeo3aD0ty3nzfWtn14K4OUhlCv+bhuTDUfwcxHSZIkSZK0dEyy5uP7gAezO+C4Hfgs8GLgCdW6fgHID9Uev2d7w5QGU6/5OK6GM9DIfDT4KEmSJEmSFrGJTLuOiL8EHkQJHgbwEeDJmfnH2jZv6neMzDwvIr4PHAtcPyIOzMw/tzhsqa9NtZqPY512Xct83Oy0a0mSJEmStIhNKvPxxbX7b8nMB9UDj0M4vXb/BvMckzQvHcHHMWY+rjXzUZIkSZIkLRGtBx8j4vqURjEJ/A74+3kc7he1+0fOZ1zSfG3eVm8401LNRzMfJUmSJEnSIjaJzMdja/f/OzO3zuNYl9Tu7z+P40jzVs98XLu6pW7XZj5KkiRJkqRFbBLBx0Nq98+c57HqkZhV8zyWNC+bat2u926r27WZj5IkSZIkaRGbRPCx3sF6vhGaA2r3L57nsaSR7dyVbNm+64rlteNsOGPmoyRJkiRJWiImEXy8oHZ/vnUab1y7f/48jyWNrJ71CO3VfNy2cxc7du3qs7UkSZIkSdLCNYng409r9+8z6kEiYjVw99qq00YekTRPm7d2ZiSOt9t157HMfpQkSZIkSYtV68HHzDydkqUYwHUj4pEjHuqJwJUo07h/mpl/HNMQpaHVMx9XLg9WrRjfU2nl8mB5xBXL1n2UJEmSJEmL1SQyHwHeUX0N4E0Rcddhdq62f3lt1evGNTBpFB2drsc45RogIlhj3UdJkiRJkrQETCr4+EpK7ccE9gL+JyLeFBHX6bdTRBwQEa8APk3pbp3AL4D/bHm8Ul+batOux9npesZaO15LkiRJkqQlYPxRky4yc2NEnAh8iRJ8XA78LfC3EXEW8JPa5gdExJuBo4FbVdvOzEG9DDgxM43GaKo2b6tnPo6v0/UMO15LkiRJkqSlYFKZj2TmacAJdHa/DkoH7BMoWY0A+wN/A9yGzuDoH4F7ZuaZ7Y9W6m9jfdp1C5mPe5n5KEmSJEmSloCJBR8BMvMrwDHAycD22kPR2DRq63YC7wGOrQKY0tRt3lafdt1u5uNmMx8lSZIkSdIiNZFp13WZ+Sfg0RHxbOCBwO2AGwEHAvsBm4ELgZ8DXwE+nJnnTHqcUj9tNpyBxrRrMx8lSZIkSdIiNfHg44zMPB94Y3WTFpV6w5l1bdR8rE273mLmoyRJkiRJWqQmOu1aWirqDWfWtdHt2mnXkiRJkiRpCTD4KI1gU8vBRxvOSJIkSZKkpWCi064jIoBjq9uVgAOAfYBLgYsotR6/m5lnTHJc0rDq067XtjHtul7z0cxHSZIkSZK0SE0k+BgRtwCeAdyZEmyca/tLgC8C/5KZ32t3dNLw6g1n9jbzUZIkSZIkqatWp11HxCER8Wngm8D9gH2BqG5dd6lu+wMPAr4dER+PiIPaHKc0rHodxla6Xa/szHzMzLGfQ5IkSZIkqW2tBR8j4ijgNOCe7A421iMo0eVGY7sATgC+GRHXaWus0rA2bq3XfGx32vXOTLbvNPgoSZIkSZIWn1amXUfE1YCvAgdSAolJCSRuomRBfgM4G7gY2AisB/YDrgkcB9wKWMfuIOQ1gVMj4maZeW4bY5aG0dHtuo3Mx0Ydyc3bdrBqxaqxn0eSJEmSJKlNbdV8/E92Bx4D+BPwz8DbM3PDXDtHxD7A3wL/SGlMk8AhwH9QMimlqepoONNC5uOKZctYtXwZ23buAkrdx/3GfhZJkiRJkqR2jX3adUTcndJYZiZr8TvAsZn5b4MEHgEy87LM/BdKV+zvsntK9t0i4s7jHrM0rE0tZz5Co+O1TWckSZIkSdIi1EbNx6dUXwP4HXD3UadKZ+bvgXtUx5kJZj51vgOU5mtzLfNxXQvdrmF20xlJkiRJkqTFZqzBx4g4ALhLtZjAYzPzkvkcMzMvAh7L7qY0d4uI/eZzTGk+tu3YdcV0aGin4Qw0Mh8NPkqSJEmSpEVo3JmPd6TUkUzgR5n5pXEcNDO/CPyoWlwB3Gkcx5VGUW82A7C2rWnXK512LUmSJEmSFrdxBx9vVbv/7jEfu368W/XcSmrZpkYW4rpVZj5KkiRJkiR1M+7g4/Vq97815mOfVrt//TEfWxrY5q27Mx9Xr1jGiuVtlE6FtbXMx81mPkqSJEmSpEVo3FGTI2r3vzfmY59eu3/4mI8tDWxjLfjYVrMZMPNRkiRJkiQtfuMOPh5cfb08M7eM88CZeTmwmdJ05uA5Npdas3lbvdN1O1OuAdZY81GSJEmSJC1y4w4+7k1pNnPJmI87Y+a461s6vjSnTfXMx5aazQCsNfNRkiRJkiQtcuMOPq6uvm4e83FnXF59XdXS8aU5bap1u17bUrMZsNu1JEmSJEla/MYdfGyn88ZsMaHzSLNs2lqfdm3NR0mSJEmSpF4mFSyUlozN2yYz7bqe+bhl+052ZbZ2LkmSJEmSpDYYfJSGtLGW+bi2xYYza2uBzQS2bt/V2rkkSZIkSZLa0Fba1vqI+Ks2jtvCMaWhbK41nNm7xWnXq1cuIyiBR7DuoyRJkiRJWnzaipwcDLyzpWNLU7WpVn9xbYvTrpdFsGbl8iuCjvXp3pIkSZIkSYtBe5GTdprCWPROU7dpa73mY3vTrqE0nZkJPpr5KEmSJEmSFps2go9tdqK2y7WmrqPhTIvTrqGz6YwdryVJkiRJ0mIz7sjJo8Z8PGnB2VRrOLOuxYYzUDIfZ5j5KEmSJEmSFpuxBh8z813jPJ60EG2qZT62WfMRzHyUJEmSJEmL27JpD0BabDZNqNs1NDIfDT5KkiRJkqRFxuCjNKTNHd2uW552vdJp15IkSZIkafEy+CgNaePWyTWcqQc3N5v5KEmSJEmSFhmDj9IQMrMjCDjJbtf1LtuSJEmSJEmLgcFHaQhbd+xi5668Ynldy9Ou16/ZHdzcsMXgoyRJkiRJWlwMPkpDqDebAVjbcubj+jUrr7i/YcsOMrPP1pIkSZIkSQuLwUdpCM26i2tXtpv5uM9eu4OP23bu6qg3KUmSJEmStNAZfJSGsKlWd3HtquUsWxatnm/tquXUT3HBhq2tnk+SJEmSJGmcDD5KQ6hPu167qt0p1wDLIjqmXp9/2ZbWzylJkiRJkjQuBh+lIWzaunva9d6r251yPaPedOaCy8x8lCRJkiRJi4fBR2kIm7dNNvMROpvOXLDBzEdJkiRJkrR4GHyUhrCxlvm4bkKZj/vUMh/PN/NRkiRJkiQtIgYfpSFMJ/OxNu3ahjOSJEmSJGkRMfgoDaGz5uNkgo/72HBGkiRJkiQtUgYfpSF0drueRsMZg4+SJEmSJGnxMPgoDWFTbdr1ugllPnY2nNlKZk7kvJIkSZIkSfNl8FEawuZpNJzZa3fwcfO2nWysZV9KkiRJkiQtZAYfpSFsnELDmbWrlrMsdi/bdEaSJEmSJC0WBh+lIWyuZR2um1DNx2URHVOvbTojSZIkSZIWC4OP0hA2batPu55M5iM0m86Y+ShJkiRJkhYHg4/SEOrdricbfKw3nTHzUZIkSZIkLQ4GH6UhbK5lPq6d0LRrgH1qmY/nm/koSZIkSZIWCYOP0hDqmY97T2natTUfJUmSJEnSYmHwURpCPfg4qW7XAPt0TLs281GSJEmSJC0OBh+lAe3alWzeXm84M7lp1x01H818lCRJkiRJi4TBR2lAW3bsJHP38tS6XW/YStYHIkmSJEmStEAZfJQGtLE25Rpg3SSnXe+1O/Nx87ads8YiSZIkSZK0EBl8lAa0eevuKdfLAtasnNzTZ+2q5SyL3ct2vJYkSZIkSYuBwUdpQJu27c42XLdqBRHRZ+vxWhbRWfdxg3UfJUmSJEnSwmfwURrQplrm49oJNpuZ0VH30cxHSZIkSZK0CBh8lAbUzHycNDMfJUmSJEnSYmPwURpQvebjJDtdz9inlvlozUdJkiRJkrQYGHyUBrSp1mF67arpTrs+/zIzHyVJkiRJ0sJn8FEaUMe066lkPtanXZv5KEmSJEmSFj6Dj9KANm+b7rTrjpqPZj5KkiRJkqRFwOCjNKCNW+sNZ6bc7XrDVjJz4mOQJEmSJEkahsFHaUCbO2o+TmHa9V67Mx83b9vZEQyVJEmSJElaiAw+SgPaVJt2vffqyWc+rl21nBXL4oplO15LkiRJkqSFzuCjNKCObtdTqPm4LIKD1q++YvmCDdZ9lCRJkiRJC5vBR2lA9czHadR8BDh4nzVX3L/AzEdJkiRJkrTAGXyUBlSv+TiNbtcAB9cyH8+347UkSZIkSVrgDD5KA9o45YYzAIfsU592beajJEmSJEla2Aw+SgPaXJ92PYWGMwAHr9897drMR0mSJEmStNAZfJQGtHnb9Kddm/koSZIkSZIWE4OP0oDq067XTWnadWfDGTMfJUmSJEnSwmbwURrAzl3Jlu27rlheO61u1x0NZ7aSmVMZhyRJkiRJ0iAMPkoDqE+5Bth7atOud2c+Xr59Z0c2piRJkiRJ0kJj8FEawKatOzuW106p4cwBa1exYllcsXz+ZdZ9lCRJkiRJC5fBR2kAm2qZjyuWBauWT+eps2xZcND6etMZ6z5KkiRJkqSFy+CjNIDNtczHdatXEBF9tm5XZ9MZMx8lSZIkSdLCZfBRGkBnp+vpTLme0dl0xsxHSZIkSZK0cBl8lAZQbzizdkrNZmYcsk992rWZj5IkSZIkaeEy+CgNYNO2zmnX03TI+t3Trs18lCRJkiRJC5nBR2kAmxbStGszHyVJkiRJ0iJh8FEaQD34uHbVdDMfOxvOmPkoSZIkSZIWLoOP0gA216Zd7716ITWc2UpmTnE0kiRJkiRJvU03hWuBioijgWOAKwM7gXOB72bmb6Y6ME1NR+bjtGs+1jIfL9++k0s2b2f/daumOCJJkiRJkqTuDD7WRMQDgedTAo/dHv8G8NzMPKWFcx8PfGXE3W+emd8d32jUtGnbwqn5eOC6Vey710ouvXw7AGeev4FbXfPAqY5JkiRJkiSpG6ddAxGxPCLeCXyIHoHHyq2B/42Il05mZFooNm9dON2uI4LrHrr+iuUzz9swxdFIkiRJkiT1ZuZj8RrgpNryZuB9wBnAKuCWwAOAlZSA7fMi4qLMfE2LYzoH2DHnVoVdR1q2saPb9fSfNtc7dD3f/s1FAPzc4KMkSZIkSVqgph9FmbKIuDfwd7VVPwXukZm/a2x3I+B/KHUgAf41Ir6UmT9qaWjHZ+bZLR1bQ6o3nFk75YYzANc5ZHfm4y/ON/goSZIkSZIWpj162nVELANeXlu1GTihGXgEyMwfAA8CdlWrmvtqCavXfNx7ytOuoWQ+zvjFeRvseC1JkiRJkhakPTr4CNyZzhqPr8vMs3ptnJnfoNSFnHGfiLhWW4PTwtHR7XoBTLu+Ti34uGHrDs695PIpjkaSJEmSJKm7PT34eL/G8n8MsM/bG8snjmcoWsg21RvOTLnbNcA+a1Zylf32umLZpjOSJEmSJGkh2tODj/eu3f91Zv56gH2+RmeDl/uMd0haiDbXpl1Pu9v1jHrHa5vOSJIkSZKkhWiPDT5GxH7A1WurThtkv8zcBnyvtuqYXttq6ejIfFwADWegM/ho5qMkSf+/vfuOj+sq8z/+fWbUJcuSe4m705wGCYnTkw0JLQtLC7AhhLIsgVAWtoT9hYXfssvSFn6EXTphwyYQ6kJCykIIgfQKpOEUd8d27LhILurl+f1xp9y51sgz1ozmSvN5v17z8j1nzr33kXQsXT06BQAAAHFUtclHSUdHymuKODc8QrLdzOaUIJ6oT5vZo2bWYWb9ZvaCmT1uZt80s9ebWTwyYFWgf3BY/UPDmXIc1nyUIpvOsOM1AAAAAACIoWpOPi6NlDcVcW60bfRapfCXkk6Q1CapVtJMScdJeo+k/5H0rJn9RRnui4ie/qGcchynXa/dsV8DoQQpAAAAAABAHFRz8rE1Ut5dxLkdkfKUEVuNXYekjZJ2SopmlpZKusHM/q1M90bK/tB6j5LUFIMNZyRp6YwW1SRMkjQw5Fq3o6vCEQEAAAAAAOSq5uRjS6TcO2KrkfUc5FqHapek/5T0CknT3X2auy9295mSpkl6vaR7I+dcaWZ/cyg3M7M/jfSStGwsH8Rk092XTT7W1SRUm4zHf5u6moSWzmzOlJ/etreC0QAAAAAAABwoHlmUymiIlPuLOLcvUm4cYyxSsInNYe7+IXf/lbvnjMR09z3u/nNJZ0n6eOTcz5nZghLEgBF0haZdt8RkynXakXOyA3jZdAYAAAAAAMRNNScfoyMd64o4tz5Sjo6ELJq773P3g46+9MCnJH0jEs8Vh3DPY0Z6KXdDnarXFRr5GJcp12lHseM1AAAAAACIsWpOPu6PlKMjIUcTHekYvdZ4+CflJj1fXYEYqkI4+dgck52u046cnU0+Pk3yEQAAAAAAxEw1Jx+jC+S1F3FuW6Q87lkfd98l6c5Q1SIzmzvecVSD7tC06+b6eI18DO94vaWzR/t6ByoYDQAAAAAAQK5qTj6uj5QXFnHuokh53RhjOVTPRMqzKhLFJNcV2u26OWZrPh7W3pizDuWz2ysxCBcAAAAAAGBk1Zx8XBUpLy/i3PBu0B3uvq0E8RyK6FqTTRWJYpKL85qPZqYjZmc3W2fdRwAAAAAAECdVm3x0905Jm0JVpxVynpnVSTopVPVECcMq1uxIeWdFopjkuvrC067jNfJRyp16/cy26GoCAAAAAAAAlVO1yceUW0PHy8xsaQHnnKXczWluLm1IRTkrdDwgaUulApnMuvvju+GMxKYzAAAAAAAgvqo9+fjzSPmvCzgn2uaG0oRSHDN7pXKnit/r7t2ViGWy2x8a+dgUsw1nJOnIOa2Z42e275O7VzAaAAAAAACArGpPPt4u6clQ+YNmtiRfYzM7TdJFoapb3H11nraLzcxDr9+Nct3GYoJO7Wr9zUj1d4u5BgoX95GPR4WmXXd2D+iFfX0VjAYAAAAAACCrqpOP7j4s6cpQVbOkm8xsQbStmR0v6SfKfs6GJX2sRKG82czuNLPXpNaUzMvMzpf0oKRwjI9Juq5EsSAi7ms+tjfXadaU+kyZTWcAAAAAAEBcxC+TMs7c/SYz+5qky1NVx0h6ysy+L+lRSbWSTpX0xtRx2kfd/bEShnJ26tVpZvdKelzS85L2KdjFeomkCySdEDlvm6TXphKpKIPwbtfNMdvtOu3IOVMyIx6f2bZPZx8xs8IRAQAAAAAAkHxM+5CkKZLelio3S3pPnrYu6bPu/oUyxdIm6cLU62AekHSJu28oUyxQ7rTrphiOfJSCqdd3rw42O2fTGQAAAAAAEBdVPe06zd2H3P1SSW9W7hqQUQ9IOt/drxylzaF4RNI1kp5SkNwcjUu6T9Ilks5097UljgURXf3ZadctMdxwRpKOCO14/cz2vRWMBAAAAAAAICuew7gqxN1/LOnHZnaspOMlzZM0JGmrpIfdfV0R19ogyQps+6Skd0mSmbVJerGkhZJmSGqU1CepU9IGSQ+5+55C48DYhaddN8VwwxlJOiq04/Xq7fs1NOxKJgrqfgAAAAAAAGUTz0xKhaWSgaONgCznvTsl/bYS98bIctd8jOd/mcNntyhh0rBLfYPD2rCrS8tmtlQ6LAAAAAAAUOWYdg2Mwt3V3R/e7Tqe064bapNaPL05U35yC4NjAQAAAABA5ZF8BEbRNzisweHsMpzNMd1wRpJetLAtc/zbp1+oXCAAAAAAAAApJB+BUYRHPUpSU108Rz5K0gVHz84c3/H0CxoYGq5gNAAAAAAAACQfgVGF13uU4rvhjCSdfcRM1dUE/6X39g7q4fW7KxwRAAAAAACodiQfgVF09WeTj421yVjvIN1cX6Mzlk3PlG9btb2C0QAAAAAAAJB8BEbV1Rf/zWbCLlgxJ3P861Xb5e6jtAYAAAAAACgvko/AKPb2DmSOW2K82Uza+UfPyhxv6ezRU8/vq2A0AAAAAACg2pF8BEbR2d2fOW5rqqtgJIWZ1dqgFy1oy5R/zdRrAAAAAABQQSQfgVF0dGVHPrY31VYwksJdsCK76/Wvn9pWwUgAAAAAAEC1I/kIjCI88rG9Of4jHyXpZaHk45Nb9mprZ08FowEAAAAAANWM5CMwio7u8MjHiZF8XD6rRUtmNGfKtz/F1GsAAAAAAFAZJB+BUXSERz5OkGnXZpY79Zp1HwEAAAAAQIWQfARG0Rka+TgRNpxJCycfH1i3K2fXbgAAAAAAgPFC8hEYRe7Ix4mTfDxxYbump9aoHBhy/e6ZHRWOCAAAAAAAVCOSj8AoOrom3rRrSUomTOcdNStTZuo1AAAAAACohJpKBwDEWUfMpl1f/+CmgtvW1yQzx7f9aZuuvX+DahK5f2+4eOXCksUGAAAAAAAQxchHII/egSH1DAxlyu3NE2fkoxTsel2bNElS3+Cw1u/sqnBEAAAAAACg2pB8BPIIbzYjTaw1HyWpriah5TNbMuUnNu+pYDQAAAAAAKAakXwE8ghvNtNQm1BDbXKU1vF0zPypmeM/PtepvT3seg0AAAAAAMYPyUcgj4m603XY8fOnqrUhWNp1aNh192p2vQYAAAAAAOOH5COQR2fMNps5FDXJhM46fGam/NCG3drfN1jBiAAAAAAAQDUh+QjkER75OG2CbTYTdvLiaWquD0Y/Dgy57l2zs8IRAQAAAACAakHyEchjMox8lIKNZ85aPiNTvn/dLnX3M/oRAAAAAACUH8lHII+OrvCajxN35KMkrVwyTY2pDXP6B4d139pdFY4IAAAAAABUA5KPQB4doZGPE3XDmbT62qTOWD49U75v7U71DgxVMCIAAAAAAFANSD4CeXSG1nycyNOu005bOkP1NcF/+d6BYT2wjtGPAAAAAACgvEg+Anns7p48064lqbEuqdOWZUc/3rNmJ2s/AgAAAACAsiL5COTROYmmXaedsWyGapMmSeruH9L1D26qcEQAAAAAAGAyI/kI5NGRM+164o98lKTm+hqtXJId/fj1363N2VgHAAAAAACglEg+AiMYGnbt6Zl8Ix8l6azDs6Mfd3X1619vXlXhiAAAAAAAwGRF8hEYwd6eAblny5Mp+TiloVYXHD07U/7ZH7fojqe3VzAiAAAAAAAwWZF8BEYQnnKdMGlKQ00Foym905fP0IL2xkz5yp89qb29A6OcAQAAAAAAUDySj8AIOkKbzbQ11SmRsApGU3oJM73+xMNUlwy+BWzb26vP3Pp0haMCAAAAAACTDclHYASdoZGP7ZNks5mo2a0N+uB5yzPlHzy0Sfet2VnBiAAAAAAAwGRD8hEYQXjk42Ra7zHqvecu04q5rZnyR3/2uLr7BysYEQAAAAAAmExIPgIjCI98bJvEycfaZEKff+PxSqamlT+3u0f//qtnKhwVAAAAAACYLEg+AiPoqIJp12nHzp+q956zNFP+7n0bdC/TrwEAAAAAQAmQfARGkDPtunnyjnxM++B5h2v5rBZJkrv0Nz/8o7bv7a1wVAAAAAAAYKIj+QiMoKMrPO16co98lKSG2qS+/JYXqa4m+Jawc3+/Pnj9HzU4NFzhyAAAAAAAwERG8hEYQe6068k/8lGSjpk3VZ98zTGZ8kMbdusLtz1bwYgAAAAAAMBER/IRGEFnzm7Xk3/kY9pbTl6g1714fqb8jTvX6jdPba9gRAAAAAAAYCIj+QiMoKNKdruOMjP92+uO1eGp9R8l6W9//Jie291dwagAAAAAAMBERfIRiHD33A1nqij5KElNdTX6+iUnqrE2KUna0zOgD1z/B/UNDlU4MgAAAAAAMNGQfAQiegaG1D+Y3WilmqZdpy2fNUWfef1xmfJjm/fokzetqmBEAAAAAABgIqqpdABA3IRHPUqTe9r19Q9uGvX9UxZP00MbdmfadvUNauWS6XnbX7xyYUnjAwAAAAAAExsjH4GIjq7seo8t9TWqq6ne/yYXHj9XC9obM+WbHtuq9Tu7KhgRAAAAAACYSKo3qwLkEd7puq0Kp1yH1SYTeuvKRZrSEAySHnbp+gc3qjO0IQ8AAAAAAEA+JB+BiPBO19W22cxIWhtrdcnKRUomTJLU1T+k7z2wMWddTAAAAAAAgJGQfAQiwqP6qn3kY9qCaU167YvmZ8pb9/TqZ3/cLHevYFQAAAAAACDuSD4CEbu7stOuGfmYddKidp2+LLvZzOOb9+jOZ3dUMCIAAAAAABB3JB+BiNxp14x8DHvlsXO1dGZzpnzbqu16fHNn5QICAAAAAACxRvIRiMidds3Ix7BkwnTxyQs1rTn7efnp7zdrAztgAwAAAACAEZB8BCI6usPTrhn5GNVUX6N3nLZYjbVJSdLgsOu6BzZq576+CkcGAAAAAADihuQjEBEe+djezMjHkcyYUq9LTs3ugN0zMKTv3r9Bu/aTgAQAAAAAAFkkH4GI8MhHpl3nt2RGs9540mGZ8u6ufr372kfUOzBUwagAAAAAAECckHwEIthwpnAnHNaml6+YnSn/cVOnPvzDRzU07BWMCgAAAAAAxAXJRyBkcGhY+3oHM+V2Rj4e1NlHzNTJi6dlyr/80zZd8dPHNUwCEgAAAACAqkfyEQjp7BnIKbPm48GZmV5zwjwdMbslU/c/f9isf7rxSbmTgAQAAAAAoJqRfARCwpvN1CZNzXXJCkYzcSQTpotPWaSVS7IjIK9/cJP+5eZVJCABAAAAAKhiJB+BkOhmM2ZWwWgmlrqahL7zjpN14sK2TN01927QZ3/5NAlIAAAAAACqFMlHIKSji81mxqKlvkbffdcpOv6wqZm6b965TlfdvrqCUQEAAAAAgEoh+QiEhHe6bmOzmUPS2lCra991io6aMyVT9+XfrNYXfvUMIyABAAAAAKgyJB+BkPC0a0Y+Hrq2pjp9790rtXxWdhOar/x2jT52w5MaYhdsAAAAAACqBslHICQ88rGdkY9jMqOlXte/e6WOnJ0dAXn9g5v0wR/8QX2DQxWMDAAAAAAAjBeSj0BIZ1fuhjMYm1mtDfrRZafqpEXtmbpbn9imd333Ye3vG6xgZAAAAAAAYDyQfARCckc+Mu26FNqa6nTdX52ic4+cmam7d80uXfztB7Rrf18FIwMAAAAAAOVWU+kAgDjpzFnzkZGPxbr+wU1533vpUbPV0dWvxzbvkSQ9vnmPLvjSXbr0tEWaNaVhxHMuXrmwLHECAAAAAIDxwchHICR3t2tGPpZSMmG66CULdNqy6Zm63V39+sada7Xmhf0VjAwAAAAAAJQLyUcgJGe362ZGPpZawkx/ftxcvfyYOZm63oFhffe+9Xpw/a4KRgYAAAAAAMqB5COQ4u7qZLfrsjMznXPETF18ykLVJk2SNOzSjY9u1S2Pb9Wwe4UjBAAAAAAApULyEUjZ3zeoweFs4osNZ8rr2PlT9Z6zlqm1Ibv07L1rd+m6+zeqb2CogpEBAAAAAIBSIfkIpIQ3m5GkqY0kH8ttfnuj3nfucs2bmt1w5pnt+/TNu9blrL8JAAAAAAAmJpKPQEo42dXaUKOaJP89xsPUxlq95+xlWjG3NVO3bW+vvva7tfrDpo4KRgYAAAAAAMaK7AqQsrsrtN4jm82Mq7qahC5euVBnHz4zU9fVN6i3fOsB/eKxrRWMDAAAAAAAjAXJRyAlPO26jc1mxl3CTK84do7ecOJ8JS3YiKZ/cFgf+sEfddXtz8rZiAYAAAAAgAmH5COQ0pGz0zXrPVbKSYum6Z1nLlZjbTJTd9Xtq3XZdb/Xvt6BUc4EAAAAAABxQ/IRSOkIjXxsZ+RjRS2d0aL3nbtMS2c0Z+puW7Vdf/GVe7V6+74KRgYAAAAAAIpB8hFIeWFvb+Z4Gms+VtyMlnr9/PIzdM4R2XUg1+3s0l989V7d8vjzFYwMAAAAAAAUiuQjkLL6hf2Z46Uzm0dpifEytalW//WOk/Wh85Zn6rr7h/T+6/+gT9/6lAaHhisYHQAAAAAAOBiSj4Akd8+Zznv4rCkVjAZhyYTpb192pK6+9CWa0lCTqf/WXet00Tfv1/qdXRWMDgAAAAAAjIbkIyBpx74+7e0dzJQPn9VSwWgwkvNXzNZNHzhTR87OJob/uKlTr/ry3brugY3shg0AAAAAQAyRfASUO+V6Rkud2lnzMZYWz2jWz99/ut5w4mGZup6BIX38hif1jmse1vbQup0AAAAAAKDySD4CElOuJ5Cmuhp98U0n6BuXnKj2ptpM/Z3P7tDLr7pLN/xxC6MgAQAAAACICZKPgHJHPh4+mynXE8Erjp2rX33kbJ131KxMXWf3gD78o0d10Tfu1xOb91QwOgAAAAAAIEk1B28CTH45yUfWe4yN6x/cdNA2Lz1qlqY21OqWJ55Xf2r360c2dug1X7lHJy5q18tWzNaUhmCE5MUrF5Y1XgAAAAAAkIvkI6pedKfr5Uy7nlDMTCcvmaZls1p06xPPa9XzeyVJLun3Gzv05JY9OveImTp12fTKBgoAAAAAQBUi+Yiqt6urXx3dA5ky064npmnNdbrk1EVa88J+3fz4Vr2wr0+S1Dc4rF+t2q671+xUd/+QLj1tUWYkJAAAAAAAKC/WfETVW709O+V6WnOdZrTUVzAajNXyWS364HmH69UnzFNjbTJT390/pH//1TM683O/1VW3P6s9oYQzAAAAAAAoD5KPqHprXghPuWbU42SQTJhOWzpdf3fBETrr8BmqS2a/1e3pGdBVt6/WmZ+7Q5/536e0fW9vBSMFAAAAAGByI/mIqsdmM5NXU32NXnnsXP3Dy4/UuUfOVEt9dqWJfX2D+uad63Tm5+7Q3//kMT0bWvcTAAAAAACUBslHVL3wtGuSj5NTc32NXrZiju796Hn68PmHq7Uhm4QcGHL99Peb9bIv3aV3ffdhPbBul9y9gtECAAAAADB5sOEMql7OyMfZ7HQ9mU1tqtWHzz9C7z5rqX740CZ95571en5Pdtr1HU+/oDuefkEnHDZV7zl7mV5x7BwlE1bBiAEAAAAAmNhIPqKqdXT1a+f+vkyZkY+T2/UPbsocN9XV6PJzl+vxzZ26e/VObQut/fjY5j16//V/0LTmOp25fIZOXNiuupoDB4pfvHLhuMQNAAAAAMBERfIRVS086nFqY61mTmGn62qSTJhevLBdL1rQptUv7Nfdq3do7Y6uzPu7u/r1i8e26vantuvM5TO0csl0NdYlR7kiAAAAAAAII/mIqrY6tNP14bNaZMYU22pkZjpi9hQdMXuKtnT26O7VO/TE5j1Kr/zY3T+k21Zt153P7tCpS6frjOUzcjavAQAAAAAAI+O3Z1S1nM1mZjPlGtL8tka95eSFevmKft2zdqce2bBbA0NBGrJvcFh3PrtD967ZqZMXT9M5R87U/LbGCkcMAAAAAEB8sds1qtqa0LTr5bPYbAZZ7c11evXx8/QPLz9Kf3bkTDXUZr9dDg677l+3S+d8/re64qePad2O/aNcCQAAAACA6sXIR1S16LRrIKqlvkYXrJijsw6fqQfX79Y9a3aqq29QUpCE/PEjm/WT32/Wq46bq8vPXaZj5k2tcMQAAAAAAMQHyUdUrT09A9q+N7TTNdOuMYqG2qTOOWKmTl82XY9s7NDdz+5QZ8+AJMlduuXx53XL48/rjOXT9bZTF+v8o2epJsngcgAAAABAdSP5iKoVnnI9pb5Gc1obKhgNJoraZEKnLZ2uUxZPU2NdUl/73RqtC+2Qfe+aXbp3zS7Nm9qgt566SG8+eYFmtLCLOgAAAACgOjEsB1Vr9fbslOvls9npGsVJJkxvPOkw/foj5+jrbz1Rx85vzXl/655e/fuvntHpn7lDH7j+D/rNU9s1MDRcoWgBAAAAAKgMRj6iaq0OjXxkvUccqmTC9Mrj5uoVx87RIxs7dO39G/W/TzyvweFgh+z+oWHd/Pjzuvnx59XeVKsLj5+r175ovk5a1E7CGwAAAAAw6ZF8RNXKTT6y0zWKd/2Dmw6oO23pdB0zr1UPb9ith9bv1r7ewcx7Hd0D+t4Dm/S9BzapralWK+a2asXcVi2a3qxkwnTxyoXjGT4AAAAAAGVH8hFVa01k2jVQKq0NtXrpUbN17hGz9My2fXp0c6eefn5vZjSkJHV2D+i+tbt039pdaqxN6ui5UzStuU5nHj5DLfV8awYAAAAATA78houqtK93QFv39GbKTLtGOSQTphXzWrViXqt6B4b0p6179OhznVq3o0seatczMKQ/bOrUe7/3e9UkTCctatc5R87U2YfP1Iq5rUokmJ4NAAAAAJiYSD6iKq0N7U7cXJfU/LbGCkaDatBQm9RJi6bppEXTtLd3QE89v1dPPb9Xa3d0aSg0InJw2PXg+t16cP1uff6Xz2hGS73OWD5dpy4NXounN7FWJAAAAABgwiD5iKqUs9P1LHa6xvhqbajVyiXTtXLJdPUODOnZ7fu06vm92rirW3t6BnLa7tzfpxsf3aobH90qSZrT2qBTl07TKUum6+TF7Vo2s4WRkQAAAACA2CL5iKq0JrTZzHI2m0EFNdQmdfxhbTr+sDa9+eQFemxzp+56dofufHaHHnuuU8Oe237b3l7d8OhW3ZBKRrY11eqkhe16yeJpesnidh03f6oaapMV+EgAAAAAADgQyUdUpWdDIx8PZ7MZxMSPHn5OkjRrSoMuOmmBLjxurtbu6NL6nfu1bkeXXtjXd8A5nd0D+s3TL+g3T78gKVhn8rC2Ri2a3qRLT1uskxa1q725blw/DgAAAAAA0kg+jsDMjpF0vKR5koYkbZH0iLuvH+c4EpJOl7RM0lxJe1Kx3O3uHeMZy2Syp2dAj2zMfvrYbAZx1VRXo+PmT9Vx86dKkvb3DWr9zi6t27FfG3d1a/veXkUGRmpo2LVxd7c27u7WXat3SgqWFnjJomB05MmL27VwGutGAgAAAADGB8nHEDN7o6SPK0g8jvT+fZI+5u6/K3McNZI+KulyBQnQqH4zu0nS37v7hnLGMhn91z3rta93UJI0pb5GJy+ZVuGIgMK01OcmI3sHhrRpd7c27urShl3d2tzRrYGhaDoyWGZgzQv79cPUyMoZLfU6eXG7Tl06XSuXTtMRs6awbiQAAAAAoCxIPkoys6SkqyW94yBNT5f0GzP7tLt/vEyxzJZ0s6SXjNKsTtIbJF1gZpe6+43liGUy2tM9oP+6JzuA9V1nLlFrQ20FIwIOXUNtUkfMnqIjZgfrlg4Nu7Z29gQjH3d1afvePu3cf+BU7Z37+/S/T27T/z65TZLU3lSrkxdP08ql0/XihW1aMbeVdSMBAAAAACVB8jHwJeUmHrslfV/SowoSfSsVJPtqJSUk/ZOZ7Xb3L5UyCDNrlHSjchOPWyR9T9JaSdMlvVLS2an3WiX90MzOc/f7SxnLZPWde9ZpX19q1GNDjd515pIKRwSUTjJhWjCtSQumNenM5TP0l6cs0MZd3XpkY4ce2bBbj2zsyNlsKa2je0C3rdqu21ZtlyTVJk0r5rbqRQva9KKFbTpi9hQtndGixjoSkgAAAACA4lR98tHMLpT0wVDVKkmvcPfnIu1OkHSrstOgv2Bmt7v7EyUM518UJDrTfirpEncPD136rJldLOm7CpKhDZJ+ZGZHuHtvCWOZdDq7+/Vf927IlN995lJNbWTUIyavHzyU/TaW3lG7u29QG3d3a/3OLq3f2aWtnT0HrBs5MOR6bPMePbZ5j/77/o2Z+vltjVo2q0XLZjZrQXuT5kxt0OzWes2a0qBZrfWqryE5CQAAAADIVdXJx9SGLp8OVXVLenU08ShJ7v6YmV0k6W4Fox/T5766RLEcJukDoarHJV3s7gMjxHK9mS2U9JlU1QJJ75f0xVLEMlldffd67U+NemxtqNE7z1xc2YCACmiqr9HRc1t19NxWScG6kRt2dWn9ji5t2t2tLZ09Ghw+cN1ISdrS2aMtnT2669kdI74/tbFWbU21amusVWtjbaY8tbFWbY11mtpYq6mpcmtDraY01KilvkYtDTWqTSbK9jEDAAAAACqnqpOPkl6q3M1l/sPd1+Vr7O73mdlPJL05VfXnZrbc3deUIJb3KRjFmHbFSInHkC8oSFbOT5U/LJKPee3u6tc192bXenzP2UtZ6xFQsG7kUXNaddScIBk5NOzatrdXz+0ONrDZ2tmrnfv78iYkw/b0DGhPz4A2HrTlgeprEjnJyJb6GrXU16qlPqnGuho11ibVUJtQY21SjXVJNdQmU3VJNdYlcsuhNk11SRKbAAAAAFBB1Z58fF2kfHUB53xb2eSjJL1WQSKwlLFslHTbaI3dfdDMrpH0T6mqw8zsJe7+SAlimXS+ffc6dfUPSZLammr19tMXVzYgIKaSCdP8tkbNb2tUsMysNOyuzu4B7djXqx37+rRjf5/29Axob8+g9vYOqDv1f2ss+gaH1be/Xzv394/5WlEt9TVqa6rVtOY6tTXVaVpTbfBvc53aQ8dtTbVqb6pTe1Md61sCAAAAQIlUe/LxwtDxWndfW8A5d0vqVXaU4p9rjMlHM1si6ehQ1e3ufvBhRtKvlU0+pmMh+Rixa3+f/vu+DZnyX5+1VFMY9QgULGGmac1Bgu7IOQe+Pzg0rL29g+rqG1TPwJB6+ofUMzCk7v4h9abK3Zn6QfX0DwXJxsHhcYl/f9+g9vcNanNHT8Hn1Nck1N4USkg2B0nKtsagPKUhO0IzGLFZq+b6pKak/q1htCUAAAAASKri5KOZtUlaGKp6oJDz3L3fzH4v6YxU1fGjtS/QCZFyQbFIekjSoLJfx1LEMul86+51mZFZ7Yx6BEquJpnIJCeLMeyu/lQSsncglZAcGFJv6t++wWH1Dg5pcMjVPzSsgcFhDQwNa2DIU/8Gx/3p48FhDQy7BgaHD9hEp1h9g8PatrdX2/Ye2j5ejbXJ0PTx7HTyKfU1aq6PJi6Duik5U86D48bapMxsjB8NAAAAAFRO1SYflTvSUJKKWbdxrbLJx3Yzm+Pu28Y7FnfvNbOtyiZRV4whhklp5/4+XXtfdgW6y85Zppb6au72QHwkzNSQWqexlDvPu7uGhl0DQ66+wWAEZnf/kLr6B4PjvsFMuSdc3z+k/hKNxuwZCEZ/7tjXN6brmEkNNak1LGsSaqhLZsupNTDrU+tc1tUkVJdMqCZhqq1JqDaZUF3SVJsMjmtrEqpNWOY4/F5N0lSXbpdMqK4meC+ZMNUkgn/Tr5rIMclRAAAAAKOp5izM0kh5UxHnRtsulTSW5ONYY0knH6PXqXr/fd8G9QwEox6nN9fp0tMWVTgiAOVmZqpJmmqSUmNdUm1NhZ87ODSs7oF0wnJQ3X2paeOhBGX3wFB2ZGZ6xObgkAaGxjre8kDu2URmXCVMqkkklEho1ERlupxIJSvTSct06tIseAV1lq1TtkG4bfpcS9UnzGQ28r+JTNmUDMWZjq8maTmJ1pqc+BOh9w+sz9aFzs1Xn/r40x9rIhV7ULYRPx5JSiRC7RR8PEodj3id1OfNFfQhueRyuafrPDM6OL3ISxBbcH7ClBNrwkzJTNwkmwEAAFCcak4+tkbKu4s4tyNSnhKTWGrNrN7dxzbUZhJ537nBSMdv3rVOl52zVE111dzlARxMTTKh1mRCrYewLuzQcDCNvHcwO4U8naAM6odHTFr2DQxnjnsHgn8L2Fw8NoZd6h8aloYkaXzW8UTlRJORmeRuKlmZTvQmEqFjMyUSoeNQvYUSnenkZ/5EaKocOk6k2tsIx+lzw23zGW2lbR9lIYd0Qjd7jVSS17PnZpO+oevlSQRH63ISx56NJXzfzNcm8nUK6iynPNJ7kX9SbQ7848BIbdIJ85G+7tGvRbiPZL+GwR8E0sfpr1n485v5nB1Qp4LaKafdoV8n53KpymJiiCb8M21D10/H5zl12bY550ZiyNc/sudkzz9Y27xx5dwrG0e0v+Zcq5i4Ivca6WsY/qNUug+m6zXCH67Sf7TJHIf7f+QPWdnjkesVvVbe61vkGtk/FkWvW6himmc/99HvLz7C1+DAfhFtE/6elb5ubmy53y+in6fwe/m+N1n4ZEW/BnnOidw/G0/+9tn7HRjz6PfIvn/A98O83y8Pcq0ivqj5mmb/BxTavrzXH7ltnmsUHUuR1y/yOvlOGKm2mM/XzCn1umDF7Hx3rSrVnIlpiZSLWdgrumtB9FqVjqWg5KOZ/SnPW0etXbtWxxxzTBFhxNuwS1f9QPpycT/nx2xPz8D43hDApBBNQmR+gfARfjlIt0m/n7lG6JfFTL3nlLPHB/7CO+aFMwEAAIAq1liX0GHtRUzDirG1a9dK0oJDPb+ak48NkXJ/EedGk3uNkygWSRru6+vrWrVq1XMluFY1W5b6t5Bd1FFd6BvIh76BfOgbyIe+gXzoG8iHvoF86BslNCBp1fOVjqJkFkjqPtSTqzn5GB1dWMw2rfWRcnT0YSliKXT04yHH4u6TZ2hjDKVHlvJ5RhR9A/nQN5APfQP50DeQD30D+dA3kA99A+WSqHQAFbQ/Uo6OPhxNdHRh9FoTORYAAAAAAACgJKo5+bg3Um4v4ty2SHnf2EIpWSwDbDYDAAAAAACAuKjm5OP6SHlhEecuipTXxSSWscYBAAAAAAAAlEw1Jx9XRcrLizh3Wei4w923VSIWM2uQNG+U6wAAAAAAAAAVU7XJR3fvlLQpVHVaIeeZWZ2kk0JVT5QgnMci5YJikXSKcjcNKkUsAAAAAAAAQElU827XknSrpPemjpeZ2VJ3P9jU5bOUuyHMzWMNwt3Xm9nTko5KVZ1vZubufpBTL4iUxxwLSocdwpAPfQP50DeQD30D+dA3kA99A/nQN5APfQPlUrUjH1N+Hin/dQHnRNvcUJpQcmJZJOllozU2sxpJ7wxVbZH0SIliAQAAAAAAAMas2pOPt0t6MlT+oJktydfYzE6TdFGo6hZ3X52n7WIz89DrdweJ5euSwjtVf97Makdp//eS5ofKVxUwUhIAAAAAAAAYN1WdfHT3YUlXhqqaJd1kZguibc3seEk/UfZzNizpYyWM5TlJXw1VHS/p+2ZWP0Isfynpk6GqLZK+UqpYAAAAAAAAgFIwBstJZvZVSZeHqrokfV/So5JqJZ0q6Y2p47R/cPcvjHLNxZLWh6rudPdzDxJHk6Q7Jb0kVL1F0nWS1klql/QqSeeE3u+TdL673zPatQEAAAAAAIDxRvJRkpklJV0j6W0FNHdJn3X3K0drdCjJx9R5cyTdIunEAmLZJ+nt7h5duxIAAAAAAACouKqedp3m7kPufqmkNyt3DcioBxSMMhw18TjGWLYpGGn5CUnb8jTrV7BBzQkkHgEAAAAAABBXjHwcgZkdq2DNxXmShiRtlfSwu68b5ziSkk6XtFzSbAUjHTdLutvdd49nLAAAAAAAAECxSD4CIWZ2jHITz1skPeLu60c9sfRxJBQknpdJmitpTyqWu929YzxjQaDSfcPM6iQdLWmFpDmSmiTtlbQ9Fce4/nEEWZXuG4ivuPUNM2tV8LNlnqRZkvZLeiEV16Pu3lWJuKpRXPqGmS1TsNTPXElTJPVI2iXpcUlPuPvgeMaD+OBZFFE8iwIYC5KPgCQze6Okjyv4RWAk90n6mLv/rsxx1Ej6qIINkOaN0KRf0k2S/t7dN5QzFgQq2TfMbL6Cza5eJelMBQ95+ayR9DVJX3P3vlLHggPF5ftGPmb2fklfiVR/0t3/uQLhVJW49Q0zO0vBz5YLJNXlaTakYHmZj7n7neMRVzWKQ99Izax5n6T3SzpqlKY7Jf23pE8z46Z8Ukm+oxVsOJl+nSCpMdTsz8bx+wXPojERh77Bs2g8xaFvFIJnUYSRfERVSz2AXy3pHQU0H1bwAP7xMsUyW9LNyt3tPJ+9ki519xvLEQsq3zfM7GWSfinJijz1T5Le5O6rShULclW6bxTCzA6TtErBSKYwHvjKKG59w8yaFDz0v0OFfy/5B3f/QrliqlZx6RtmNkvBxoaFPGukvSDpDe5+T6njqXZm9j+SXi6p+SBNxyWJwLNofMShb/AsGk9x6BuF4FkUUTWVDgCosC8p9xeBbknfl/SogtEhKyW9QVKtgg2a/snMdrv7l0oZhJk1SrpRuQ97WyR9T9JaSdMlvVLS2an3WiX90MzOc/f7SxkLMirdN5qU+7A3LOkxSXdL2iipQ1K7gg2q/kLZ0UzHSLrDzM509zUligW5Kt03CvF1Hfiwh/KLTd8ws2YFSaZzQtU9kn6jYITjdklJBVPnXiTpPAU/W1AeFe8bqSmTv1buqMs+Sb9Q0Cd2S2qRdJyCkU7TUm1mSfpfM1tJMqHkTtLBEwjjgmfR2IlD3+BZNJ7i0DcKwbMocrk7L15V+ZJ0oSQPvf4kacEI7U5Q8PCVbjck6bgSx/LvkVh+Iql+hHYXK5jukm63SVJDpT+Xk+0Vh74h6bWpa65TMP1p7ihtFyqYqheO+a5Kfx4n4ysOfaOAGN8Suu+qSLz/XOnP4WR9xa1vSLo1Es+1kmaN0r5W0uskvaLSn8vJ9opL35B0RSSORyUtydN2iqQfRdr/utKfy8n2krQh9PntlfSQgl/Yr4t87s8dh1h4Fo3RKw59QzyLxvIVh75RQIw8i/I64JUQUIVS62R8OlTVLenV7v5ctK27PybpIgV/7ZOCEQmfjrYbQyyHSfpAqOpxSRf7CGuluPv1kj4RqlqgYM0mlEiM+sYLki6TdKS7f87dn8/X0N03KZh+8Uyo+iwzOzvPKTgEMeobo8U4XdKXU8VeSR8q9z0Rv75hZn+lYIRS2ufd/VJ3fyHfOe4+4O4/d/dfljKWahezvvH20HFPKo71IzV0932S3qrgmSTtpWY20hqAOHTXSnqPgpFMU9z9FHd/n4IRyuOGZ9FYikPf4Fk0nuLQN/LiWRT5kHxEtXqpcqcd/YePskObu9+n4C/AaX9uZstLFMv7JDWEyle4+8Ao7b+gYGRE2odLFAcCsegb7n6fu3/rIH0h3H6fpE9Gqv98rHEgRyz6xkF8ScEUSUn6lILF31F+sekbZjZFwc+JtAck/Z9SXBuHJBZ9w8waFOxQm3bzSAnQSCyDkr4dvozyb5SDQ+Dun3D3b7v7Hwr9eV8mPIvGTBz6Bs+i8RSHvnEQPItiRCQfUa1eFylfXcA5346UX1uaUHJi2SjpttEap34ZuCZUdZiZFbNwPEYXp75RrNsj5WUViWLyinXfSC0M/7ZUcZWkz5frXjhAnPrGJZLaQuUr3H04T1uUX1z6xvRIudBfBldHytNGbIWJjmdRlArPolWMZ1GMhuQjqtWFoeO17r62gHPuVjB0PG3Mf8kzsyWSjg5V3e4eLJRxEL+OlPmrYunEom8cov2R8kRYjHoiiW3fSG0u8s1U0SVdFtO/hk9Wceob7wkdP+Pud5foujg0cekbnQq+N6QV+vOhJVLOO3UfExPPoigxnkWrFM+iOBiSj6g6ZtamYFHktAcKOc/d+yX9PlRViqlHJ0TKBcWiYGHhwRLHUvVi1jcOxZJIeVtFopiEJkDf+JSkxanjq939njLdBxFx6htmNkPBztVpt471mjh0ceob7t6lYJfatPMKPPWloeP0xgaYXHgWRSnxLFq9eBbFqEg+ohodHSkXsw5FeMRCu5nNqUQs7t4raWuoakW+tihKnPrGoXh9pHx/BWKYrGLbN8zsFGUX896uYEdKjJ849Y1TIuX7pWDxdzP7iJndY2bPm1lf6t/7zOxTZnb4GO+LkcWpb0jSf4aOjzWzUTcJMbOTJb0rVPUtd99bgjgQLzyLopR4Fq1CPIuiECQfUY2WRsqbijg32jZ6rUrFMtY4EIhT3yiKmbVIujxU1S/pxvGMYZKLZd8ws1pJ31H25/lH3L2jVNdHQeLUN14cKT9tZm+Q9LSk/yfpDElzJNWl/j1N0sckPWVmXzOz+jHeH7ni1DekYI2+8M+F/0x93Y8KNzKzOWZ2haTfSkr3iYckXVmCGBA/PIuiJHgWrU48i6JQJB9RjVoj5d1FnBv9RjolJrHU8ktjScSpbxTri5LmhsrfcHemupROXPvGP0o6NnV8m7v/oITXRmHi1DdmRsrnKtg5eUaq7JJ2SHpe0lCoXVLBbre/MbPGMcaArDj1DaXW8XuTpKsUTJc1BV/3p8xsj5mtN7N0//icgrXaBiR9XdJLU1O3MfnwLIpS4Vm0OvEsioKQfEQ1ii6e3jtiq5H1HORaEzkWTNCvh5ldqtxNJjZJ+vh43b9KxK5vmNnRCkatpe/xvlJcF0WLU99oi5S/qCDB1CfpnyXNd/dZ7j5Pwe7Hlys30XCGgkQTSiNOfUNSsJ6ku39EwS+Kd4bealWwVteMUN0mSa9198vdPbqJBCaP2PVTTDw8i1YnnkVRDJKPqEYNkXJ/Eef2RcpjHSESp1gwAb8eZnaOpG+HqgYkvYV1uUouVn3DzEzB1z09yuRf3H3dWK+LQxKnvhH9xb9WwfeEV7n7J939+fQb7r7H3b8u6UxJu0LnvD211h/GLk59Q5JkZgkz+4ikuySdc5DmCyXdYma/NjOm1E5eseunmFh4Fq1OPIuiWCQfUY2if9GtK+Lc6HSS6F98J3IsmGBfDzM7SdIvlI3TJb3T3Vncu/Ti1jcuVzBKTZKeUDDCDZURp74x0oilL7r7HflOcPenJP1tpPrDY4wDgTj1DZlZg6SbFaz/OStVfbuk1yqYKlknqV1BUvLbyk7NP1/SI2Z24lhjQCzFqp9iYuFZtKrxLIqikHxENYpOHYr+xXc00b/ojnUaUpxiwQT6epjZcZJ+pdy1mi539++X875VLDZ9w8wWSPpMquiSLnP3gbFcE2MSm74haV+k7JL+o4DzrlewO2Xa+WOMA4E49Q1J+rKkV4bKV7r7Be5+o7tvc/cBd+9097vc/T2SXqZsYqpd0s9SG0pgcolbP8UEwbNo9eJZFIeC5COqUXQKQHsR57ZFytFf9IpVqlgG3D069QXFi1PfyCu1M+ntCtZsS/uwu3+jXPdErPrG15XdfOIbjC6ouDj1jWgsT4enWufj7oOS7glVzTKzw8YYC2LUN1Lrcv11qOoX7v6ZfO0lKTVi9mOhqkWSLhtLHIglnkVRNJ5Fqx7PoigayUdUo/WR8sIizl0UKY91XYtSxcL6GqURp74xIjM7XNIdyk6Zk6R/dPcvl+N+yIhF3zCz10i6MFXcJun/HOq1UDKx6BspayPlTUWcuzFSju6cjeLFqW+8RcHmQ2lfKfC8byp3DcDXjzEOxA/PoigKz6LVjWdRHKqaSgcAVMCqSHl5EecuCx13uPu2MsRy50gNw1LrNs0b5To4NHHqGwdILfh/h4K1udI+4e6fK/W9cIC49I3wpg9Nkn4frPedV/Tn/IfM7JJQ+VPu/t0xxIP49A1J+lOkXMyutdG2xUy9xMji1DeOj5QfKeQkd+8ys6dD5x8zxjgQPzyLomA8i0I8i+IQkXxE1XH3TjPbpOxfdk8r5Dwzq5N0UqjqiRKE81ikfJqk7xRw3inK/f9biliqXsz6RvQeiyT9VlJ4KuSn3P1fS30vHCimfaNVuessFaJduVPq2koWTZWKWd94UsEmIclUeVoR50bb7hqxFQoWs77RHCkXszZfV+iY3YwnH55FURCeRTECnkVRMKZdo1rdGjpelvor3sGcpdyRIDePNQh3Xy/p6VDV+XaQPx2lXBApjzkWZMSib4Sl1l67Q7lToT7n7h8v5X1wULHrG4iNWPQNd9+j3BFLx5tZoc96Lw4dD0jaPNZ4ICkmfUNSR6Q8p4hzwyOcSEpPMjyLohA8iwIYK5KPqFY/j5T/esRWo7e5oTSh5MSySMHuknmZWY2kd4aqtqjA6VMoSJz6hsxsroKHvfAvrP/P3f+xVPdAwSreN9z9Kne3Ql+SlkQu8clIm6vGEg8yKt43Qn4aOp6qg/xMkSQzWyLp5FDVA+7eXaJ4ql1c+saaSDmaOBpRam23xaGqZ0sQC+KHZ1HkxbMowngWxaEi+YhqdbuC6WlpH0z98jUiMztN0kWhqlvcfXWetovNzEOv3x0klq9LCu8O+Hkzqx2l/d9Lmh8qX+XufpB7oHCx6RtmNjMVz+Gh6v9w97872AeBsohN30DsxKlvXCdpe6j82dQ03tF8UbnPhP99kPYoXFz6xi8j5SvNbMqILXNF13H7VQHnoMJ4FkU+PIsiH55FUW4kH1GV3H1Y0pWhqmZJN5nZgmhbMzte0k+U/f8yLOljJYzlOUlfDVUdL+n7ZlY/Qix/KemToaotKnzHShQgLn3DzNol/VrSilD119z9b0pxfRQvLn0D8ROnvuHu+yX931DVCZJ+lvqeEo2l3sy+Kul1oepnJV1bqniqXVz6hrvfLenhUNUySbemplIewMyazOxq5faNvZK+XYp4EC88i2IkPIsCKCU2nEHVcvebzOxrki5PVR0j6Skz+76kRyXVSjpV0htTx2kfdffo4txj9XFJZ0t6Sap8kaTTzew6SesULMr7KknnhM7pk/QWdy9mN1MUICZ94wMKkgZhrzCz6NS50Wx293NLFA8Um76BGIpZ3/iWgp8Xf5kqXyhpjZn9WNLjkgYVjGJ5k4Iplmn7Jb3B3QdKHE9Vi1HfuEzSXZJaUuUzFfSLX0h6UMF6js0KEk9vkDQ9cv7fuPvOEsZT9czs9ZI+P8Jb0VGp3zeznhHaXeHuPytRODyLxkhM+gbPojEUk74BFI3kI6rdhxR8o35bqtws6T152rqkz7r7F0odhLt3m9mrJd0i6cRU9XxJ+dZS2Sfp7e5+T6ljQUal+0ZyhLpCNioI43t8eVS6byC+YtE33N3N7B0KRtC9OVU9TdJ7Rzlti6TXufuTo7TBoat433D3P5rZhZJ+qOwmMvUKkkwX5T1R6pX0EXf/binjgaRgl9hlBbSbN8r5JcGzaOzEoW/wLBpPcegbQNGYdo2q5u5D7n6pgl/ORvuF6wFJ57v7laO0GWss2xSMfPiEpG15mvUrWBT8BHePLmKPEopT30C80DeQT5z6hrv3u/tbFIxufHSUpnsUjKA4wd0fHqUdxiAufcPd75J0rKR/U/5njbRuSddIerG7f6Mc8SBeeBYFAJSLsTYwkGVmxyqYbjRP0pCkrZIedvd14xxHUtLpkpZLmq3gr8ubJd3t7rvHMxYE4tI3ED/0DeQTp75hZkdIenEqljoFU2xXSXrI3QfHO55qF4e+YWYm6WhJL5I0U8HIzB5JuxX0jUfdvS/vBTCp8SwKACglko8AAAAAAAAAyoJp1wAAAAAAAADKguQjAAAAAAAAgLIg+QgAAAAAAACgLEg+AgAAAAAAACgLko8AAAAAAAAAyoLkIwAAAAAAAICyIPkIAAAAAAAAoCxIPgIAAAAAAAAoC5KPAAAAAAAAAMqC5CMAAAAAAACAsiD5CAAAAAAAAKAsSD4CAAAAAAAAKAuSjwAAAAAAAADKguQjAAAAAAAAgLIg+QgAAAAAAACMwMwSZnaMmb3dzP7TzO43s24z89Dr3ErHmWZmGyKxHcrrd6WMqaaUFwMAAAAAAAAmAzP7H0kvl9Rc6VjGWWcpL0byEQAAAAAAADjQSZp4iccNkgaLPGeepMZQ+Qcli0YkHwEAAAAAAICD6ZP0uKTfS2qRdEllwxmZu59bTHszq5e0Rdnk4y5JN5QyJpKPAAAAAAAAwIGulfScgoTjE+4+IElm9g7FNPl4CF4raXqofJ2795XyBiQfAQAAAAAAgAh3/8R43cvMTNKJklZImiXJJG2X9Ad3/1MZb/3uSPk7pb4ByUcAAAAAAACgAsxsiqSPKkgCzs7TZrWk/+vuJV2L0cwWS3ppqOpBd3+ylPeQpESpLwgAAAAAAABgdGZ2qqTVkj6mPInHlMMlXW9mPzaz2hKG8C4FIyzTri7htTMY+QgAAAAAAACMIzP7M0k3S2oKVT+TqlurYMfqIyW9SdKC1PsXSXJJby7B/ROS3hGq6pL0o7FedyQkHwEAAAAAAIBxYmazJP1A2cRjr6T3S7rG3T3S9uOSviTpslTVm8zsZne/boxhvEzZpKYk/cjd943xmiNi2jUAAAAAAAAwfj6r7DTrYUmvc/f/iiYeJcnde9z9vZL+J1T9r6mRi2MR3WimLFOuJZKPAAAAAAAAwLgwszmS3hqqutrdf1nAqR+SNJA6XiTpVWOIYaak14SqVrn7/Yd6vYMh+QgAAAAAAACMjzdKqguVv1TISe6+VdLtoaoLxhDDpZLCG9d8ZwzXOiiSjwAAAAAAAMD4OCt0vM7dny7i3IdCxyvHEMO7Qsf9kq4dw7UOiuQjAAAAAAAAMD5OCB3/qchzt4eODzuUm5vZaZJWhKpudPedh3KtQrHbNQAAAAAAADA+poeOX21mB2wyU6D2Qzxv3DaaSWPkIwAAAAAAADA+2kp0naZiTzCzFklvClVtVO46kmXByEcAAAAAAABgfHRLak0dd0jaPY73foukllD5GncfLvdNST4CAAAAAAAA42OnssnHn7j7ZeN4778KHQ9LumY8bsq0awAAAAAAAGB8hHe3Pma8bmpmx0g6NVR1m7tvGo97k3wEAAAAAAAAxsdvQ8enmtmMcbrvX0XK3xmn+5J8BAAAAAAAAMbJTyUNpo6Tkv6h3Dc0szpJbwtV7ZB0Y7nvm0byEQAAAAAAABgH7r5B0g9CVX9rZi8r5hoWqCvilL+QFB5hea27DxRzz7Eg+QgAAAAAAACMnyskPZ86rpF0k5n9nZk1jHaSmc01sw8qWDfyxCLuV7Ep15Jk7j6e9wMAAAAAAABiz8xeL+nzI7w1RdKsUHmrpJ4R2l3h7j/Lc+3TJP1S2Z2vpWAn7F9JelTSbgXTstskHaEg2fhiSZZqe5q7P1DAx7BQ0nplByDe5+5nHOy8UqoZz5sBAAAAAAAAE0SrpGUFtJs3yvkjcvf7zexUSTcoSC5KwdTot6ZeBzNUQBtJeqdyZz5fXeB5JcO0awAAAAAAAGCcuftTko6V9F5Jqwo4ZZWkL0p6sbs/fLDGZmYKko9p+yT9+BBCHROmXQMAAAAAAAAVZmbzJZ0qabakdkn9kjokrZX0pLvvqGB4h4zkIwAAAAAAAICyYNo1AAAAAAAAgLIg+QgAAAAAAACgLEg+AgAAAAAAACgLko8AAAAAAAAAyoLkIwAAAAAAAICyIPkIAAAAAAAAoCxIPgIAAAAAAAAoC5KPAAAAAAAAAMqC5CMAAAAAAACAsiD5CAAAAAAAAKAsSD4CAAAAAAAAKAuSjwAAAAAAAADKguQjAAAAAAAAgLIg+QgAAAAAAACgLEg+AgAAAAAAACgLko8AAAAAAAAAyoLkIwAAAAAAAICyIPkIAAAAAAAAoCxIPgIAAAAAAAAoC5KPAAAAAAAAAMri/wMZEmFf1HC9MAAAAABJRU5ErkJggg==\n",
|
217 |
+
"text/plain": [
|
218 |
+
"<Figure size 1500x750 with 1 Axes>"
|
219 |
+
]
|
220 |
+
},
|
221 |
+
"metadata": {
|
222 |
+
"needs_background": "light"
|
223 |
+
},
|
224 |
+
"output_type": "display_data"
|
225 |
+
}
|
226 |
+
],
|
227 |
+
"source": [
|
228 |
+
"gene_detection_counts = [i for i in gene_detection_counts_dict.values()]\n",
|
229 |
+
"import seaborn as sns\n",
|
230 |
+
"import matplotlib.pyplot as plt\n",
|
231 |
+
"plt.figure(figsize=(10,5), dpi=150)\n",
|
232 |
+
"plt.rcParams.update({'font.size': 18})\n",
|
233 |
+
"count_plot = sns.distplot(gene_detection_counts).set_title(f\"# Cells Expressing Each\\nProtein-Coding or miRNA Gene\")"
|
234 |
+
]
|
235 |
+
},
|
236 |
+
{
|
237 |
+
"cell_type": "code",
|
238 |
+
"execution_count": 47,
|
239 |
+
"id": "missing-bradley",
|
240 |
+
"metadata": {},
|
241 |
+
"outputs": [
|
242 |
+
{
|
243 |
+
"data": {
|
244 |
+
"text/plain": [
|
245 |
+
"27454"
|
246 |
+
]
|
247 |
+
},
|
248 |
+
"execution_count": 47,
|
249 |
+
"metadata": {},
|
250 |
+
"output_type": "execute_result"
|
251 |
+
}
|
252 |
+
],
|
253 |
+
"source": [
|
254 |
+
"len(gene_detection_counts)"
|
255 |
+
]
|
256 |
+
},
|
257 |
+
{
|
258 |
+
"cell_type": "code",
|
259 |
+
"execution_count": 55,
|
260 |
+
"id": "perfect-signal",
|
261 |
+
"metadata": {},
|
262 |
+
"outputs": [
|
263 |
+
{
|
264 |
+
"data": {
|
265 |
+
"text/plain": [
|
266 |
+
"25424"
|
267 |
+
]
|
268 |
+
},
|
269 |
+
"execution_count": 55,
|
270 |
+
"metadata": {},
|
271 |
+
"output_type": "execute_result"
|
272 |
+
}
|
273 |
+
],
|
274 |
+
"source": [
|
275 |
+
"len([i for i in gene_detection_counts if i > 0])"
|
276 |
+
]
|
277 |
+
},
|
278 |
+
{
|
279 |
+
"cell_type": "code",
|
280 |
+
"execution_count": 56,
|
281 |
+
"id": "faced-theory",
|
282 |
+
"metadata": {},
|
283 |
+
"outputs": [
|
284 |
+
{
|
285 |
+
"data": {
|
286 |
+
"text/plain": [
|
287 |
+
"22735"
|
288 |
+
]
|
289 |
+
},
|
290 |
+
"execution_count": 56,
|
291 |
+
"metadata": {},
|
292 |
+
"output_type": "execute_result"
|
293 |
+
}
|
294 |
+
],
|
295 |
+
"source": [
|
296 |
+
"len([i for i in gene_detection_counts if i > 100])"
|
297 |
+
]
|
298 |
+
},
|
299 |
+
{
|
300 |
+
"cell_type": "code",
|
301 |
+
"execution_count": 57,
|
302 |
+
"id": "tough-workplace",
|
303 |
+
"metadata": {},
|
304 |
+
"outputs": [
|
305 |
+
{
|
306 |
+
"data": {
|
307 |
+
"text/plain": [
|
308 |
+
"21167"
|
309 |
+
]
|
310 |
+
},
|
311 |
+
"execution_count": 57,
|
312 |
+
"metadata": {},
|
313 |
+
"output_type": "execute_result"
|
314 |
+
}
|
315 |
+
],
|
316 |
+
"source": [
|
317 |
+
"len([i for i in gene_detection_counts if i > 1000])"
|
318 |
+
]
|
319 |
+
},
|
320 |
+
{
|
321 |
+
"cell_type": "code",
|
322 |
+
"execution_count": 49,
|
323 |
+
"id": "cooperative-camcorder",
|
324 |
+
"metadata": {},
|
325 |
+
"outputs": [
|
326 |
+
{
|
327 |
+
"data": {
|
328 |
+
"text/plain": [
|
329 |
+
"173152.0299000284"
|
330 |
+
]
|
331 |
+
},
|
332 |
+
"execution_count": 49,
|
333 |
+
"metadata": {},
|
334 |
+
"output_type": "execute_result"
|
335 |
+
}
|
336 |
+
],
|
337 |
+
"source": [
|
338 |
+
"gene_detection_event_digest = crick.tdigest.TDigest()\n",
|
339 |
+
"gene_detection_event_digest.update(gene_detection_counts)\n",
|
340 |
+
"gene_detection_event_digest.quantile(0.5)"
|
341 |
+
]
|
342 |
+
}
|
343 |
+
],
|
344 |
+
"metadata": {
|
345 |
+
"kernelspec": {
|
346 |
+
"display_name": "Python 3 (ipykernel)",
|
347 |
+
"language": "python",
|
348 |
+
"name": "python3"
|
349 |
+
},
|
350 |
+
"language_info": {
|
351 |
+
"codemirror_mode": {
|
352 |
+
"name": "ipython",
|
353 |
+
"version": 3
|
354 |
+
},
|
355 |
+
"file_extension": ".py",
|
356 |
+
"mimetype": "text/x-python",
|
357 |
+
"name": "python",
|
358 |
+
"nbconvert_exporter": "python",
|
359 |
+
"pygments_lexer": "ipython3",
|
360 |
+
"version": "3.10.11"
|
361 |
+
}
|
362 |
+
},
|
363 |
+
"nbformat": 4,
|
364 |
+
"nbformat_minor": 5
|
365 |
+
}
|
examples/pretraining_new_model/pretrain_geneformer_w_deepspeed.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
|
4 |
+
# run with:
|
5 |
+
# deepspeed --num_gpus=12 --num_nodes=3 pretrain_geneformer_w_deepspeed.py --deepspeed ds_config.json
|
6 |
+
|
7 |
+
import datetime
|
8 |
+
|
9 |
+
# imports
|
10 |
+
import os
|
11 |
+
|
12 |
+
os.environ["NCCL_DEBUG"] = "INFO"
|
13 |
+
os.environ["OMPI_MCA_opal_cuda_support"] = "true"
|
14 |
+
os.environ["CONDA_OVERRIDE_GLIBC"] = "2.56"
|
15 |
+
|
16 |
+
import pickle
|
17 |
+
import random
|
18 |
+
import subprocess
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import pytz
|
22 |
+
import torch
|
23 |
+
from datasets import load_from_disk
|
24 |
+
from transformers import BertConfig, BertForMaskedLM, TrainingArguments
|
25 |
+
|
26 |
+
from geneformer import GeneformerPretrainer
|
27 |
+
|
28 |
+
seed_num = 0
|
29 |
+
random.seed(seed_num)
|
30 |
+
np.random.seed(seed_num)
|
31 |
+
seed_val = 42
|
32 |
+
torch.manual_seed(seed_val)
|
33 |
+
torch.cuda.manual_seed_all(seed_val)
|
34 |
+
|
35 |
+
# set local time/directories
|
36 |
+
timezone = pytz.timezone("US/Eastern")
|
37 |
+
rootdir = "/parent_ouput_directory"
|
38 |
+
|
39 |
+
# set model parameters
|
40 |
+
# model type
|
41 |
+
model_type = "bert"
|
42 |
+
# max input size
|
43 |
+
max_input_size = 2**11 # 2048
|
44 |
+
# number of layers
|
45 |
+
num_layers = 6
|
46 |
+
# number of attention heads
|
47 |
+
num_attn_heads = 4
|
48 |
+
# number of embedding dimensions
|
49 |
+
num_embed_dim = 256
|
50 |
+
# intermediate size
|
51 |
+
intermed_size = num_embed_dim * 2
|
52 |
+
# activation function
|
53 |
+
activ_fn = "relu"
|
54 |
+
# initializer range, layer norm, dropout
|
55 |
+
initializer_range = 0.02
|
56 |
+
layer_norm_eps = 1e-12
|
57 |
+
attention_probs_dropout_prob = 0.02
|
58 |
+
hidden_dropout_prob = 0.02
|
59 |
+
|
60 |
+
|
61 |
+
# set training parameters
|
62 |
+
# total number of examples in Genecorpus-30M after QC filtering:
|
63 |
+
num_examples = 27_406_208
|
64 |
+
# number gpus
|
65 |
+
num_gpus = 12
|
66 |
+
# batch size for training and eval
|
67 |
+
geneformer_batch_size = 12
|
68 |
+
# max learning rate
|
69 |
+
max_lr = 1e-3
|
70 |
+
# learning schedule
|
71 |
+
lr_schedule_fn = "linear"
|
72 |
+
# warmup steps
|
73 |
+
warmup_steps = 10_000
|
74 |
+
# number of epochs
|
75 |
+
epochs = 3
|
76 |
+
# optimizer
|
77 |
+
optimizer = "adamw"
|
78 |
+
# weight_decay
|
79 |
+
weight_decay = 0.001
|
80 |
+
|
81 |
+
|
82 |
+
# output directories
|
83 |
+
current_date = datetime.datetime.now(tz=timezone)
|
84 |
+
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}_{current_date.strftime('%X').replace(':','')}"
|
85 |
+
run_name = f"{datestamp}_geneformer_30M_L{num_layers}_emb{num_embed_dim}_SL{max_input_size}_E{epochs}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_O{optimizer}_DS{num_gpus}"
|
86 |
+
training_output_dir = f"{rootdir}/models/{run_name}/"
|
87 |
+
logging_dir = f"{rootdir}/runs/{run_name}/"
|
88 |
+
model_output_dir = os.path.join(training_output_dir, "models/")
|
89 |
+
|
90 |
+
|
91 |
+
# ensure not overwriting previously saved model
|
92 |
+
model_output_file = os.path.join(model_output_dir, "pytorch_model.bin")
|
93 |
+
if os.path.isfile(model_output_file) is True:
|
94 |
+
raise Exception("Model already saved to this directory.")
|
95 |
+
|
96 |
+
|
97 |
+
# make training and model output directories
|
98 |
+
subprocess.call(f"mkdir {training_output_dir}", shell=True)
|
99 |
+
subprocess.call(f"mkdir {model_output_dir}", shell=True)
|
100 |
+
|
101 |
+
|
102 |
+
# load gene_ensembl_id:token dictionary (e.g. https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/blob/main/token_dictionary.pkl)
|
103 |
+
with open("token_dictionary.pkl", "rb") as fp:
|
104 |
+
token_dictionary = pickle.load(fp)
|
105 |
+
|
106 |
+
# model configuration
|
107 |
+
config = {
|
108 |
+
"hidden_size": num_embed_dim,
|
109 |
+
"num_hidden_layers": num_layers,
|
110 |
+
"initializer_range": initializer_range,
|
111 |
+
"layer_norm_eps": layer_norm_eps,
|
112 |
+
"attention_probs_dropout_prob": attention_probs_dropout_prob,
|
113 |
+
"hidden_dropout_prob": hidden_dropout_prob,
|
114 |
+
"intermediate_size": intermed_size,
|
115 |
+
"hidden_act": activ_fn,
|
116 |
+
"max_position_embeddings": max_input_size,
|
117 |
+
"model_type": model_type,
|
118 |
+
"num_attention_heads": num_attn_heads,
|
119 |
+
"pad_token_id": token_dictionary.get("<pad>"),
|
120 |
+
"vocab_size": len(token_dictionary), # genes+2 for <mask> and <pad> tokens
|
121 |
+
}
|
122 |
+
|
123 |
+
config = BertConfig(**config)
|
124 |
+
model = BertForMaskedLM(config)
|
125 |
+
model = model.train()
|
126 |
+
|
127 |
+
# define the training arguments
|
128 |
+
training_args = {
|
129 |
+
"learning_rate": max_lr,
|
130 |
+
"do_train": True,
|
131 |
+
"do_eval": False,
|
132 |
+
"group_by_length": True,
|
133 |
+
"length_column_name": "length",
|
134 |
+
"disable_tqdm": False,
|
135 |
+
"lr_scheduler_type": lr_schedule_fn,
|
136 |
+
"warmup_steps": warmup_steps,
|
137 |
+
"weight_decay": weight_decay,
|
138 |
+
"per_device_train_batch_size": geneformer_batch_size,
|
139 |
+
"num_train_epochs": epochs,
|
140 |
+
"save_strategy": "steps",
|
141 |
+
"save_steps": np.floor(num_examples / geneformer_batch_size / 8), # 8 saves per epoch
|
142 |
+
"logging_steps": 1000,
|
143 |
+
"output_dir": training_output_dir,
|
144 |
+
"logging_dir": logging_dir,
|
145 |
+
}
|
146 |
+
training_args = TrainingArguments(**training_args)
|
147 |
+
|
148 |
+
print("Starting training.")
|
149 |
+
|
150 |
+
# define the trainer
|
151 |
+
trainer = GeneformerPretrainer(
|
152 |
+
model=model,
|
153 |
+
args=training_args,
|
154 |
+
# pretraining corpus (e.g. https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/genecorpus_30M_2048.dataset)
|
155 |
+
train_dataset=load_from_disk("genecorpus_30M_2048.dataset"),
|
156 |
+
# file of lengths of each example cell (e.g. https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/blob/main/genecorpus_30M_2048_lengths.pkl)
|
157 |
+
example_lengths_file="genecorpus_30M_2048_lengths.pkl",
|
158 |
+
token_dictionary=token_dictionary,
|
159 |
+
)
|
160 |
+
|
161 |
+
# train
|
162 |
+
trainer.train()
|
163 |
+
|
164 |
+
# save model
|
165 |
+
trainer.save_model(model_output_dir)
|
examples/tokenizing_scRNAseq_data.ipynb
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "a91bca46-c056-4784-8c6c-b0f5d3f33496",
|
6 |
+
"metadata": {
|
7 |
+
"tags": []
|
8 |
+
},
|
9 |
+
"source": [
|
10 |
+
"## Tokenizing .loom single cell RNA-seq data to rank value encoding .dataset format"
|
11 |
+
]
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"cell_type": "markdown",
|
15 |
+
"id": "350e6252-b783-494b-9767-f087eb868a15",
|
16 |
+
"metadata": {},
|
17 |
+
"source": [
|
18 |
+
"#### Input data is a directory with .loom files containing raw counts from single cell RNAseq data, including all genes detected in the transcriptome without feature selection. \n",
|
19 |
+
"\n",
|
20 |
+
"#### Genes should be labeled with Ensembl IDs (row attribute \"ensembl_id\"), which provide a unique identifer for conversion to tokens. Other forms of gene annotations (e.g. gene names) can be converted to Ensembl IDs via Ensembl Biomart. Cells should be labeled with the total read count in the cell (column attribute \"n_counts\") to be used for normalization.\n",
|
21 |
+
"\n",
|
22 |
+
"#### No cell metadata is required, but custom cell attributes may be passed onto the tokenized dataset by providing a dictionary of custom attributes to be added, which is formatted as loom_col_attr_name : desired_dataset_col_attr_name. For example, if the original .loom dataset has column attributes \"cell_type\" and \"organ_major\" and one would like to retain these attributes as labels in the tokenized dataset with the new names \"cell_type\" and \"organ\", respectively, the following custom attribute dictionary should be provided: {\"cell_type\": \"cell_type\", \"organ_major\": \"organ\"}. \n",
|
23 |
+
"\n",
|
24 |
+
"#### Additionally, if the original .loom file contains a cell column attribute called \"filter_pass\", this column will be used as a binary indicator of whether to include these cells in the tokenized data. All cells with \"1\" in this attribute will be tokenized, whereas the others will be excluded. One may use this column to indicate QC filtering or other criteria for selection for inclusion in the final tokenized dataset.\n",
|
25 |
+
"\n",
|
26 |
+
"#### If one's data is in other formats besides .loom, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom format prior to running the transcriptome tokenizer."
|
27 |
+
]
|
28 |
+
},
|
29 |
+
{
|
30 |
+
"cell_type": "code",
|
31 |
+
"execution_count": null,
|
32 |
+
"id": "080fdd9c-0c48-4d5d-a254-52b6c53cdf78",
|
33 |
+
"metadata": {},
|
34 |
+
"outputs": [],
|
35 |
+
"source": [
|
36 |
+
"from geneformer import TranscriptomeTokenizer"
|
37 |
+
]
|
38 |
+
},
|
39 |
+
{
|
40 |
+
"cell_type": "code",
|
41 |
+
"execution_count": null,
|
42 |
+
"id": "37205758-aa52-4443-a383-0638519ee8a9",
|
43 |
+
"metadata": {},
|
44 |
+
"outputs": [],
|
45 |
+
"source": [
|
46 |
+
"tk = TranscriptomeTokenizer({\"cell_type\": \"cell_type\", \"organ_major\": \"organ_major\"}, nproc=4)\n",
|
47 |
+
"tk.tokenize_data(\"loom_data_directory\", \"output_directory\", \"output_prefix\")"
|
48 |
+
]
|
49 |
+
}
|
50 |
+
],
|
51 |
+
"metadata": {
|
52 |
+
"kernelspec": {
|
53 |
+
"display_name": "Python 3 (ipykernel)",
|
54 |
+
"language": "python",
|
55 |
+
"name": "python3"
|
56 |
+
},
|
57 |
+
"language_info": {
|
58 |
+
"codemirror_mode": {
|
59 |
+
"name": "ipython",
|
60 |
+
"version": 3
|
61 |
+
},
|
62 |
+
"file_extension": ".py",
|
63 |
+
"mimetype": "text/x-python",
|
64 |
+
"name": "python",
|
65 |
+
"nbconvert_exporter": "python",
|
66 |
+
"pygments_lexer": "ipython3",
|
67 |
+
"version": "3.10.11"
|
68 |
+
}
|
69 |
+
},
|
70 |
+
"nbformat": 4,
|
71 |
+
"nbformat_minor": 5
|
72 |
+
}
|
fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224/config.json
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "/n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/",
|
3 |
+
"architectures": [
|
4 |
+
"BertForSequenceClassification"
|
5 |
+
],
|
6 |
+
"attention_probs_dropout_prob": 0.02,
|
7 |
+
"gradient_checkpointing": false,
|
8 |
+
"hidden_act": "relu",
|
9 |
+
"hidden_dropout_prob": 0.02,
|
10 |
+
"hidden_size": 256,
|
11 |
+
"id2label": {
|
12 |
+
"0": "LABEL_0",
|
13 |
+
"1": "LABEL_1",
|
14 |
+
"2": "LABEL_2"
|
15 |
+
},
|
16 |
+
"initializer_range": 0.02,
|
17 |
+
"intermediate_size": 512,
|
18 |
+
"label2id": {
|
19 |
+
"LABEL_0": 0,
|
20 |
+
"LABEL_1": 1,
|
21 |
+
"LABEL_2": 2
|
22 |
+
},
|
23 |
+
"layer_norm_eps": 1e-12,
|
24 |
+
"max_position_embeddings": 2048,
|
25 |
+
"model_type": "bert",
|
26 |
+
"num_attention_heads": 4,
|
27 |
+
"num_hidden_layers": 6,
|
28 |
+
"pad_token_id": 0,
|
29 |
+
"position_embedding_type": "absolute",
|
30 |
+
"problem_type": "single_label_classification",
|
31 |
+
"transformers_version": "4.6.0",
|
32 |
+
"type_vocab_size": 2,
|
33 |
+
"use_cache": true,
|
34 |
+
"vocab_size": 25426
|
35 |
+
}
|
fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224/trainer_state.json
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"best_metric": 0.39658036828041077,
|
3 |
+
"best_model_checkpoint": "/n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/models/220224_geneformer_27M_SequenceClassifier_tuning_hCMdCM_L2048_B12_LR1e-05_LScosine_WU500_E1_Oadamw_F2/run-8429a330/checkpoint-7020",
|
4 |
+
"epoch": 0.9,
|
5 |
+
"global_step": 7020,
|
6 |
+
"is_hyper_param_search": true,
|
7 |
+
"is_local_process_zero": true,
|
8 |
+
"is_world_process_zero": true,
|
9 |
+
"log_history": [
|
10 |
+
{
|
11 |
+
"epoch": 0.1,
|
12 |
+
"learning_rate": 0.00034606438343856935,
|
13 |
+
"loss": 0.911,
|
14 |
+
"step": 780
|
15 |
+
},
|
16 |
+
{
|
17 |
+
"epoch": 0.1,
|
18 |
+
"eval_accuracy": 0.4531576503366612,
|
19 |
+
"eval_loss": 1.4550466537475586,
|
20 |
+
"eval_runtime": 66.5164,
|
21 |
+
"eval_samples_per_second": 259.004,
|
22 |
+
"step": 780
|
23 |
+
},
|
24 |
+
{
|
25 |
+
"epoch": 0.2,
|
26 |
+
"learning_rate": 0.0006921287668771387,
|
27 |
+
"loss": 0.6273,
|
28 |
+
"step": 1560
|
29 |
+
},
|
30 |
+
{
|
31 |
+
"epoch": 0.2,
|
32 |
+
"eval_accuracy": 0.5953680055723242,
|
33 |
+
"eval_loss": 0.846651554107666,
|
34 |
+
"eval_runtime": 66.1267,
|
35 |
+
"eval_samples_per_second": 260.53,
|
36 |
+
"step": 1560
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"epoch": 0.3,
|
40 |
+
"learning_rate": 0.0007330550166223805,
|
41 |
+
"loss": 0.5592,
|
42 |
+
"step": 2340
|
43 |
+
},
|
44 |
+
{
|
45 |
+
"epoch": 0.3,
|
46 |
+
"eval_accuracy": 0.5935105641978176,
|
47 |
+
"eval_loss": 1.0599186420440674,
|
48 |
+
"eval_runtime": 66.2608,
|
49 |
+
"eval_samples_per_second": 260.003,
|
50 |
+
"step": 2340
|
51 |
+
},
|
52 |
+
{
|
53 |
+
"epoch": 0.4,
|
54 |
+
"learning_rate": 0.0006283471571048975,
|
55 |
+
"loss": 0.3714,
|
56 |
+
"step": 3120
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"epoch": 0.4,
|
60 |
+
"eval_accuracy": 0.686324587880195,
|
61 |
+
"eval_loss": 1.184874415397644,
|
62 |
+
"eval_runtime": 66.1411,
|
63 |
+
"eval_samples_per_second": 260.473,
|
64 |
+
"step": 3120
|
65 |
+
},
|
66 |
+
{
|
67 |
+
"epoch": 0.5,
|
68 |
+
"learning_rate": 0.0005236392975874146,
|
69 |
+
"loss": 0.2976,
|
70 |
+
"step": 3900
|
71 |
+
},
|
72 |
+
{
|
73 |
+
"epoch": 0.5,
|
74 |
+
"eval_accuracy": 0.7681100534014396,
|
75 |
+
"eval_loss": 0.6318939328193665,
|
76 |
+
"eval_runtime": 66.3309,
|
77 |
+
"eval_samples_per_second": 259.728,
|
78 |
+
"step": 3900
|
79 |
+
},
|
80 |
+
{
|
81 |
+
"epoch": 0.6,
|
82 |
+
"learning_rate": 0.0004189314380699318,
|
83 |
+
"loss": 0.2564,
|
84 |
+
"step": 4680
|
85 |
+
},
|
86 |
+
{
|
87 |
+
"epoch": 0.6,
|
88 |
+
"eval_accuracy": 0.7807058277223126,
|
89 |
+
"eval_loss": 0.7283642888069153,
|
90 |
+
"eval_runtime": 66.3416,
|
91 |
+
"eval_samples_per_second": 259.686,
|
92 |
+
"step": 4680
|
93 |
+
},
|
94 |
+
{
|
95 |
+
"epoch": 0.7,
|
96 |
+
"learning_rate": 0.0003142235785524487,
|
97 |
+
"loss": 0.2336,
|
98 |
+
"step": 5460
|
99 |
+
},
|
100 |
+
{
|
101 |
+
"epoch": 0.7,
|
102 |
+
"eval_accuracy": 0.8563965637334572,
|
103 |
+
"eval_loss": 0.5184123516082764,
|
104 |
+
"eval_runtime": 66.3416,
|
105 |
+
"eval_samples_per_second": 259.686,
|
106 |
+
"step": 5460
|
107 |
+
},
|
108 |
+
{
|
109 |
+
"epoch": 0.8,
|
110 |
+
"learning_rate": 0.0002095157190349659,
|
111 |
+
"loss": 0.1731,
|
112 |
+
"step": 6240
|
113 |
+
},
|
114 |
+
{
|
115 |
+
"epoch": 0.8,
|
116 |
+
"eval_accuracy": 0.8288832133735778,
|
117 |
+
"eval_loss": 0.5823884010314941,
|
118 |
+
"eval_runtime": 66.1535,
|
119 |
+
"eval_samples_per_second": 260.425,
|
120 |
+
"step": 6240
|
121 |
+
},
|
122 |
+
{
|
123 |
+
"epoch": 0.9,
|
124 |
+
"learning_rate": 0.00010480785951748295,
|
125 |
+
"loss": 0.1451,
|
126 |
+
"step": 7020
|
127 |
+
},
|
128 |
+
{
|
129 |
+
"epoch": 0.9,
|
130 |
+
"eval_accuracy": 0.886812166241003,
|
131 |
+
"eval_loss": 0.39658036828041077,
|
132 |
+
"eval_runtime": 66.3555,
|
133 |
+
"eval_samples_per_second": 259.632,
|
134 |
+
"step": 7020
|
135 |
+
}
|
136 |
+
],
|
137 |
+
"max_steps": 7800,
|
138 |
+
"num_train_epochs": 1,
|
139 |
+
"total_flos": 0,
|
140 |
+
"trial_name": null,
|
141 |
+
"trial_params": {
|
142 |
+
"learning_rate": 0.0008039341830649843,
|
143 |
+
"lr_scheduler_type": "polynomial",
|
144 |
+
"num_train_epochs": 1,
|
145 |
+
"per_device_train_batch_size": 12,
|
146 |
+
"seed": 73.15243080311434,
|
147 |
+
"warmup_steps": 1812.6785581609881,
|
148 |
+
"weight_decay": 0.2588277764570262
|
149 |
+
}
|
150 |
+
}
|
geneformer-12L-30M/config.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"BertForMaskedLM"
|
4 |
+
],
|
5 |
+
"attention_probs_dropout_prob": 0.02,
|
6 |
+
"gradient_checkpointing": false,
|
7 |
+
"hidden_act": "relu",
|
8 |
+
"hidden_dropout_prob": 0.02,
|
9 |
+
"hidden_size": 512,
|
10 |
+
"initializer_range": 0.02,
|
11 |
+
"intermediate_size": 1024,
|
12 |
+
"layer_norm_eps": 1e-12,
|
13 |
+
"max_position_embeddings": 2048,
|
14 |
+
"model_type": "bert",
|
15 |
+
"num_attention_heads": 8,
|
16 |
+
"num_hidden_layers": 12,
|
17 |
+
"pad_token_id": 0,
|
18 |
+
"position_embedding_type": "absolute",
|
19 |
+
"transformers_version": "4.6.0",
|
20 |
+
"type_vocab_size": 2,
|
21 |
+
"use_cache": true,
|
22 |
+
"vocab_size": 25426
|
23 |
+
}
|
geneformer/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from . import tokenizer
|
2 |
+
from . import pretrainer
|
3 |
+
from . import collator_for_classification
|
4 |
+
from . import in_silico_perturber
|
5 |
+
from . import in_silico_perturber_stats
|
6 |
+
from .tokenizer import TranscriptomeTokenizer
|
7 |
+
from .pretrainer import GeneformerPretrainer
|
8 |
+
from .collator_for_classification import DataCollatorForGeneClassification
|
9 |
+
from .collator_for_classification import DataCollatorForCellClassification
|
10 |
+
from .emb_extractor import EmbExtractor
|
11 |
+
from .in_silico_perturber import InSilicoPerturber
|
12 |
+
from .in_silico_perturber_stats import InSilicoPerturberStats
|
geneformer/collator_for_classification.py
ADDED
@@ -0,0 +1,602 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Geneformer collator for gene and cell classification.
|
3 |
+
|
4 |
+
Huggingface data collator modified to accommodate single-cell transcriptomics data for gene and cell classification.
|
5 |
+
"""
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import warnings
|
9 |
+
from enum import Enum
|
10 |
+
from typing import Dict, List, Optional, Union
|
11 |
+
|
12 |
+
from transformers import (
|
13 |
+
DataCollatorForTokenClassification,
|
14 |
+
SpecialTokensMixin,
|
15 |
+
BatchEncoding,
|
16 |
+
)
|
17 |
+
from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
|
18 |
+
from transformers.utils.generic import _is_tensorflow, _is_torch
|
19 |
+
|
20 |
+
from .pretrainer import token_dictionary
|
21 |
+
|
22 |
+
EncodedInput = List[int]
|
23 |
+
logger = logging.get_logger(__name__)
|
24 |
+
VERY_LARGE_INTEGER = int(
|
25 |
+
1e30
|
26 |
+
) # This is used to set the max input length for a model with infinite size input
|
27 |
+
LARGE_INTEGER = int(
|
28 |
+
1e20
|
29 |
+
) # This is used when we need something big but slightly smaller than VERY_LARGE_INTEGER
|
30 |
+
|
31 |
+
# precollator functions
|
32 |
+
|
33 |
+
class ExplicitEnum(Enum):
|
34 |
+
"""
|
35 |
+
Enum with more explicit error message for missing values.
|
36 |
+
"""
|
37 |
+
|
38 |
+
@classmethod
|
39 |
+
def _missing_(cls, value):
|
40 |
+
raise ValueError(
|
41 |
+
"%r is not a valid %s, please select one of %s"
|
42 |
+
% (value, cls.__name__, str(list(cls._value2member_map_.keys())))
|
43 |
+
)
|
44 |
+
|
45 |
+
class TruncationStrategy(ExplicitEnum):
|
46 |
+
"""
|
47 |
+
Possible values for the ``truncation`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
|
48 |
+
tab-completion in an IDE.
|
49 |
+
"""
|
50 |
+
|
51 |
+
ONLY_FIRST = "only_first"
|
52 |
+
ONLY_SECOND = "only_second"
|
53 |
+
LONGEST_FIRST = "longest_first"
|
54 |
+
DO_NOT_TRUNCATE = "do_not_truncate"
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
class PaddingStrategy(ExplicitEnum):
|
59 |
+
"""
|
60 |
+
Possible values for the ``padding`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for tab-completion
|
61 |
+
in an IDE.
|
62 |
+
"""
|
63 |
+
|
64 |
+
LONGEST = "longest"
|
65 |
+
MAX_LENGTH = "max_length"
|
66 |
+
DO_NOT_PAD = "do_not_pad"
|
67 |
+
|
68 |
+
|
69 |
+
|
70 |
+
class TensorType(ExplicitEnum):
|
71 |
+
"""
|
72 |
+
Possible values for the ``return_tensors`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
|
73 |
+
tab-completion in an IDE.
|
74 |
+
"""
|
75 |
+
|
76 |
+
PYTORCH = "pt"
|
77 |
+
TENSORFLOW = "tf"
|
78 |
+
NUMPY = "np"
|
79 |
+
JAX = "jax"
|
80 |
+
|
81 |
+
|
82 |
+
class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
83 |
+
mask_token = "<mask>"
|
84 |
+
mask_token_id = token_dictionary.get("<mask>")
|
85 |
+
pad_token = "<pad>"
|
86 |
+
pad_token_id = token_dictionary.get("<pad>")
|
87 |
+
padding_side = "right"
|
88 |
+
all_special_ids = [
|
89 |
+
token_dictionary.get("<mask>"),
|
90 |
+
token_dictionary.get("<pad>")
|
91 |
+
]
|
92 |
+
model_input_names = ["input_ids"]
|
93 |
+
|
94 |
+
def _get_padding_truncation_strategies(
|
95 |
+
self, padding=True, truncation=False, max_length=None, pad_to_multiple_of=None, verbose=True, **kwargs
|
96 |
+
):
|
97 |
+
"""
|
98 |
+
Find the correct padding/truncation strategy with backward compatibility for old arguments (truncation_strategy
|
99 |
+
and pad_to_max_length) and behaviors.
|
100 |
+
"""
|
101 |
+
old_truncation_strategy = kwargs.pop("truncation_strategy", "do_not_truncate")
|
102 |
+
old_pad_to_max_length = kwargs.pop("pad_to_max_length", False)
|
103 |
+
|
104 |
+
# Backward compatibility for previous behavior, maybe we should deprecate it:
|
105 |
+
# If you only set max_length, it activates truncation for max_length
|
106 |
+
if max_length is not None and padding is False and truncation is False:
|
107 |
+
if verbose:
|
108 |
+
if not self.deprecation_warnings.get("Truncation-not-explicitly-activated", False):
|
109 |
+
logger.warning(
|
110 |
+
"Truncation was not explicitly activated but `max_length` is provided a specific value, "
|
111 |
+
"please use `truncation=True` to explicitly truncate examples to max length. "
|
112 |
+
"Defaulting to 'longest_first' truncation strategy. "
|
113 |
+
"If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy "
|
114 |
+
"more precisely by providing a specific strategy to `truncation`."
|
115 |
+
)
|
116 |
+
self.deprecation_warnings["Truncation-not-explicitly-activated"] = True
|
117 |
+
truncation = "longest_first"
|
118 |
+
|
119 |
+
# Get padding strategy
|
120 |
+
if padding is False and old_pad_to_max_length:
|
121 |
+
if verbose:
|
122 |
+
warnings.warn(
|
123 |
+
"The `pad_to_max_length` argument is deprecated and will be removed in a future version, "
|
124 |
+
"use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or "
|
125 |
+
"use `padding='max_length'` to pad to a max length. In this case, you can give a specific "
|
126 |
+
"length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the "
|
127 |
+
"maximal input size of the model (e.g. 512 for Bert).",
|
128 |
+
FutureWarning,
|
129 |
+
)
|
130 |
+
if max_length is None:
|
131 |
+
padding_strategy = PaddingStrategy.LONGEST
|
132 |
+
else:
|
133 |
+
padding_strategy = PaddingStrategy.MAX_LENGTH
|
134 |
+
elif padding is not False:
|
135 |
+
if padding is True:
|
136 |
+
padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch
|
137 |
+
elif not isinstance(padding, PaddingStrategy):
|
138 |
+
padding_strategy = PaddingStrategy(padding)
|
139 |
+
elif isinstance(padding, PaddingStrategy):
|
140 |
+
padding_strategy = padding
|
141 |
+
else:
|
142 |
+
padding_strategy = PaddingStrategy.DO_NOT_PAD
|
143 |
+
|
144 |
+
# Get truncation strategy
|
145 |
+
if truncation is False and old_truncation_strategy != "do_not_truncate":
|
146 |
+
if verbose:
|
147 |
+
warnings.warn(
|
148 |
+
"The `truncation_strategy` argument is deprecated and will be removed in a future version, "
|
149 |
+
"use `truncation=True` to truncate examples to a max length. You can give a specific "
|
150 |
+
"length with `max_length` (e.g. `max_length=45`) or leave max_length to None to truncate to the "
|
151 |
+
"maximal input size of the model (e.g. 512 for Bert). "
|
152 |
+
" If you have pairs of inputs, you can give a specific truncation strategy selected among "
|
153 |
+
"`truncation='only_first'` (will only truncate the first sentence in the pairs) "
|
154 |
+
"`truncation='only_second'` (will only truncate the second sentence in the pairs) "
|
155 |
+
"or `truncation='longest_first'` (will iteratively remove tokens from the longest sentence in the pairs).",
|
156 |
+
FutureWarning,
|
157 |
+
)
|
158 |
+
truncation_strategy = TruncationStrategy(old_truncation_strategy)
|
159 |
+
elif truncation is not False:
|
160 |
+
if truncation is True:
|
161 |
+
truncation_strategy = (
|
162 |
+
TruncationStrategy.LONGEST_FIRST
|
163 |
+
) # Default to truncate the longest sequences in pairs of inputs
|
164 |
+
elif not isinstance(truncation, TruncationStrategy):
|
165 |
+
truncation_strategy = TruncationStrategy(truncation)
|
166 |
+
elif isinstance(truncation, TruncationStrategy):
|
167 |
+
truncation_strategy = truncation
|
168 |
+
else:
|
169 |
+
truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
|
170 |
+
|
171 |
+
# Set max length if needed
|
172 |
+
if max_length is None:
|
173 |
+
if padding_strategy == PaddingStrategy.MAX_LENGTH:
|
174 |
+
if self.model_max_length > LARGE_INTEGER:
|
175 |
+
if verbose:
|
176 |
+
if not self.deprecation_warnings.get("Asking-to-pad-to-max_length", False):
|
177 |
+
logger.warning(
|
178 |
+
"Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. "
|
179 |
+
"Default to no padding."
|
180 |
+
)
|
181 |
+
self.deprecation_warnings["Asking-to-pad-to-max_length"] = True
|
182 |
+
padding_strategy = PaddingStrategy.DO_NOT_PAD
|
183 |
+
else:
|
184 |
+
max_length = self.model_max_length
|
185 |
+
|
186 |
+
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE:
|
187 |
+
if self.model_max_length > LARGE_INTEGER:
|
188 |
+
if verbose:
|
189 |
+
if not self.deprecation_warnings.get("Asking-to-truncate-to-max_length", False):
|
190 |
+
logger.warning(
|
191 |
+
"Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. "
|
192 |
+
"Default to no truncation."
|
193 |
+
)
|
194 |
+
self.deprecation_warnings["Asking-to-truncate-to-max_length"] = True
|
195 |
+
truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
|
196 |
+
else:
|
197 |
+
max_length = self.model_max_length
|
198 |
+
|
199 |
+
# Test if we have a padding token
|
200 |
+
if padding_strategy != PaddingStrategy.DO_NOT_PAD and (not self.pad_token or self.pad_token_id < 0):
|
201 |
+
raise ValueError(
|
202 |
+
"Asking to pad but the tokenizer does not have a padding token. "
|
203 |
+
"Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` "
|
204 |
+
"or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`."
|
205 |
+
)
|
206 |
+
|
207 |
+
# Check that we will truncate to a multiple of pad_to_multiple_of if both are provided
|
208 |
+
if (
|
209 |
+
truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE
|
210 |
+
and padding_strategy != PaddingStrategy.DO_NOT_PAD
|
211 |
+
and pad_to_multiple_of is not None
|
212 |
+
and max_length is not None
|
213 |
+
and (max_length % pad_to_multiple_of != 0)
|
214 |
+
):
|
215 |
+
raise ValueError(
|
216 |
+
f"Truncation and padding are both activated but "
|
217 |
+
f"truncation length ({max_length}) is not a multiple of pad_to_multiple_of ({pad_to_multiple_of})."
|
218 |
+
)
|
219 |
+
|
220 |
+
return padding_strategy, truncation_strategy, max_length, kwargs
|
221 |
+
|
222 |
+
def pad(
|
223 |
+
self,
|
224 |
+
encoded_inputs: Union[
|
225 |
+
BatchEncoding,
|
226 |
+
List[BatchEncoding],
|
227 |
+
Dict[str, EncodedInput],
|
228 |
+
Dict[str, List[EncodedInput]],
|
229 |
+
List[Dict[str, EncodedInput]],
|
230 |
+
],
|
231 |
+
class_type, # options: "gene" or "cell"
|
232 |
+
padding: Union[bool, str, PaddingStrategy] = True,
|
233 |
+
max_length: Optional[int] = None,
|
234 |
+
pad_to_multiple_of: Optional[int] = None,
|
235 |
+
return_attention_mask: Optional[bool] = True,
|
236 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
237 |
+
verbose: bool = True,
|
238 |
+
) -> BatchEncoding:
|
239 |
+
"""
|
240 |
+
Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length
|
241 |
+
in the batch.
|
242 |
+
|
243 |
+
Padding side (left/right) padding token ids are defined at the tokenizer level (with ``self.padding_side``,
|
244 |
+
``self.pad_token_id`` and ``self.pad_token_type_id``)
|
245 |
+
|
246 |
+
.. note::
|
247 |
+
|
248 |
+
If the ``encoded_inputs`` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the
|
249 |
+
result will use the same type unless you provide a different tensor type with ``return_tensors``. In the
|
250 |
+
case of PyTorch tensors, you will lose the specific device of your tensors however.
|
251 |
+
|
252 |
+
Args:
|
253 |
+
encoded_inputs (:class:`~transformers.BatchEncoding`, list of :class:`~transformers.BatchEncoding`, :obj:`Dict[str, List[int]]`, :obj:`Dict[str, List[List[int]]` or :obj:`List[Dict[str, List[int]]]`):
|
254 |
+
Tokenized inputs. Can represent one input (:class:`~transformers.BatchEncoding` or :obj:`Dict[str,
|
255 |
+
List[int]]`) or a batch of tokenized inputs (list of :class:`~transformers.BatchEncoding`, `Dict[str,
|
256 |
+
List[List[int]]]` or `List[Dict[str, List[int]]]`) so you can use this method during preprocessing as
|
257 |
+
well as in a PyTorch Dataloader collate function.
|
258 |
+
|
259 |
+
Instead of :obj:`List[int]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors),
|
260 |
+
see the note above for the return type.
|
261 |
+
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
262 |
+
Select a strategy to pad the returned sequences (according to the model's padding side and padding
|
263 |
+
index) among:
|
264 |
+
|
265 |
+
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
|
266 |
+
single sequence if provided).
|
267 |
+
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
268 |
+
maximum acceptable input length for the model if that argument is not provided.
|
269 |
+
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
270 |
+
different lengths).
|
271 |
+
max_length (:obj:`int`, `optional`):
|
272 |
+
Maximum length of the returned list and optionally padding length (see above).
|
273 |
+
pad_to_multiple_of (:obj:`int`, `optional`):
|
274 |
+
If set will pad the sequence to a multiple of the provided value.
|
275 |
+
|
276 |
+
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
|
277 |
+
>= 7.5 (Volta).
|
278 |
+
return_attention_mask (:obj:`bool`, `optional`):
|
279 |
+
Whether to return the attention mask. If left to the default, will return the attention mask according
|
280 |
+
to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute.
|
281 |
+
|
282 |
+
`What are attention masks? <../glossary.html#attention-mask>`__
|
283 |
+
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`):
|
284 |
+
If set, will return tensors instead of list of python integers. Acceptable values are:
|
285 |
+
|
286 |
+
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
|
287 |
+
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
|
288 |
+
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
|
289 |
+
verbose (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
290 |
+
Whether or not to print more information and warnings.
|
291 |
+
"""
|
292 |
+
# If we have a list of dicts, let's convert it in a dict of lists
|
293 |
+
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
|
294 |
+
if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], (dict, BatchEncoding)):
|
295 |
+
encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()}
|
296 |
+
|
297 |
+
# The model's main input name, usually `input_ids`, has be passed for padding
|
298 |
+
if self.model_input_names[0] not in encoded_inputs:
|
299 |
+
raise ValueError(
|
300 |
+
"You should supply an encoding or a list of encodings to this method"
|
301 |
+
f"that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}"
|
302 |
+
)
|
303 |
+
|
304 |
+
required_input = encoded_inputs[self.model_input_names[0]]
|
305 |
+
|
306 |
+
if not required_input:
|
307 |
+
if return_attention_mask:
|
308 |
+
encoded_inputs["attention_mask"] = []
|
309 |
+
return encoded_inputs
|
310 |
+
|
311 |
+
# If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects
|
312 |
+
# and rebuild them afterwards if no return_tensors is specified
|
313 |
+
# Note that we lose the specific device the tensor may be on for PyTorch
|
314 |
+
|
315 |
+
first_element = required_input[0]
|
316 |
+
if isinstance(first_element, (list, tuple)):
|
317 |
+
# first_element might be an empty list/tuple in some edge cases so we grab the first non empty element.
|
318 |
+
index = 0
|
319 |
+
while len(required_input[index]) == 0:
|
320 |
+
index += 1
|
321 |
+
if index < len(required_input):
|
322 |
+
first_element = required_input[index][0]
|
323 |
+
# At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
|
324 |
+
if not isinstance(first_element, (int, list, tuple)):
|
325 |
+
if is_tf_available() and _is_tensorflow(first_element):
|
326 |
+
return_tensors = "tf" if return_tensors is None else return_tensors
|
327 |
+
elif is_torch_available() and _is_torch(first_element):
|
328 |
+
return_tensors = "pt" if return_tensors is None else return_tensors
|
329 |
+
elif isinstance(first_element, np.ndarray):
|
330 |
+
return_tensors = "np" if return_tensors is None else return_tensors
|
331 |
+
else:
|
332 |
+
raise ValueError(
|
333 |
+
f"type of {first_element} unknown: {type(first_element)}. "
|
334 |
+
f"Should be one of a python, numpy, pytorch or tensorflow object."
|
335 |
+
)
|
336 |
+
|
337 |
+
for key, value in encoded_inputs.items():
|
338 |
+
encoded_inputs[key] = to_py_obj(value)
|
339 |
+
|
340 |
+
# Convert padding_strategy in PaddingStrategy
|
341 |
+
padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies(
|
342 |
+
padding=padding, max_length=max_length, verbose=verbose
|
343 |
+
)
|
344 |
+
|
345 |
+
required_input = encoded_inputs[self.model_input_names[0]]
|
346 |
+
if required_input and not isinstance(required_input[0], (list, tuple)):
|
347 |
+
encoded_inputs = self._pad(
|
348 |
+
encoded_inputs,
|
349 |
+
class_type=class_type,
|
350 |
+
max_length=max_length,
|
351 |
+
padding_strategy=padding_strategy,
|
352 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
353 |
+
return_attention_mask=return_attention_mask,
|
354 |
+
)
|
355 |
+
return BatchEncoding(encoded_inputs, tensor_type=return_tensors)
|
356 |
+
|
357 |
+
batch_size = len(required_input)
|
358 |
+
assert all(
|
359 |
+
len(v) == batch_size for v in encoded_inputs.values()
|
360 |
+
), "Some items in the output dictionary have a different batch size than others."
|
361 |
+
|
362 |
+
if padding_strategy == PaddingStrategy.LONGEST:
|
363 |
+
max_length = max(len(inputs) for inputs in required_input)
|
364 |
+
padding_strategy = PaddingStrategy.MAX_LENGTH
|
365 |
+
|
366 |
+
batch_outputs = {}
|
367 |
+
for i in range(batch_size):
|
368 |
+
inputs = dict((k, v[i]) for k, v in encoded_inputs.items())
|
369 |
+
outputs = self._pad(
|
370 |
+
inputs,
|
371 |
+
class_type=class_type,
|
372 |
+
max_length=max_length,
|
373 |
+
padding_strategy=padding_strategy,
|
374 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
375 |
+
return_attention_mask=return_attention_mask,
|
376 |
+
)
|
377 |
+
|
378 |
+
for key, value in outputs.items():
|
379 |
+
if key not in batch_outputs:
|
380 |
+
batch_outputs[key] = []
|
381 |
+
batch_outputs[key].append(value)
|
382 |
+
if class_type == "cell":
|
383 |
+
del batch_outputs["label"]
|
384 |
+
return BatchEncoding(batch_outputs, tensor_type=return_tensors)
|
385 |
+
|
386 |
+
def _pad(
|
387 |
+
self,
|
388 |
+
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
|
389 |
+
class_type, # options: "gene" or "cell"
|
390 |
+
max_length: Optional[int] = None,
|
391 |
+
padding_strategy: PaddingStrategy = PaddingStrategy.LONGEST,
|
392 |
+
pad_to_multiple_of: Optional[int] = None,
|
393 |
+
return_attention_mask: Optional[bool] = True,
|
394 |
+
) -> dict:
|
395 |
+
"""
|
396 |
+
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
|
397 |
+
|
398 |
+
Args:
|
399 |
+
encoded_inputs: Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
|
400 |
+
max_length: maximum length of the returned list and optionally padding length (see below).
|
401 |
+
Will truncate by taking into account the special tokens.
|
402 |
+
padding_strategy: PaddingStrategy to use for padding.
|
403 |
+
|
404 |
+
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
|
405 |
+
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
|
406 |
+
- PaddingStrategy.DO_NOT_PAD: Do not pad
|
407 |
+
The tokenizer padding sides are defined in self.padding_side:
|
408 |
+
|
409 |
+
- 'left': pads on the left of the sequences
|
410 |
+
- 'right': pads on the right of the sequences
|
411 |
+
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
|
412 |
+
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
|
413 |
+
>= 7.5 (Volta).
|
414 |
+
return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics)
|
415 |
+
"""
|
416 |
+
# Load from model defaults
|
417 |
+
if return_attention_mask is None:
|
418 |
+
return_attention_mask = "attention_mask" in self.model_input_names
|
419 |
+
|
420 |
+
required_input = encoded_inputs[self.model_input_names[0]]
|
421 |
+
|
422 |
+
if padding_strategy == PaddingStrategy.LONGEST:
|
423 |
+
max_length = len(required_input)
|
424 |
+
|
425 |
+
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
|
426 |
+
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
427 |
+
|
428 |
+
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
|
429 |
+
|
430 |
+
if needs_to_be_padded:
|
431 |
+
difference = max_length - len(required_input)
|
432 |
+
if self.padding_side == "right":
|
433 |
+
if return_attention_mask:
|
434 |
+
encoded_inputs["attention_mask"] = [1] * len(required_input) + [0] * difference
|
435 |
+
if "token_type_ids" in encoded_inputs:
|
436 |
+
encoded_inputs["token_type_ids"] = (
|
437 |
+
encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
|
438 |
+
)
|
439 |
+
if "special_tokens_mask" in encoded_inputs:
|
440 |
+
encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
|
441 |
+
encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
|
442 |
+
if class_type == "gene":
|
443 |
+
encoded_inputs["labels"] = encoded_inputs["labels"] + [-100] * difference
|
444 |
+
elif self.padding_side == "left":
|
445 |
+
if return_attention_mask:
|
446 |
+
encoded_inputs["attention_mask"] = [0] * difference + [1] * len(required_input)
|
447 |
+
if "token_type_ids" in encoded_inputs:
|
448 |
+
encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
|
449 |
+
"token_type_ids"
|
450 |
+
]
|
451 |
+
if "special_tokens_mask" in encoded_inputs:
|
452 |
+
encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
|
453 |
+
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
|
454 |
+
if class_type == "gene":
|
455 |
+
encoded_inputs["labels"] = [-100] * difference + encoded_inputs["labels"]
|
456 |
+
else:
|
457 |
+
raise ValueError("Invalid padding strategy:" + str(self.padding_side))
|
458 |
+
elif return_attention_mask and "attention_mask" not in encoded_inputs:
|
459 |
+
encoded_inputs["attention_mask"] = [1] * len(required_input)
|
460 |
+
|
461 |
+
return encoded_inputs
|
462 |
+
|
463 |
+
def get_special_tokens_mask(
|
464 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
465 |
+
) -> List[int]:
|
466 |
+
"""
|
467 |
+
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
468 |
+
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
|
469 |
+
Args:
|
470 |
+
token_ids_0 (:obj:`List[int]`):
|
471 |
+
List of ids of the first sequence.
|
472 |
+
token_ids_1 (:obj:`List[int]`, `optional`):
|
473 |
+
List of ids of the second sequence.
|
474 |
+
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
475 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
476 |
+
Returns:
|
477 |
+
A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
478 |
+
"""
|
479 |
+
assert already_has_special_tokens and token_ids_1 is None, (
|
480 |
+
"You cannot use ``already_has_special_tokens=False`` with this tokenizer. "
|
481 |
+
"Please use a slow (full python) tokenizer to activate this argument."
|
482 |
+
"Or set `return_special_tokens_mask=True` when calling the encoding method "
|
483 |
+
"to get the special tokens mask in any tokenizer. "
|
484 |
+
)
|
485 |
+
|
486 |
+
all_special_ids = self.all_special_ids # cache the property
|
487 |
+
|
488 |
+
special_tokens_mask = [1 if token in all_special_ids else 0 for token in token_ids_0]
|
489 |
+
|
490 |
+
return special_tokens_mask
|
491 |
+
|
492 |
+
def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
|
493 |
+
"""
|
494 |
+
Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the
|
495 |
+
vocabulary.
|
496 |
+
Args:
|
497 |
+
tokens (:obj:`str` or :obj:`List[str]`): One or several token(s) to convert to token id(s).
|
498 |
+
Returns:
|
499 |
+
:obj:`int` or :obj:`List[int]`: The token id or list of token ids.
|
500 |
+
"""
|
501 |
+
if tokens is None:
|
502 |
+
return None
|
503 |
+
|
504 |
+
if isinstance(tokens, str):
|
505 |
+
return self._convert_token_to_id_with_added_voc(tokens)
|
506 |
+
|
507 |
+
ids = []
|
508 |
+
for token in tokens:
|
509 |
+
ids.append(self._convert_token_to_id_with_added_voc(token))
|
510 |
+
return ids
|
511 |
+
|
512 |
+
def _convert_token_to_id_with_added_voc(self, token):
|
513 |
+
if token is None:
|
514 |
+
return None
|
515 |
+
|
516 |
+
return token_dictionary.get(token)
|
517 |
+
|
518 |
+
def __len__(self):
|
519 |
+
return len(token_dictionary)
|
520 |
+
|
521 |
+
|
522 |
+
# collator functions
|
523 |
+
|
524 |
+
class DataCollatorForGeneClassification(DataCollatorForTokenClassification):
|
525 |
+
"""
|
526 |
+
Data collator that will dynamically pad the inputs received, as well as the labels.
|
527 |
+
Args:
|
528 |
+
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
|
529 |
+
The tokenizer used for encoding the data.
|
530 |
+
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
531 |
+
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
532 |
+
among:
|
533 |
+
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
534 |
+
sequence if provided).
|
535 |
+
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
536 |
+
maximum acceptable input length for the model if that argument is not provided.
|
537 |
+
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
538 |
+
different lengths).
|
539 |
+
max_length (:obj:`int`, `optional`):
|
540 |
+
Maximum length of the returned list and optionally padding length (see above).
|
541 |
+
pad_to_multiple_of (:obj:`int`, `optional`):
|
542 |
+
If set will pad the sequence to a multiple of the provided value.
|
543 |
+
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
544 |
+
7.5 (Volta).
|
545 |
+
label_pad_token_id (:obj:`int`, `optional`, defaults to -100):
|
546 |
+
The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
|
547 |
+
"""
|
548 |
+
|
549 |
+
tokenizer = PrecollatorForGeneAndCellClassification()
|
550 |
+
class_type = "gene"
|
551 |
+
padding: Union[bool, str, PaddingStrategy] = True
|
552 |
+
max_length: Optional[int] = None
|
553 |
+
pad_to_multiple_of: Optional[int] = None
|
554 |
+
label_pad_token_id: int = -100
|
555 |
+
|
556 |
+
def __init__(self, *args, **kwargs) -> None:
|
557 |
+
super().__init__(
|
558 |
+
tokenizer=self.tokenizer,
|
559 |
+
padding=self.padding,
|
560 |
+
max_length=self.max_length,
|
561 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
562 |
+
label_pad_token_id=self.label_pad_token_id,
|
563 |
+
*args, **kwargs)
|
564 |
+
|
565 |
+
def _prepare_batch(self, features):
|
566 |
+
label_name = "label" if "label" in features[0].keys() else "labels"
|
567 |
+
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
|
568 |
+
batch = self.tokenizer.pad(
|
569 |
+
features,
|
570 |
+
class_type=self.class_type,
|
571 |
+
padding=self.padding,
|
572 |
+
max_length=self.max_length,
|
573 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
574 |
+
return_tensors="pt",
|
575 |
+
)
|
576 |
+
return batch
|
577 |
+
|
578 |
+
def __call__(self, features):
|
579 |
+
batch = self._prepare_batch(features)
|
580 |
+
|
581 |
+
batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
|
582 |
+
return batch
|
583 |
+
|
584 |
+
|
585 |
+
class DataCollatorForCellClassification(DataCollatorForGeneClassification):
|
586 |
+
|
587 |
+
class_type = "cell"
|
588 |
+
|
589 |
+
def _prepare_batch(self, features):
|
590 |
+
|
591 |
+
batch = super()._prepare_batch(features)
|
592 |
+
|
593 |
+
# Special handling for labels.
|
594 |
+
# Ensure that tensor is created with the correct type
|
595 |
+
# (it should be automatically the case, but let's make sure of it.)
|
596 |
+
first = features[0]
|
597 |
+
if "label" in first and first["label"] is not None:
|
598 |
+
label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"]
|
599 |
+
dtype = torch.long if isinstance(label, int) else torch.float
|
600 |
+
batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
|
601 |
+
|
602 |
+
return batch
|
geneformer/emb_extractor.py
ADDED
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Geneformer embedding extractor.
|
3 |
+
|
4 |
+
Usage:
|
5 |
+
from geneformer import EmbExtractor
|
6 |
+
embex = EmbExtractor(model_type="CellClassifier",
|
7 |
+
num_classes=3,
|
8 |
+
emb_mode="cell",
|
9 |
+
cell_emb_style="mean_pool",
|
10 |
+
filter_data={"cell_type":["cardiomyocyte"]},
|
11 |
+
max_ncells=1000,
|
12 |
+
max_ncells_to_plot=1000,
|
13 |
+
emb_layer=-1,
|
14 |
+
emb_label=["disease","cell_type"],
|
15 |
+
labels_to_plot=["disease","cell_type"],
|
16 |
+
forward_batch_size=100,
|
17 |
+
nproc=16,
|
18 |
+
summary_stat=None)
|
19 |
+
embs = embex.extract_embs("path/to/model",
|
20 |
+
"path/to/input_data",
|
21 |
+
"path/to/output_directory",
|
22 |
+
"output_prefix")
|
23 |
+
embex.plot_embs(embs=embs,
|
24 |
+
plot_style="heatmap",
|
25 |
+
output_directory="path/to/output_directory",
|
26 |
+
output_prefix="output_prefix")
|
27 |
+
|
28 |
+
"""
|
29 |
+
|
30 |
+
# imports
|
31 |
+
import logging
|
32 |
+
import anndata
|
33 |
+
import matplotlib.pyplot as plt
|
34 |
+
import numpy as np
|
35 |
+
import pandas as pd
|
36 |
+
import pickle
|
37 |
+
from tdigest import TDigest
|
38 |
+
import scanpy as sc
|
39 |
+
import seaborn as sns
|
40 |
+
import torch
|
41 |
+
from collections import Counter
|
42 |
+
from pathlib import Path
|
43 |
+
from tqdm.notebook import trange
|
44 |
+
from transformers import BertForMaskedLM, BertForTokenClassification, BertForSequenceClassification
|
45 |
+
|
46 |
+
from .tokenizer import TOKEN_DICTIONARY_FILE
|
47 |
+
|
48 |
+
from .in_silico_perturber import downsample_and_sort, \
|
49 |
+
gen_attention_mask, \
|
50 |
+
get_model_input_size, \
|
51 |
+
load_and_filter, \
|
52 |
+
load_model, \
|
53 |
+
mean_nonpadding_embs, \
|
54 |
+
pad_tensor_list, \
|
55 |
+
quant_layers
|
56 |
+
|
57 |
+
logger = logging.getLogger(__name__)
|
58 |
+
|
59 |
+
# extract embeddings
|
60 |
+
def get_embs(model,
|
61 |
+
filtered_input_data,
|
62 |
+
emb_mode,
|
63 |
+
layer_to_quant,
|
64 |
+
pad_token_id,
|
65 |
+
forward_batch_size,
|
66 |
+
summary_stat):
|
67 |
+
|
68 |
+
model_input_size = get_model_input_size(model)
|
69 |
+
total_batch_length = len(filtered_input_data)
|
70 |
+
|
71 |
+
if summary_stat is None:
|
72 |
+
embs_list = []
|
73 |
+
elif summary_stat is not None:
|
74 |
+
# test embedding extraction for example cell and extract # emb dims
|
75 |
+
example = filtered_input_data.select([i for i in range(1)])
|
76 |
+
example.set_format(type="torch")
|
77 |
+
emb_dims = test_emb(model, example["input_ids"], layer_to_quant)
|
78 |
+
# initiate tdigests for # of emb dims
|
79 |
+
embs_tdigests = [TDigest() for _ in range(emb_dims)]
|
80 |
+
|
81 |
+
for i in trange(0, total_batch_length, forward_batch_size):
|
82 |
+
max_range = min(i+forward_batch_size, total_batch_length)
|
83 |
+
|
84 |
+
minibatch = filtered_input_data.select([i for i in range(i, max_range)])
|
85 |
+
max_len = max(minibatch["length"])
|
86 |
+
original_lens = torch.tensor(minibatch["length"]).to("cuda")
|
87 |
+
minibatch.set_format(type="torch")
|
88 |
+
|
89 |
+
input_data_minibatch = minibatch["input_ids"]
|
90 |
+
input_data_minibatch = pad_tensor_list(input_data_minibatch,
|
91 |
+
max_len,
|
92 |
+
pad_token_id,
|
93 |
+
model_input_size)
|
94 |
+
|
95 |
+
with torch.no_grad():
|
96 |
+
outputs = model(
|
97 |
+
input_ids = input_data_minibatch.to("cuda"),
|
98 |
+
attention_mask = gen_attention_mask(minibatch)
|
99 |
+
)
|
100 |
+
|
101 |
+
embs_i = outputs.hidden_states[layer_to_quant]
|
102 |
+
|
103 |
+
if emb_mode == "cell":
|
104 |
+
mean_embs = mean_nonpadding_embs(embs_i, original_lens)
|
105 |
+
if summary_stat is None:
|
106 |
+
embs_list += [mean_embs]
|
107 |
+
elif summary_stat is not None:
|
108 |
+
# update tdigests with current batch for each emb dim
|
109 |
+
# note: tdigest batch update known to be slow so updating serially
|
110 |
+
[embs_tdigests[j].update(mean_embs[i,j].item()) for i in range(mean_embs.size(0)) for j in range(emb_dims)]
|
111 |
+
|
112 |
+
del outputs
|
113 |
+
del minibatch
|
114 |
+
del input_data_minibatch
|
115 |
+
del embs_i
|
116 |
+
del mean_embs
|
117 |
+
torch.cuda.empty_cache()
|
118 |
+
|
119 |
+
if summary_stat is None:
|
120 |
+
embs_stack = torch.cat(embs_list)
|
121 |
+
# calculate summary stat embs from approximated tdigests
|
122 |
+
elif summary_stat is not None:
|
123 |
+
if summary_stat == "mean":
|
124 |
+
summary_emb_list = [embs_tdigests[i].trimmed_mean(0,100) for i in range(emb_dims)]
|
125 |
+
elif summary_stat == "median":
|
126 |
+
summary_emb_list = [embs_tdigests[i].percentile(50) for i in range(emb_dims)]
|
127 |
+
embs_stack = torch.tensor(summary_emb_list)
|
128 |
+
|
129 |
+
return embs_stack
|
130 |
+
|
131 |
+
def test_emb(model, example, layer_to_quant):
|
132 |
+
with torch.no_grad():
|
133 |
+
outputs = model(
|
134 |
+
input_ids = example.to("cuda")
|
135 |
+
)
|
136 |
+
|
137 |
+
embs_test = outputs.hidden_states[layer_to_quant]
|
138 |
+
return embs_test.size()[2]
|
139 |
+
|
140 |
+
def label_embs(embs, downsampled_data, emb_labels):
|
141 |
+
embs_df = pd.DataFrame(embs.cpu())
|
142 |
+
if emb_labels is not None:
|
143 |
+
for label in emb_labels:
|
144 |
+
emb_label = downsampled_data[label]
|
145 |
+
embs_df[label] = emb_label
|
146 |
+
return embs_df
|
147 |
+
|
148 |
+
def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict):
|
149 |
+
only_embs_df = embs_df.iloc[:,:emb_dims]
|
150 |
+
only_embs_df.index = pd.RangeIndex(0, only_embs_df.shape[0], name=None).astype(str)
|
151 |
+
only_embs_df.columns = pd.RangeIndex(0, only_embs_df.shape[1], name=None).astype(str)
|
152 |
+
vars_dict = {"embs": only_embs_df.columns}
|
153 |
+
obs_dict = {"cell_id": list(only_embs_df.index),
|
154 |
+
f"{label}": list(embs_df[label])}
|
155 |
+
adata = anndata.AnnData(X=only_embs_df, obs=obs_dict, var=vars_dict)
|
156 |
+
sc.tl.pca(adata, svd_solver='arpack')
|
157 |
+
sc.pp.neighbors(adata)
|
158 |
+
sc.tl.umap(adata)
|
159 |
+
sns.set(rc={'figure.figsize':(10,10)}, font_scale=2.3)
|
160 |
+
sns.set_style("white")
|
161 |
+
default_kwargs_dict = {"palette":"Set2", "size":200}
|
162 |
+
if kwargs_dict is not None:
|
163 |
+
default_kwargs_dict.update(kwargs_dict)
|
164 |
+
|
165 |
+
sc.pl.umap(adata, color=label, save=output_file, **default_kwargs_dict)
|
166 |
+
|
167 |
+
def gen_heatmap_class_colors(labels, df):
|
168 |
+
pal = sns.cubehelix_palette(len(Counter(labels).keys()), light=0.9, dark=0.1, hue=1, reverse=True, start=1, rot=-2)
|
169 |
+
lut = dict(zip(map(str, Counter(labels).keys()), pal))
|
170 |
+
colors = pd.Series(labels, index=df.index).map(lut)
|
171 |
+
return colors
|
172 |
+
|
173 |
+
def gen_heatmap_class_dict(classes, label_colors_series):
|
174 |
+
class_color_dict_df = pd.DataFrame({"classes": classes, "color": label_colors_series})
|
175 |
+
class_color_dict_df = class_color_dict_df.drop_duplicates(subset=["classes"])
|
176 |
+
return dict(zip(class_color_dict_df["classes"],class_color_dict_df["color"]))
|
177 |
+
|
178 |
+
def make_colorbar(embs_df, label):
|
179 |
+
|
180 |
+
labels = list(embs_df[label])
|
181 |
+
|
182 |
+
cell_type_colors = gen_heatmap_class_colors(labels, embs_df)
|
183 |
+
label_colors = pd.DataFrame(cell_type_colors, columns=[label])
|
184 |
+
|
185 |
+
for i,row in label_colors.iterrows():
|
186 |
+
colors=row[0]
|
187 |
+
if len(colors)!=3 or any(np.isnan(colors)):
|
188 |
+
print(i,colors)
|
189 |
+
|
190 |
+
label_colors.isna().sum()
|
191 |
+
|
192 |
+
# create dictionary for colors and classes
|
193 |
+
label_color_dict = gen_heatmap_class_dict(labels, label_colors[label])
|
194 |
+
return label_colors, label_color_dict
|
195 |
+
|
196 |
+
def plot_heatmap(embs_df, emb_dims, label, output_file, kwargs_dict):
|
197 |
+
sns.set_style("white")
|
198 |
+
sns.set(font_scale=2)
|
199 |
+
plt.figure(figsize=(15, 15), dpi=150)
|
200 |
+
label_colors, label_color_dict = make_colorbar(embs_df, label)
|
201 |
+
|
202 |
+
default_kwargs_dict = {"row_cluster": True,
|
203 |
+
"col_cluster": True,
|
204 |
+
"row_colors": label_colors,
|
205 |
+
"standard_scale": 1,
|
206 |
+
"linewidths": 0,
|
207 |
+
"xticklabels": False,
|
208 |
+
"yticklabels": False,
|
209 |
+
"figsize": (15,15),
|
210 |
+
"center": 0,
|
211 |
+
"cmap": "magma"}
|
212 |
+
|
213 |
+
if kwargs_dict is not None:
|
214 |
+
default_kwargs_dict.update(kwargs_dict)
|
215 |
+
g = sns.clustermap(embs_df.iloc[:,0:emb_dims].apply(pd.to_numeric), **default_kwargs_dict)
|
216 |
+
|
217 |
+
plt.setp(g.ax_row_colors.get_xmajorticklabels(), rotation=45, ha="right")
|
218 |
+
|
219 |
+
for label_color in list(label_color_dict.keys()):
|
220 |
+
g.ax_col_dendrogram.bar(0, 0, color=label_color_dict[label_color], label=label_color, linewidth=0)
|
221 |
+
|
222 |
+
l1 = g.ax_col_dendrogram.legend(title=f"{label}",
|
223 |
+
loc="lower center",
|
224 |
+
ncol=4,
|
225 |
+
bbox_to_anchor=(0.5, 1),
|
226 |
+
facecolor="white")
|
227 |
+
|
228 |
+
plt.savefig(output_file, bbox_inches='tight')
|
229 |
+
|
230 |
+
class EmbExtractor:
|
231 |
+
valid_option_dict = {
|
232 |
+
"model_type": {"Pretrained","GeneClassifier","CellClassifier"},
|
233 |
+
"num_classes": {int},
|
234 |
+
"emb_mode": {"cell","gene"},
|
235 |
+
"cell_emb_style": {"mean_pool"},
|
236 |
+
"filter_data": {None, dict},
|
237 |
+
"max_ncells": {None, int},
|
238 |
+
"emb_layer": {-1, 0},
|
239 |
+
"emb_label": {None, list},
|
240 |
+
"labels_to_plot": {None, list},
|
241 |
+
"forward_batch_size": {int},
|
242 |
+
"nproc": {int},
|
243 |
+
"summary_stat": {None, "mean", "median"},
|
244 |
+
}
|
245 |
+
def __init__(
|
246 |
+
self,
|
247 |
+
model_type="Pretrained",
|
248 |
+
num_classes=0,
|
249 |
+
emb_mode="cell",
|
250 |
+
cell_emb_style="mean_pool",
|
251 |
+
filter_data=None,
|
252 |
+
max_ncells=1000,
|
253 |
+
emb_layer=-1,
|
254 |
+
emb_label=None,
|
255 |
+
labels_to_plot=None,
|
256 |
+
forward_batch_size=100,
|
257 |
+
nproc=4,
|
258 |
+
summary_stat=None,
|
259 |
+
token_dictionary_file=TOKEN_DICTIONARY_FILE,
|
260 |
+
):
|
261 |
+
"""
|
262 |
+
Initialize embedding extractor.
|
263 |
+
|
264 |
+
Parameters
|
265 |
+
----------
|
266 |
+
model_type : {"Pretrained","GeneClassifier","CellClassifier"}
|
267 |
+
Whether model is the pretrained Geneformer or a fine-tuned gene or cell classifier.
|
268 |
+
num_classes : int
|
269 |
+
If model is a gene or cell classifier, specify number of classes it was trained to classify.
|
270 |
+
For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
|
271 |
+
emb_mode : {"cell","gene"}
|
272 |
+
Whether to output cell or gene embeddings.
|
273 |
+
cell_emb_style : "mean_pool"
|
274 |
+
Method for summarizing cell embeddings.
|
275 |
+
Currently only option is mean pooling of gene embeddings for given cell.
|
276 |
+
filter_data : None, dict
|
277 |
+
Default is to extract embeddings from all input data.
|
278 |
+
Otherwise, dictionary specifying .dataset column name and list of values to filter by.
|
279 |
+
max_ncells : None, int
|
280 |
+
Maximum number of cells to extract embeddings from.
|
281 |
+
Default is 1000 cells randomly sampled from input data.
|
282 |
+
If None, will extract embeddings from all cells.
|
283 |
+
emb_layer : {-1, 0}
|
284 |
+
Embedding layer to extract.
|
285 |
+
The last layer is most specifically weighted to optimize the given learning objective.
|
286 |
+
Generally, it is best to extract the 2nd to last layer to get a more general representation.
|
287 |
+
-1: 2nd to last layer
|
288 |
+
0: last layer
|
289 |
+
emb_label : None, list
|
290 |
+
List of column name(s) in .dataset to add as labels to embedding output.
|
291 |
+
labels_to_plot : None, list
|
292 |
+
Cell labels to plot.
|
293 |
+
Shown as color bar in heatmap.
|
294 |
+
Shown as cell color in umap.
|
295 |
+
Plotting umap requires labels to plot.
|
296 |
+
forward_batch_size : int
|
297 |
+
Batch size for forward pass.
|
298 |
+
nproc : int
|
299 |
+
Number of CPU processes to use.
|
300 |
+
summary_stat : {None, "mean", "median"}
|
301 |
+
If not None, outputs only approximated mean or median embedding of input data.
|
302 |
+
Recommended if encountering memory constraints while generating goal embedding positions.
|
303 |
+
Slower but more memory-efficient.
|
304 |
+
token_dictionary_file : Path
|
305 |
+
Path to pickle file containing token dictionary (Ensembl ID:token).
|
306 |
+
"""
|
307 |
+
|
308 |
+
self.model_type = model_type
|
309 |
+
self.num_classes = num_classes
|
310 |
+
self.emb_mode = emb_mode
|
311 |
+
self.cell_emb_style = cell_emb_style
|
312 |
+
self.filter_data = filter_data
|
313 |
+
self.max_ncells = max_ncells
|
314 |
+
self.emb_layer = emb_layer
|
315 |
+
self.emb_label = emb_label
|
316 |
+
self.labels_to_plot = labels_to_plot
|
317 |
+
self.forward_batch_size = forward_batch_size
|
318 |
+
self.nproc = nproc
|
319 |
+
self.summary_stat = summary_stat
|
320 |
+
|
321 |
+
self.validate_options()
|
322 |
+
|
323 |
+
# load token dictionary (Ensembl IDs:token)
|
324 |
+
with open(token_dictionary_file, "rb") as f:
|
325 |
+
self.gene_token_dict = pickle.load(f)
|
326 |
+
|
327 |
+
self.pad_token_id = self.gene_token_dict.get("<pad>")
|
328 |
+
|
329 |
+
|
330 |
+
def validate_options(self):
|
331 |
+
# first disallow options under development
|
332 |
+
if self.emb_mode == "gene":
|
333 |
+
logger.error(
|
334 |
+
"Extraction and plotting of gene-level embeddings currently under development. " \
|
335 |
+
"Current valid option for 'emb_mode': 'cell'"
|
336 |
+
)
|
337 |
+
raise
|
338 |
+
|
339 |
+
# confirm arguments are within valid options and compatible with each other
|
340 |
+
for attr_name,valid_options in self.valid_option_dict.items():
|
341 |
+
attr_value = self.__dict__[attr_name]
|
342 |
+
if type(attr_value) not in {list, dict}:
|
343 |
+
if attr_value in valid_options:
|
344 |
+
continue
|
345 |
+
valid_type = False
|
346 |
+
for option in valid_options:
|
347 |
+
if (option in [int,list,dict]) and isinstance(attr_value, option):
|
348 |
+
valid_type = True
|
349 |
+
break
|
350 |
+
if valid_type:
|
351 |
+
continue
|
352 |
+
logger.error(
|
353 |
+
f"Invalid option for {attr_name}. " \
|
354 |
+
f"Valid options for {attr_name}: {valid_options}"
|
355 |
+
)
|
356 |
+
raise
|
357 |
+
|
358 |
+
if self.filter_data is not None:
|
359 |
+
for key,value in self.filter_data.items():
|
360 |
+
if type(value) != list:
|
361 |
+
self.filter_data[key] = [value]
|
362 |
+
logger.warning(
|
363 |
+
"Values in filter_data dict must be lists. " \
|
364 |
+
f"Changing {key} value to list ([{value}]).")
|
365 |
+
|
366 |
+
def extract_embs(self,
|
367 |
+
model_directory,
|
368 |
+
input_data_file,
|
369 |
+
output_directory,
|
370 |
+
output_prefix):
|
371 |
+
"""
|
372 |
+
Extract embeddings from input data and save as results in output_directory.
|
373 |
+
|
374 |
+
Parameters
|
375 |
+
----------
|
376 |
+
model_directory : Path
|
377 |
+
Path to directory containing model
|
378 |
+
input_data_file : Path
|
379 |
+
Path to directory containing .dataset inputs
|
380 |
+
output_directory : Path
|
381 |
+
Path to directory where embedding data will be saved as csv
|
382 |
+
output_prefix : str
|
383 |
+
Prefix for output file
|
384 |
+
"""
|
385 |
+
|
386 |
+
filtered_input_data = load_and_filter(self.filter_data, self.nproc, input_data_file)
|
387 |
+
downsampled_data = downsample_and_sort(filtered_input_data, self.max_ncells)
|
388 |
+
model = load_model(self.model_type, self.num_classes, model_directory)
|
389 |
+
layer_to_quant = quant_layers(model)+self.emb_layer
|
390 |
+
embs = get_embs(model,
|
391 |
+
downsampled_data,
|
392 |
+
self.emb_mode,
|
393 |
+
layer_to_quant,
|
394 |
+
self.pad_token_id,
|
395 |
+
self.forward_batch_size,
|
396 |
+
self.summary_stat)
|
397 |
+
|
398 |
+
if self.summary_stat is None:
|
399 |
+
embs_df = label_embs(embs, downsampled_data, self.emb_label)
|
400 |
+
elif self.summary_stat is not None:
|
401 |
+
embs_df = pd.DataFrame(embs.cpu()).T
|
402 |
+
|
403 |
+
# save embeddings to output_path
|
404 |
+
output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")
|
405 |
+
embs_df.to_csv(output_path)
|
406 |
+
|
407 |
+
return embs_df
|
408 |
+
|
409 |
+
def plot_embs(self,
|
410 |
+
embs,
|
411 |
+
plot_style,
|
412 |
+
output_directory,
|
413 |
+
output_prefix,
|
414 |
+
max_ncells_to_plot=1000,
|
415 |
+
kwargs_dict=None):
|
416 |
+
|
417 |
+
"""
|
418 |
+
Plot embeddings, coloring by provided labels.
|
419 |
+
|
420 |
+
Parameters
|
421 |
+
----------
|
422 |
+
embs : pandas.core.frame.DataFrame
|
423 |
+
Pandas dataframe containing embeddings output from extract_embs
|
424 |
+
plot_style : str
|
425 |
+
Style of plot: "heatmap" or "umap"
|
426 |
+
output_directory : Path
|
427 |
+
Path to directory where plots will be saved as pdf
|
428 |
+
output_prefix : str
|
429 |
+
Prefix for output file
|
430 |
+
max_ncells_to_plot : None, int
|
431 |
+
Maximum number of cells to plot.
|
432 |
+
Default is 1000 cells randomly sampled from embeddings.
|
433 |
+
If None, will plot embeddings from all cells.
|
434 |
+
kwargs_dict : dict
|
435 |
+
Dictionary of kwargs to pass to plotting function.
|
436 |
+
"""
|
437 |
+
|
438 |
+
if plot_style not in ["heatmap","umap"]:
|
439 |
+
logger.error(
|
440 |
+
"Invalid option for 'plot_style'. " \
|
441 |
+
"Valid options: {'heatmap','umap'}"
|
442 |
+
)
|
443 |
+
raise
|
444 |
+
|
445 |
+
if (plot_style == "umap") and (self.labels_to_plot is None):
|
446 |
+
logger.error(
|
447 |
+
"Plotting UMAP requires 'labels_to_plot'. "
|
448 |
+
)
|
449 |
+
raise
|
450 |
+
|
451 |
+
if max_ncells_to_plot > self.max_ncells:
|
452 |
+
max_ncells_to_plot = self.max_ncells
|
453 |
+
logger.warning(
|
454 |
+
"max_ncells_to_plot must be <= max_ncells. " \
|
455 |
+
f"Changing max_ncells_to_plot to {self.max_ncells}.")
|
456 |
+
|
457 |
+
if (max_ncells_to_plot is not None) \
|
458 |
+
and (max_ncells_to_plot < self.max_ncells):
|
459 |
+
embs = embs.sample(max_ncells_to_plot, axis=0)
|
460 |
+
|
461 |
+
if self.emb_label is None:
|
462 |
+
label_len = 0
|
463 |
+
else:
|
464 |
+
label_len = len(self.emb_label)
|
465 |
+
|
466 |
+
emb_dims = embs.shape[1] - label_len
|
467 |
+
|
468 |
+
if self.emb_label is None:
|
469 |
+
emb_labels = None
|
470 |
+
else:
|
471 |
+
emb_labels = embs.columns[emb_dims:]
|
472 |
+
|
473 |
+
if plot_style == "umap":
|
474 |
+
for label in self.labels_to_plot:
|
475 |
+
if label not in emb_labels:
|
476 |
+
logger.warning(
|
477 |
+
f"Label {label} from labels_to_plot " \
|
478 |
+
f"not present in provided embeddings dataframe.")
|
479 |
+
continue
|
480 |
+
output_prefix_label = "_" + output_prefix + f"_umap_{label}"
|
481 |
+
output_file = (Path(output_directory) / output_prefix_label).with_suffix(".pdf")
|
482 |
+
plot_umap(embs, emb_dims, label, output_prefix_label, kwargs_dict)
|
483 |
+
|
484 |
+
if plot_style == "heatmap":
|
485 |
+
for label in self.labels_to_plot:
|
486 |
+
if label not in emb_labels:
|
487 |
+
logger.warning(
|
488 |
+
f"Label {label} from labels_to_plot " \
|
489 |
+
f"not present in provided embeddings dataframe.")
|
490 |
+
continue
|
491 |
+
output_prefix_label = output_prefix + f"_heatmap_{label}"
|
492 |
+
output_file = (Path(output_directory) / output_prefix_label).with_suffix(".pdf")
|
493 |
+
plot_heatmap(embs, emb_dims, label, output_file, kwargs_dict)
|
geneformer/gene_median_dictionary.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3129017daec18ff275f0900e674957d9b6547af266ef0e2c97b03d20b5d4c225
|
3 |
+
size 1640760
|
geneformer/gene_name_id_dict.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:90f7100adec84828555873be1bae866e83509ce016dedfd9633d12e01dee4ea4
|
3 |
+
size 607393
|
geneformer/in_silico_perturber.py
ADDED
@@ -0,0 +1,1297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Geneformer in silico perturber.
|
3 |
+
|
4 |
+
Usage:
|
5 |
+
from geneformer import InSilicoPerturber
|
6 |
+
isp = InSilicoPerturber(perturb_type="delete",
|
7 |
+
perturb_rank_shift=None,
|
8 |
+
genes_to_perturb="all",
|
9 |
+
combos=0,
|
10 |
+
anchor_gene=None,
|
11 |
+
model_type="Pretrained",
|
12 |
+
num_classes=0,
|
13 |
+
emb_mode="cell",
|
14 |
+
cell_emb_style="mean_pool",
|
15 |
+
filter_data={"cell_type":["cardiomyocyte"]},
|
16 |
+
cell_states_to_model={"state_key": "disease", "start_state": "dcm", "goal_state": "nf", "alt_states": ["hcm", "other1", "other2"]},
|
17 |
+
max_ncells=None,
|
18 |
+
emb_layer=-1,
|
19 |
+
forward_batch_size=100,
|
20 |
+
nproc=4)
|
21 |
+
isp.perturb_data("path/to/model",
|
22 |
+
"path/to/input_data",
|
23 |
+
"path/to/output_directory",
|
24 |
+
"output_prefix")
|
25 |
+
"""
|
26 |
+
|
27 |
+
# imports
|
28 |
+
import itertools as it
|
29 |
+
import logging
|
30 |
+
import numpy as np
|
31 |
+
import pickle
|
32 |
+
import re
|
33 |
+
import seaborn as sns; sns.set()
|
34 |
+
import torch
|
35 |
+
from collections import defaultdict
|
36 |
+
from datasets import Dataset, load_from_disk
|
37 |
+
from tqdm.notebook import trange
|
38 |
+
from transformers import BertForMaskedLM, BertForTokenClassification, BertForSequenceClassification
|
39 |
+
|
40 |
+
from .tokenizer import TOKEN_DICTIONARY_FILE
|
41 |
+
|
42 |
+
logger = logging.getLogger(__name__)
|
43 |
+
|
44 |
+
|
45 |
+
# load data and filter by defined criteria
|
46 |
+
def load_and_filter(filter_data, nproc, input_data_file):
|
47 |
+
data = load_from_disk(input_data_file)
|
48 |
+
if filter_data is not None:
|
49 |
+
for key,value in filter_data.items():
|
50 |
+
def filter_data_by_criteria(example):
|
51 |
+
return example[key] in value
|
52 |
+
data = data.filter(filter_data_by_criteria, num_proc=nproc)
|
53 |
+
if len(data) == 0:
|
54 |
+
logger.error(
|
55 |
+
"No cells remain after filtering. Check filtering criteria.")
|
56 |
+
raise
|
57 |
+
data_shuffled = data.shuffle(seed=42)
|
58 |
+
return data_shuffled
|
59 |
+
|
60 |
+
# load model to GPU
|
61 |
+
def load_model(model_type, num_classes, model_directory):
|
62 |
+
if model_type == "Pretrained":
|
63 |
+
model = BertForMaskedLM.from_pretrained(model_directory,
|
64 |
+
output_hidden_states=True,
|
65 |
+
output_attentions=False)
|
66 |
+
elif model_type == "GeneClassifier":
|
67 |
+
model = BertForTokenClassification.from_pretrained(model_directory,
|
68 |
+
num_labels=num_classes,
|
69 |
+
output_hidden_states=True,
|
70 |
+
output_attentions=False)
|
71 |
+
elif model_type == "CellClassifier":
|
72 |
+
model = BertForSequenceClassification.from_pretrained(model_directory,
|
73 |
+
num_labels=num_classes,
|
74 |
+
output_hidden_states=True,
|
75 |
+
output_attentions=False)
|
76 |
+
# put the model in eval mode for fwd pass
|
77 |
+
model.eval()
|
78 |
+
model = model.to("cuda:0")
|
79 |
+
return model
|
80 |
+
|
81 |
+
def quant_layers(model):
|
82 |
+
layer_nums = []
|
83 |
+
for name, parameter in model.named_parameters():
|
84 |
+
if "layer" in name:
|
85 |
+
layer_nums += [int(name.split("layer.")[1].split(".")[0])]
|
86 |
+
return int(max(layer_nums))+1
|
87 |
+
|
88 |
+
def get_model_input_size(model):
|
89 |
+
return int(re.split("\(|,",str(model.bert.embeddings.position_embeddings))[1])
|
90 |
+
|
91 |
+
def flatten_list(megalist):
|
92 |
+
return [item for sublist in megalist for item in sublist]
|
93 |
+
|
94 |
+
def measure_length(example):
|
95 |
+
example["length"] = len(example["input_ids"])
|
96 |
+
return example
|
97 |
+
|
98 |
+
def downsample_and_sort(data_shuffled, max_ncells):
|
99 |
+
num_cells = len(data_shuffled)
|
100 |
+
# if max number of cells is defined, then subsample to this max number
|
101 |
+
if max_ncells != None:
|
102 |
+
num_cells = min(max_ncells,num_cells)
|
103 |
+
data_subset = data_shuffled.select([i for i in range(num_cells)])
|
104 |
+
# sort dataset with largest cell first to encounter any memory errors earlier
|
105 |
+
data_sorted = data_subset.sort("length",reverse=True)
|
106 |
+
return data_sorted
|
107 |
+
|
108 |
+
def get_possible_states(cell_states_to_model):
|
109 |
+
possible_states = []
|
110 |
+
for key in ["start_state","goal_state"]:
|
111 |
+
possible_states += [cell_states_to_model[key]]
|
112 |
+
possible_states += cell_states_to_model.get("alt_states",[])
|
113 |
+
return possible_states
|
114 |
+
|
115 |
+
def forward_pass_single_cell(model, example_cell, layer_to_quant):
|
116 |
+
example_cell.set_format(type="torch")
|
117 |
+
input_data = example_cell["input_ids"]
|
118 |
+
with torch.no_grad():
|
119 |
+
outputs = model(
|
120 |
+
input_ids = input_data.to("cuda")
|
121 |
+
)
|
122 |
+
emb = torch.squeeze(outputs.hidden_states[layer_to_quant])
|
123 |
+
del outputs
|
124 |
+
return emb
|
125 |
+
|
126 |
+
def perturb_emb_by_index(emb, indices):
|
127 |
+
mask = torch.ones(emb.numel(), dtype=torch.bool)
|
128 |
+
mask[indices] = False
|
129 |
+
return emb[mask]
|
130 |
+
|
131 |
+
def delete_indices(example):
|
132 |
+
indices = example["perturb_index"]
|
133 |
+
if any(isinstance(el, list) for el in indices):
|
134 |
+
indices = flatten_list(indices)
|
135 |
+
for index in sorted(indices, reverse=True):
|
136 |
+
del example["input_ids"][index]
|
137 |
+
return example
|
138 |
+
|
139 |
+
# for genes_to_perturb = "all" where only genes within cell are overexpressed
|
140 |
+
def overexpress_indices(example):
|
141 |
+
indices = example["perturb_index"]
|
142 |
+
if any(isinstance(el, list) for el in indices):
|
143 |
+
indices = flatten_list(indices)
|
144 |
+
for index in sorted(indices, reverse=True):
|
145 |
+
example["input_ids"].insert(0, example["input_ids"].pop(index))
|
146 |
+
return example
|
147 |
+
|
148 |
+
# for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
|
149 |
+
def overexpress_tokens(example):
|
150 |
+
# -100 indicates tokens to overexpress are not present in rank value encoding
|
151 |
+
if example["perturb_index"] != [-100]:
|
152 |
+
example = delete_indices(example)
|
153 |
+
[example["input_ids"].insert(0, token) for token in example["tokens_to_perturb"][::-1]]
|
154 |
+
return example
|
155 |
+
|
156 |
+
def remove_indices_from_emb(emb, indices_to_remove, gene_dim):
|
157 |
+
# indices_to_remove is list of indices to remove
|
158 |
+
indices_to_keep = [i for i in range(emb.size()[gene_dim]) if i not in indices_to_remove]
|
159 |
+
num_dims = emb.dim()
|
160 |
+
emb_slice = [slice(None) if dim != gene_dim else indices_to_keep for dim in range(num_dims)]
|
161 |
+
sliced_emb = emb[emb_slice]
|
162 |
+
return sliced_emb
|
163 |
+
|
164 |
+
def remove_indices_from_emb_batch(emb_batch, list_of_indices_to_remove, gene_dim):
|
165 |
+
output_batch = torch.stack([
|
166 |
+
remove_indices_from_emb(emb_batch[i, :, :], idx, gene_dim-1) for
|
167 |
+
i, idx in enumerate(list_of_indices_to_remove)
|
168 |
+
])
|
169 |
+
return output_batch
|
170 |
+
|
171 |
+
def make_perturbation_batch(example_cell,
|
172 |
+
perturb_type,
|
173 |
+
tokens_to_perturb,
|
174 |
+
anchor_token,
|
175 |
+
combo_lvl,
|
176 |
+
num_proc):
|
177 |
+
if tokens_to_perturb == "all":
|
178 |
+
if perturb_type in ["overexpress","activate"]:
|
179 |
+
range_start = 1
|
180 |
+
elif perturb_type in ["delete","inhibit"]:
|
181 |
+
range_start = 0
|
182 |
+
indices_to_perturb = [[i] for i in range(range_start,example_cell["length"][0])]
|
183 |
+
elif combo_lvl>0 and (anchor_token is not None):
|
184 |
+
example_input_ids = example_cell["input_ids "][0]
|
185 |
+
anchor_index = example_input_ids.index(anchor_token[0])
|
186 |
+
indices_to_perturb = [sorted([anchor_index,i]) if i!=anchor_index else None for i in range(example_cell["length"][0])]
|
187 |
+
indices_to_perturb = [item for item in indices_to_perturb if item is not None]
|
188 |
+
else:
|
189 |
+
example_input_ids = example_cell["input_ids"][0]
|
190 |
+
indices_to_perturb = [[example_input_ids.index(token)] if token in example_input_ids else None for token in tokens_to_perturb]
|
191 |
+
indices_to_perturb = [item for item in indices_to_perturb if item is not None]
|
192 |
+
|
193 |
+
# create all permutations of combo_lvl of modifiers from tokens_to_perturb
|
194 |
+
if combo_lvl>0 and (anchor_token is None):
|
195 |
+
if tokens_to_perturb != "all":
|
196 |
+
if len(tokens_to_perturb) == combo_lvl+1:
|
197 |
+
indices_to_perturb = [list(x) for x in it.combinations(indices_to_perturb, combo_lvl+1)]
|
198 |
+
else:
|
199 |
+
all_indices = [[i] for i in range(example_cell["length"][0])]
|
200 |
+
all_indices = [index for index in all_indices if index not in indices_to_perturb]
|
201 |
+
indices_to_perturb = [[[j for i in indices_to_perturb for j in i], x] for x in all_indices]
|
202 |
+
length = len(indices_to_perturb)
|
203 |
+
perturbation_dataset = Dataset.from_dict({"input_ids": example_cell["input_ids"]*length,
|
204 |
+
"perturb_index": indices_to_perturb})
|
205 |
+
if length<400:
|
206 |
+
num_proc_i = 1
|
207 |
+
else:
|
208 |
+
num_proc_i = num_proc
|
209 |
+
if perturb_type == "delete":
|
210 |
+
perturbation_dataset = perturbation_dataset.map(delete_indices, num_proc=num_proc_i)
|
211 |
+
elif perturb_type == "overexpress":
|
212 |
+
perturbation_dataset = perturbation_dataset.map(overexpress_indices, num_proc=num_proc_i)
|
213 |
+
return perturbation_dataset, indices_to_perturb
|
214 |
+
|
215 |
+
# perturbed cell emb removing the activated/overexpressed/inhibited gene emb
|
216 |
+
# so that only non-perturbed gene embeddings are compared to each other
|
217 |
+
# in original or perturbed context
|
218 |
+
def make_comparison_batch(original_emb_batch, indices_to_perturb, perturb_group):
|
219 |
+
all_embs_list = []
|
220 |
+
|
221 |
+
# if making comparison batch for multiple perturbations in single cell
|
222 |
+
if perturb_group == False:
|
223 |
+
original_emb_list = [original_emb_batch]*len(indices_to_perturb)
|
224 |
+
# if making comparison batch for single perturbation in multiple cells
|
225 |
+
elif perturb_group == True:
|
226 |
+
original_emb_list = original_emb_batch
|
227 |
+
|
228 |
+
|
229 |
+
for i in range(len(original_emb_list)):
|
230 |
+
original_emb = original_emb_list[i]
|
231 |
+
indices = indices_to_perturb[i]
|
232 |
+
if indices == [-100]:
|
233 |
+
all_embs_list += [original_emb[:]]
|
234 |
+
continue
|
235 |
+
emb_list = []
|
236 |
+
start = 0
|
237 |
+
if any(isinstance(el, list) for el in indices):
|
238 |
+
indices = flatten_list(indices)
|
239 |
+
for i in sorted(indices):
|
240 |
+
emb_list += [original_emb[start:i]]
|
241 |
+
start = i+1
|
242 |
+
emb_list += [original_emb[start:]]
|
243 |
+
all_embs_list += [torch.cat(emb_list)]
|
244 |
+
len_set = set([emb.size()[0] for emb in all_embs_list])
|
245 |
+
if len(len_set) > 1:
|
246 |
+
max_len = max(len_set)
|
247 |
+
all_embs_list = [pad_2d_tensor(emb, None, max_len, 0) for emb in all_embs_list]
|
248 |
+
return torch.stack(all_embs_list)
|
249 |
+
|
250 |
+
# average embedding position of goal cell states
|
251 |
+
def get_cell_state_avg_embs(model,
|
252 |
+
filtered_input_data,
|
253 |
+
cell_states_to_model,
|
254 |
+
layer_to_quant,
|
255 |
+
pad_token_id,
|
256 |
+
forward_batch_size,
|
257 |
+
num_proc):
|
258 |
+
|
259 |
+
model_input_size = get_model_input_size(model)
|
260 |
+
possible_states = get_possible_states(cell_states_to_model)
|
261 |
+
state_embs_dict = dict()
|
262 |
+
for possible_state in possible_states:
|
263 |
+
state_embs_list = []
|
264 |
+
original_lens = []
|
265 |
+
|
266 |
+
def filter_states(example):
|
267 |
+
state_key = cell_states_to_model["state_key"]
|
268 |
+
return example[state_key] in [possible_state]
|
269 |
+
filtered_input_data_state = filtered_input_data.filter(filter_states, num_proc=num_proc)
|
270 |
+
total_batch_length = len(filtered_input_data_state)
|
271 |
+
if ((total_batch_length-1)/forward_batch_size).is_integer():
|
272 |
+
forward_batch_size = forward_batch_size-1
|
273 |
+
max_len = max(filtered_input_data_state["length"])
|
274 |
+
for i in range(0, total_batch_length, forward_batch_size):
|
275 |
+
max_range = min(i+forward_batch_size, total_batch_length)
|
276 |
+
|
277 |
+
state_minibatch = filtered_input_data_state.select([i for i in range(i, max_range)])
|
278 |
+
state_minibatch.set_format(type="torch")
|
279 |
+
|
280 |
+
input_data_minibatch = state_minibatch["input_ids"]
|
281 |
+
original_lens += state_minibatch["length"]
|
282 |
+
input_data_minibatch = pad_tensor_list(input_data_minibatch,
|
283 |
+
max_len,
|
284 |
+
pad_token_id,
|
285 |
+
model_input_size)
|
286 |
+
attention_mask = gen_attention_mask(state_minibatch, max_len)
|
287 |
+
|
288 |
+
with torch.no_grad():
|
289 |
+
outputs = model(
|
290 |
+
input_ids = input_data_minibatch.to("cuda"),
|
291 |
+
attention_mask = attention_mask
|
292 |
+
)
|
293 |
+
|
294 |
+
state_embs_i = outputs.hidden_states[layer_to_quant]
|
295 |
+
state_embs_list += [state_embs_i]
|
296 |
+
del outputs
|
297 |
+
del state_minibatch
|
298 |
+
del input_data_minibatch
|
299 |
+
del attention_mask
|
300 |
+
del state_embs_i
|
301 |
+
torch.cuda.empty_cache()
|
302 |
+
|
303 |
+
state_embs = torch.cat(state_embs_list)
|
304 |
+
avg_state_emb = mean_nonpadding_embs(state_embs, torch.Tensor(original_lens).to("cuda"))
|
305 |
+
avg_state_emb = torch.mean(avg_state_emb, dim=0, keepdim=True)
|
306 |
+
state_embs_dict[possible_state] = avg_state_emb
|
307 |
+
return state_embs_dict
|
308 |
+
|
309 |
+
# quantify cosine similarity of perturbed vs original or alternate states
|
310 |
+
def quant_cos_sims(model,
|
311 |
+
perturb_type,
|
312 |
+
perturbation_batch,
|
313 |
+
forward_batch_size,
|
314 |
+
layer_to_quant,
|
315 |
+
original_emb,
|
316 |
+
tokens_to_perturb,
|
317 |
+
indices_to_perturb,
|
318 |
+
perturb_group,
|
319 |
+
cell_states_to_model,
|
320 |
+
state_embs_dict,
|
321 |
+
pad_token_id,
|
322 |
+
model_input_size,
|
323 |
+
nproc):
|
324 |
+
cos = torch.nn.CosineSimilarity(dim=2)
|
325 |
+
total_batch_length = len(perturbation_batch)
|
326 |
+
if ((total_batch_length-1)/forward_batch_size).is_integer():
|
327 |
+
forward_batch_size = forward_batch_size-1
|
328 |
+
if cell_states_to_model is None:
|
329 |
+
if perturb_group == False: # (if perturb_group is True, original_emb is filtered_input_data)
|
330 |
+
comparison_batch = make_comparison_batch(original_emb, indices_to_perturb, perturb_group)
|
331 |
+
cos_sims = []
|
332 |
+
else:
|
333 |
+
possible_states = get_possible_states(cell_states_to_model)
|
334 |
+
cos_sims_vs_alt_dict = dict(zip(possible_states,[[] for i in range(len(possible_states))]))
|
335 |
+
|
336 |
+
# measure length of each element in perturbation_batch
|
337 |
+
perturbation_batch = perturbation_batch.map(
|
338 |
+
measure_length, num_proc=nproc
|
339 |
+
)
|
340 |
+
|
341 |
+
for i in range(0, total_batch_length, forward_batch_size):
|
342 |
+
max_range = min(i+forward_batch_size, total_batch_length)
|
343 |
+
|
344 |
+
perturbation_minibatch = perturbation_batch.select([i for i in range(i, max_range)])
|
345 |
+
# determine if need to pad or truncate batch
|
346 |
+
minibatch_length_set = set(perturbation_minibatch["length"])
|
347 |
+
minibatch_lengths = perturbation_minibatch["length"]
|
348 |
+
if (len(minibatch_length_set) > 1) or (max(minibatch_length_set) > model_input_size):
|
349 |
+
needs_pad_or_trunc = True
|
350 |
+
else:
|
351 |
+
needs_pad_or_trunc = False
|
352 |
+
max_len = max(minibatch_length_set)
|
353 |
+
|
354 |
+
if needs_pad_or_trunc == True:
|
355 |
+
max_len = min(max(minibatch_length_set),model_input_size)
|
356 |
+
def pad_or_trunc_example(example):
|
357 |
+
example["input_ids"] = pad_or_truncate_encoding(example["input_ids"],
|
358 |
+
pad_token_id,
|
359 |
+
max_len)
|
360 |
+
return example
|
361 |
+
perturbation_minibatch = perturbation_minibatch.map(pad_or_trunc_example, num_proc=nproc)
|
362 |
+
|
363 |
+
perturbation_minibatch.set_format(type="torch")
|
364 |
+
|
365 |
+
input_data_minibatch = perturbation_minibatch["input_ids"]
|
366 |
+
attention_mask = gen_attention_mask(perturbation_minibatch, max_len)
|
367 |
+
|
368 |
+
# extract embeddings for perturbation minibatch
|
369 |
+
with torch.no_grad():
|
370 |
+
outputs = model(
|
371 |
+
input_ids = input_data_minibatch.to("cuda"),
|
372 |
+
attention_mask = attention_mask
|
373 |
+
)
|
374 |
+
del input_data_minibatch
|
375 |
+
del perturbation_minibatch
|
376 |
+
del attention_mask
|
377 |
+
|
378 |
+
if len(indices_to_perturb)>1:
|
379 |
+
minibatch_emb = torch.squeeze(outputs.hidden_states[layer_to_quant])
|
380 |
+
else:
|
381 |
+
minibatch_emb = outputs.hidden_states[layer_to_quant]
|
382 |
+
|
383 |
+
if perturb_type == "overexpress":
|
384 |
+
# remove overexpressed genes to quantify effect on remaining genes
|
385 |
+
if perturb_group == False:
|
386 |
+
overexpressed_to_remove = 1
|
387 |
+
if perturb_group == True:
|
388 |
+
overexpressed_to_remove = len(tokens_to_perturb)
|
389 |
+
minibatch_emb = minibatch_emb[:,overexpressed_to_remove:,:]
|
390 |
+
|
391 |
+
# if quantifying single perturbation in multiple different cells, pad original batch and extract embs
|
392 |
+
if perturb_group == True:
|
393 |
+
# pad minibatch of original batch to extract embeddings
|
394 |
+
# truncate to the (model input size - # tokens to overexpress) to ensure comparability
|
395 |
+
# since max input size of perturb batch will be reduced by # tokens to overexpress
|
396 |
+
original_minibatch = original_emb.select([i for i in range(i, max_range)])
|
397 |
+
original_minibatch_lengths = original_minibatch["length"]
|
398 |
+
original_minibatch_length_set = set(original_minibatch["length"])
|
399 |
+
|
400 |
+
indices_to_perturb_minibatch = indices_to_perturb[i:i+forward_batch_size]
|
401 |
+
|
402 |
+
if perturb_type == "overexpress":
|
403 |
+
new_max_len = model_input_size - len(tokens_to_perturb)
|
404 |
+
else:
|
405 |
+
new_max_len = model_input_size
|
406 |
+
if (len(original_minibatch_length_set) > 1) or (max(original_minibatch_length_set) > new_max_len):
|
407 |
+
new_max_len = min(max(original_minibatch_length_set),new_max_len)
|
408 |
+
def pad_or_trunc_example(example):
|
409 |
+
example["input_ids"] = pad_or_truncate_encoding(example["input_ids"], pad_token_id, new_max_len)
|
410 |
+
return example
|
411 |
+
original_minibatch = original_minibatch.map(pad_or_trunc_example, num_proc=nproc)
|
412 |
+
original_minibatch.set_format(type="torch")
|
413 |
+
original_input_data_minibatch = original_minibatch["input_ids"]
|
414 |
+
attention_mask = gen_attention_mask(original_minibatch, new_max_len)
|
415 |
+
# extract embeddings for original minibatch
|
416 |
+
with torch.no_grad():
|
417 |
+
original_outputs = model(
|
418 |
+
input_ids = original_input_data_minibatch.to("cuda"),
|
419 |
+
attention_mask = attention_mask
|
420 |
+
)
|
421 |
+
del original_input_data_minibatch
|
422 |
+
del original_minibatch
|
423 |
+
del attention_mask
|
424 |
+
|
425 |
+
if len(indices_to_perturb)>1:
|
426 |
+
original_minibatch_emb = torch.squeeze(original_outputs.hidden_states[layer_to_quant])
|
427 |
+
else:
|
428 |
+
original_minibatch_emb = original_outputs.hidden_states[layer_to_quant]
|
429 |
+
|
430 |
+
# embedding dimension of the genes
|
431 |
+
gene_dim = 1
|
432 |
+
# exclude overexpression due to case when genes are not expressed but being overexpressed
|
433 |
+
if perturb_type != "overexpress":
|
434 |
+
original_minibatch_emb = remove_indices_from_emb_batch(original_minibatch_emb,
|
435 |
+
indices_to_perturb_minibatch,
|
436 |
+
gene_dim)
|
437 |
+
|
438 |
+
# cosine similarity between original emb and batch items
|
439 |
+
if cell_states_to_model is None:
|
440 |
+
if perturb_group == False:
|
441 |
+
minibatch_comparison = comparison_batch[i:max_range]
|
442 |
+
elif perturb_group == True:
|
443 |
+
minibatch_comparison = original_minibatch_emb
|
444 |
+
|
445 |
+
cos_sims += [cos(minibatch_emb, minibatch_comparison).to("cpu")]
|
446 |
+
elif cell_states_to_model is not None:
|
447 |
+
for state in possible_states:
|
448 |
+
if perturb_group == False:
|
449 |
+
cos_sims_vs_alt_dict[state] += cos_sim_shift(original_emb,
|
450 |
+
minibatch_emb,
|
451 |
+
state_embs_dict[state],
|
452 |
+
perturb_group)
|
453 |
+
elif perturb_group == True:
|
454 |
+
cos_sims_vs_alt_dict[state] += cos_sim_shift(original_minibatch_emb,
|
455 |
+
minibatch_emb,
|
456 |
+
state_embs_dict[state],
|
457 |
+
perturb_group,
|
458 |
+
torch.tensor(original_minibatch_lengths, device="cuda"),
|
459 |
+
torch.tensor(minibatch_lengths, device="cuda"))
|
460 |
+
del outputs
|
461 |
+
del minibatch_emb
|
462 |
+
if cell_states_to_model is None:
|
463 |
+
del minibatch_comparison
|
464 |
+
torch.cuda.empty_cache()
|
465 |
+
if cell_states_to_model is None:
|
466 |
+
cos_sims_stack = torch.cat(cos_sims)
|
467 |
+
return cos_sims_stack
|
468 |
+
else:
|
469 |
+
for state in possible_states:
|
470 |
+
cos_sims_vs_alt_dict[state] = torch.cat(cos_sims_vs_alt_dict[state])
|
471 |
+
return cos_sims_vs_alt_dict
|
472 |
+
|
473 |
+
# calculate cos sim shift of perturbation with respect to origin and alternative cell
|
474 |
+
def cos_sim_shift(original_emb,
|
475 |
+
minibatch_emb,
|
476 |
+
end_emb,
|
477 |
+
perturb_group,
|
478 |
+
original_minibatch_lengths = None,
|
479 |
+
minibatch_lengths = None):
|
480 |
+
cos = torch.nn.CosineSimilarity(dim=2)
|
481 |
+
if not perturb_group:
|
482 |
+
original_emb = torch.mean(original_emb,dim=0,keepdim=True)
|
483 |
+
original_emb = original_emb[None, :]
|
484 |
+
origin_v_end = torch.squeeze(cos(original_emb, end_emb)) #test
|
485 |
+
else:
|
486 |
+
if original_emb.size() != minibatch_emb.size():
|
487 |
+
logger.error(
|
488 |
+
f"Embeddings are not the same dimensions. " \
|
489 |
+
f"original_emb is {original_emb.size()}. " \
|
490 |
+
f"minibatch_emb is {minibatch_emb.size()}. "
|
491 |
+
)
|
492 |
+
raise
|
493 |
+
|
494 |
+
if original_minibatch_lengths is not None:
|
495 |
+
original_emb = mean_nonpadding_embs(original_emb, original_minibatch_lengths)
|
496 |
+
# else:
|
497 |
+
# original_emb = torch.mean(original_emb,dim=1,keepdim=True)
|
498 |
+
|
499 |
+
end_emb = torch.unsqueeze(end_emb, 1)
|
500 |
+
origin_v_end = cos(original_emb, end_emb)
|
501 |
+
origin_v_end = torch.squeeze(origin_v_end)
|
502 |
+
if minibatch_lengths is not None:
|
503 |
+
perturb_emb = mean_nonpadding_embs(minibatch_emb, minibatch_lengths)
|
504 |
+
else:
|
505 |
+
perturb_emb = torch.mean(minibatch_emb,dim=1,keepdim=True)
|
506 |
+
|
507 |
+
perturb_v_end = cos(perturb_emb, end_emb)
|
508 |
+
perturb_v_end = torch.squeeze(perturb_v_end)
|
509 |
+
return [(perturb_v_end-origin_v_end).to("cpu")]
|
510 |
+
|
511 |
+
def pad_list(input_ids, pad_token_id, max_len):
|
512 |
+
input_ids = np.pad(input_ids,
|
513 |
+
(0, max_len-len(input_ids)),
|
514 |
+
mode='constant', constant_values=pad_token_id)
|
515 |
+
return input_ids
|
516 |
+
|
517 |
+
def pad_tensor(tensor, pad_token_id, max_len):
|
518 |
+
tensor = torch.nn.functional.pad(tensor, pad=(0,
|
519 |
+
max_len - tensor.numel()),
|
520 |
+
mode='constant',
|
521 |
+
value=pad_token_id)
|
522 |
+
return tensor
|
523 |
+
|
524 |
+
def pad_2d_tensor(tensor, pad_token_id, max_len, dim):
|
525 |
+
if dim == 0:
|
526 |
+
pad = (0, 0, 0, max_len - tensor.size()[dim])
|
527 |
+
elif dim == 1:
|
528 |
+
pad = (0, max_len - tensor.size()[dim], 0, 0)
|
529 |
+
tensor = torch.nn.functional.pad(tensor, pad=pad,
|
530 |
+
mode='constant',
|
531 |
+
value=pad_token_id)
|
532 |
+
return tensor
|
533 |
+
|
534 |
+
def pad_or_truncate_encoding(encoding, pad_token_id, max_len):
|
535 |
+
if isinstance(encoding, torch.Tensor):
|
536 |
+
encoding_len = tensor.size()[0]
|
537 |
+
elif isinstance(encoding, list):
|
538 |
+
encoding_len = len(encoding)
|
539 |
+
if encoding_len > max_len:
|
540 |
+
encoding = encoding[0:max_len]
|
541 |
+
elif encoding_len < max_len:
|
542 |
+
if isinstance(encoding, torch.Tensor):
|
543 |
+
encoding = pad_tensor(encoding, pad_token_id, max_len)
|
544 |
+
elif isinstance(encoding, list):
|
545 |
+
encoding = pad_list(encoding, pad_token_id, max_len)
|
546 |
+
return encoding
|
547 |
+
|
548 |
+
# pad list of tensors and convert to tensor
|
549 |
+
def pad_tensor_list(tensor_list, dynamic_or_constant, pad_token_id, model_input_size):
|
550 |
+
|
551 |
+
# Determine maximum tensor length
|
552 |
+
if dynamic_or_constant == "dynamic":
|
553 |
+
max_len = max([tensor.squeeze().numel() for tensor in tensor_list])
|
554 |
+
elif type(dynamic_or_constant) == int:
|
555 |
+
max_len = dynamic_or_constant
|
556 |
+
else:
|
557 |
+
max_len = model_input_size
|
558 |
+
logger.warning(
|
559 |
+
"If padding style is constant, must provide integer value. " \
|
560 |
+
f"Setting padding to max input size {model_input_size}.")
|
561 |
+
|
562 |
+
# pad all tensors to maximum length
|
563 |
+
tensor_list = [pad_tensor(tensor, pad_token_id, max_len) for tensor in tensor_list]
|
564 |
+
|
565 |
+
# return stacked tensors
|
566 |
+
return torch.stack(tensor_list)
|
567 |
+
|
568 |
+
def gen_attention_mask(minibatch_encoding, max_len = None):
|
569 |
+
if max_len == None:
|
570 |
+
max_len = max(minibatch_encoding["length"])
|
571 |
+
original_lens = minibatch_encoding["length"]
|
572 |
+
attention_mask = [[1]*original_len
|
573 |
+
+[0]*(max_len - original_len)
|
574 |
+
if original_len <= max_len
|
575 |
+
else [1]*max_len
|
576 |
+
for original_len in original_lens]
|
577 |
+
return torch.tensor(attention_mask).to("cuda")
|
578 |
+
|
579 |
+
# get cell embeddings excluding padding
|
580 |
+
def mean_nonpadding_embs(embs, original_lens):
|
581 |
+
# mask based on padding lengths
|
582 |
+
mask = torch.arange(embs.size(1)).unsqueeze(0).to("cuda") < original_lens.unsqueeze(1)
|
583 |
+
|
584 |
+
# extend mask dimensions to match the embeddings tensor
|
585 |
+
mask = mask.unsqueeze(2).expand_as(embs)
|
586 |
+
|
587 |
+
# use the mask to zero out the embeddings in padded areas
|
588 |
+
masked_embs = embs * mask.float()
|
589 |
+
|
590 |
+
# sum and divide by the lengths to get the mean of non-padding embs
|
591 |
+
mean_embs = masked_embs.sum(1) / original_lens.view(-1, 1).float()
|
592 |
+
return mean_embs
|
593 |
+
|
594 |
+
class InSilicoPerturber:
|
595 |
+
valid_option_dict = {
|
596 |
+
"perturb_type": {"delete","overexpress","inhibit","activate"},
|
597 |
+
"perturb_rank_shift": {None, 1, 2, 3},
|
598 |
+
"genes_to_perturb": {"all", list},
|
599 |
+
"combos": {0, 1},
|
600 |
+
"anchor_gene": {None, str},
|
601 |
+
"model_type": {"Pretrained","GeneClassifier","CellClassifier"},
|
602 |
+
"num_classes": {int},
|
603 |
+
"emb_mode": {"cell","cell_and_gene"},
|
604 |
+
"cell_emb_style": {"mean_pool"},
|
605 |
+
"filter_data": {None, dict},
|
606 |
+
"cell_states_to_model": {None, dict},
|
607 |
+
"max_ncells": {None, int},
|
608 |
+
"cell_inds_to_perturb": {"all", dict},
|
609 |
+
"emb_layer": {-1, 0},
|
610 |
+
"forward_batch_size": {int},
|
611 |
+
"nproc": {int},
|
612 |
+
}
|
613 |
+
def __init__(
|
614 |
+
self,
|
615 |
+
perturb_type="delete",
|
616 |
+
perturb_rank_shift=None,
|
617 |
+
genes_to_perturb="all",
|
618 |
+
combos=0,
|
619 |
+
anchor_gene=None,
|
620 |
+
model_type="Pretrained",
|
621 |
+
num_classes=0,
|
622 |
+
emb_mode="cell",
|
623 |
+
cell_emb_style="mean_pool",
|
624 |
+
filter_data=None,
|
625 |
+
cell_states_to_model=None,
|
626 |
+
max_ncells=None,
|
627 |
+
cell_inds_to_perturb="all",
|
628 |
+
emb_layer=-1,
|
629 |
+
forward_batch_size=100,
|
630 |
+
nproc=4,
|
631 |
+
token_dictionary_file=TOKEN_DICTIONARY_FILE,
|
632 |
+
):
|
633 |
+
"""
|
634 |
+
Initialize in silico perturber.
|
635 |
+
|
636 |
+
Parameters
|
637 |
+
----------
|
638 |
+
perturb_type : {"delete","overexpress","inhibit","activate"}
|
639 |
+
Type of perturbation.
|
640 |
+
"delete": delete gene from rank value encoding
|
641 |
+
"overexpress": move gene to front of rank value encoding
|
642 |
+
"inhibit": move gene to lower quartile of rank value encoding
|
643 |
+
"activate": move gene to higher quartile of rank value encoding
|
644 |
+
perturb_rank_shift : None, {1,2,3}
|
645 |
+
Number of quartiles by which to shift rank of gene.
|
646 |
+
For example, if perturb_type="activate" and perturb_rank_shift=1:
|
647 |
+
genes in 4th quartile will move to middle of 3rd quartile.
|
648 |
+
genes in 3rd quartile will move to middle of 2nd quartile.
|
649 |
+
genes in 2nd quartile will move to middle of 1st quartile.
|
650 |
+
genes in 1st quartile will move to front of rank value encoding.
|
651 |
+
For example, if perturb_type="inhibit" and perturb_rank_shift=2:
|
652 |
+
genes in 1st quartile will move to middle of 3rd quartile.
|
653 |
+
genes in 2nd quartile will move to middle of 4th quartile.
|
654 |
+
genes in 3rd or 4th quartile will move to bottom of rank value encoding.
|
655 |
+
genes_to_perturb : "all", list
|
656 |
+
Default is perturbing each gene detected in each cell in the dataset.
|
657 |
+
Otherwise, may provide a list of ENSEMBL IDs of genes to perturb.
|
658 |
+
If gene list is provided, then perturber will only test perturbing them all together
|
659 |
+
(rather than testing each possible combination of the provided genes).
|
660 |
+
combos : {0,1}
|
661 |
+
Whether to perturb genes individually (0) or in pairs (1).
|
662 |
+
anchor_gene : None, str
|
663 |
+
ENSEMBL ID of gene to use as anchor in combination perturbations.
|
664 |
+
For example, if combos=1 and anchor_gene="ENSG00000148400":
|
665 |
+
anchor gene will be perturbed in combination with each other gene.
|
666 |
+
model_type : {"Pretrained","GeneClassifier","CellClassifier"}
|
667 |
+
Whether model is the pretrained Geneformer or a fine-tuned gene or cell classifier.
|
668 |
+
num_classes : int
|
669 |
+
If model is a gene or cell classifier, specify number of classes it was trained to classify.
|
670 |
+
For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
|
671 |
+
emb_mode : {"cell","cell_and_gene"}
|
672 |
+
Whether to output impact of perturbation on cell and/or gene embeddings.
|
673 |
+
cell_emb_style : "mean_pool"
|
674 |
+
Method for summarizing cell embeddings.
|
675 |
+
Currently only option is mean pooling of gene embeddings for given cell.
|
676 |
+
filter_data : None, dict
|
677 |
+
Default is to use all input data for in silico perturbation study.
|
678 |
+
Otherwise, dictionary specifying .dataset column name and list of values to filter by.
|
679 |
+
cell_states_to_model: None, dict
|
680 |
+
Cell states to model if testing perturbations that achieve goal state change.
|
681 |
+
Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states
|
682 |
+
state_key: key specifying name of column in .dataset that defines the start/goal states
|
683 |
+
start_state: value in the state_key column that specifies the start state
|
684 |
+
goal_state: value in the state_key column taht specifies the goal end state
|
685 |
+
alt_states: list of values in the state_key column that specify the alternate end states
|
686 |
+
For example: {"state_key": "disease",
|
687 |
+
"start_state": "dcm",
|
688 |
+
"goal_state": "nf",
|
689 |
+
"alt_states": ["hcm", "other1", "other2"]}
|
690 |
+
max_ncells : None, int
|
691 |
+
Maximum number of cells to test.
|
692 |
+
If None, will test all cells.
|
693 |
+
cell_inds_to_perturb : "all", list
|
694 |
+
Default is perturbing each cell in the dataset.
|
695 |
+
Otherwise, may provide a dict of indices of cells to perturb with keys start_ind and end_ind.
|
696 |
+
start_ind: the first index to perturb.
|
697 |
+
end_ind: the last index to perturb (exclusive).
|
698 |
+
Indices will be selected *after* the filter_data criteria and sorting.
|
699 |
+
Useful for splitting extremely large datasets across separate GPUs.
|
700 |
+
emb_layer : {-1, 0}
|
701 |
+
Embedding layer to use for quantification.
|
702 |
+
-1: 2nd to last layer (recommended for pretrained Geneformer)
|
703 |
+
0: last layer (recommended for cell classifier fine-tuned for disease state)
|
704 |
+
forward_batch_size : int
|
705 |
+
Batch size for forward pass.
|
706 |
+
nproc : int
|
707 |
+
Number of CPU processes to use.
|
708 |
+
token_dictionary_file : Path
|
709 |
+
Path to pickle file containing token dictionary (Ensembl ID:token).
|
710 |
+
"""
|
711 |
+
|
712 |
+
self.perturb_type = perturb_type
|
713 |
+
self.perturb_rank_shift = perturb_rank_shift
|
714 |
+
self.genes_to_perturb = genes_to_perturb
|
715 |
+
self.combos = combos
|
716 |
+
self.anchor_gene = anchor_gene
|
717 |
+
if self.genes_to_perturb == "all":
|
718 |
+
self.perturb_group = False
|
719 |
+
else:
|
720 |
+
self.perturb_group = True
|
721 |
+
if (self.anchor_gene != None) or (self.combos != 0):
|
722 |
+
self.anchor_gene = None
|
723 |
+
self.combos = 0
|
724 |
+
logger.warning(
|
725 |
+
"anchor_gene set to None and combos set to 0. " \
|
726 |
+
"If providing list of genes to perturb, " \
|
727 |
+
"list of genes_to_perturb will be perturbed together, "\
|
728 |
+
"without anchor gene or combinations.")
|
729 |
+
self.model_type = model_type
|
730 |
+
self.num_classes = num_classes
|
731 |
+
self.emb_mode = emb_mode
|
732 |
+
self.cell_emb_style = cell_emb_style
|
733 |
+
self.filter_data = filter_data
|
734 |
+
self.cell_states_to_model = cell_states_to_model
|
735 |
+
self.max_ncells = max_ncells
|
736 |
+
self.cell_inds_to_perturb = cell_inds_to_perturb
|
737 |
+
self.emb_layer = emb_layer
|
738 |
+
self.forward_batch_size = forward_batch_size
|
739 |
+
self.nproc = nproc
|
740 |
+
|
741 |
+
self.validate_options()
|
742 |
+
|
743 |
+
# load token dictionary (Ensembl IDs:token)
|
744 |
+
with open(token_dictionary_file, "rb") as f:
|
745 |
+
self.gene_token_dict = pickle.load(f)
|
746 |
+
|
747 |
+
self.pad_token_id = self.gene_token_dict.get("<pad>")
|
748 |
+
|
749 |
+
if self.anchor_gene is None:
|
750 |
+
self.anchor_token = None
|
751 |
+
else:
|
752 |
+
try:
|
753 |
+
self.anchor_token = [self.gene_token_dict[self.anchor_gene]]
|
754 |
+
except KeyError:
|
755 |
+
logger.error(
|
756 |
+
f"Anchor gene {self.anchor_gene} not in token dictionary."
|
757 |
+
)
|
758 |
+
raise
|
759 |
+
|
760 |
+
if self.genes_to_perturb == "all":
|
761 |
+
self.tokens_to_perturb = "all"
|
762 |
+
else:
|
763 |
+
missing_genes = [gene for gene in self.genes_to_perturb if gene not in self.gene_token_dict.keys()]
|
764 |
+
if len(missing_genes) == len(self.genes_to_perturb):
|
765 |
+
logger.error(
|
766 |
+
"None of the provided genes to perturb are in token dictionary."
|
767 |
+
)
|
768 |
+
raise
|
769 |
+
elif len(missing_genes)>0:
|
770 |
+
logger.warning(
|
771 |
+
f"Genes to perturb {missing_genes} are not in token dictionary.")
|
772 |
+
self.tokens_to_perturb = [self.gene_token_dict.get(gene) for gene in self.genes_to_perturb]
|
773 |
+
|
774 |
+
def validate_options(self):
|
775 |
+
# first disallow options under development
|
776 |
+
if self.perturb_type in ["inhibit", "activate"]:
|
777 |
+
logger.error(
|
778 |
+
"In silico inhibition and activation currently under development. " \
|
779 |
+
"Current valid options for 'perturb_type': 'delete' or 'overexpress'"
|
780 |
+
)
|
781 |
+
raise
|
782 |
+
|
783 |
+
# confirm arguments are within valid options and compatible with each other
|
784 |
+
for attr_name,valid_options in self.valid_option_dict.items():
|
785 |
+
attr_value = self.__dict__[attr_name]
|
786 |
+
if type(attr_value) not in {list, dict}:
|
787 |
+
if attr_value in valid_options:
|
788 |
+
continue
|
789 |
+
if attr_name in ["anchor_gene"]:
|
790 |
+
if type(attr_name) in {str}:
|
791 |
+
continue
|
792 |
+
valid_type = False
|
793 |
+
for option in valid_options:
|
794 |
+
if (option in [int,list,dict]) and isinstance(attr_value, option):
|
795 |
+
valid_type = True
|
796 |
+
break
|
797 |
+
if valid_type:
|
798 |
+
continue
|
799 |
+
logger.error(
|
800 |
+
f"Invalid option for {attr_name}. " \
|
801 |
+
f"Valid options for {attr_name}: {valid_options}"
|
802 |
+
)
|
803 |
+
raise
|
804 |
+
|
805 |
+
if self.perturb_type in ["delete","overexpress"]:
|
806 |
+
if self.perturb_rank_shift is not None:
|
807 |
+
if self.perturb_type == "delete":
|
808 |
+
logger.warning(
|
809 |
+
"perturb_rank_shift set to None. " \
|
810 |
+
"If perturb type is delete then gene is deleted entirely " \
|
811 |
+
"rather than shifted by quartile")
|
812 |
+
elif self.perturb_type == "overexpress":
|
813 |
+
logger.warning(
|
814 |
+
"perturb_rank_shift set to None. " \
|
815 |
+
"If perturb type is overexpress then gene is moved to front " \
|
816 |
+
"of rank value encoding rather than shifted by quartile")
|
817 |
+
self.perturb_rank_shift = None
|
818 |
+
|
819 |
+
if (self.anchor_gene is not None) and (self.emb_mode == "cell_and_gene"):
|
820 |
+
self.emb_mode = "cell"
|
821 |
+
logger.warning(
|
822 |
+
"emb_mode set to 'cell'. " \
|
823 |
+
"Currently, analysis with anchor gene " \
|
824 |
+
"only outputs effect on cell embeddings.")
|
825 |
+
|
826 |
+
if self.cell_states_to_model is not None:
|
827 |
+
if len(self.cell_states_to_model.items()) == 1:
|
828 |
+
logger.warning(
|
829 |
+
"The single value dictionary for cell_states_to_model will be " \
|
830 |
+
"replaced with a dictionary with named keys for start, goal, and alternate states. " \
|
831 |
+
"Please specify state_key, start_state, goal_state, and alt_states " \
|
832 |
+
"in the cell_states_to_model dictionary for future use. " \
|
833 |
+
"For example, cell_states_to_model={" \
|
834 |
+
"'state_key': 'disease', " \
|
835 |
+
"'start_state': 'dcm', " \
|
836 |
+
"'goal_state': 'nf', " \
|
837 |
+
"'alt_states': ['hcm', 'other1', 'other2']}"
|
838 |
+
)
|
839 |
+
for key,value in self.cell_states_to_model.items():
|
840 |
+
if (len(value) == 3) and isinstance(value, tuple):
|
841 |
+
if isinstance(value[0],list) and isinstance(value[1],list) and isinstance(value[2],list):
|
842 |
+
if len(value[0]) == 1 and len(value[1]) == 1:
|
843 |
+
all_values = value[0]+value[1]+value[2]
|
844 |
+
if len(all_values) == len(set(all_values)):
|
845 |
+
continue
|
846 |
+
# reformat to the new named key format
|
847 |
+
state_values = flatten_list(list(self.cell_states_to_model.values()))
|
848 |
+
self.cell_states_to_model = {
|
849 |
+
"state_key": list(self.cell_states_to_model.keys())[0],
|
850 |
+
"start_state": state_values[0][0],
|
851 |
+
"goal_state": state_values[1][0],
|
852 |
+
"alt_states": state_values[2:][0]
|
853 |
+
}
|
854 |
+
elif set(self.cell_states_to_model.keys()) == {"state_key", "start_state", "goal_state", "alt_states"}:
|
855 |
+
if (self.cell_states_to_model["state_key"] is None) \
|
856 |
+
or (self.cell_states_to_model["start_state"] is None) \
|
857 |
+
or (self.cell_states_to_model["goal_state"] is None):
|
858 |
+
logger.error(
|
859 |
+
"Please specify 'state_key', 'start_state', and 'goal_state' in cell_states_to_model.")
|
860 |
+
raise
|
861 |
+
|
862 |
+
if self.cell_states_to_model["start_state"] == self.cell_states_to_model["goal_state"]:
|
863 |
+
logger.error(
|
864 |
+
"All states must be unique.")
|
865 |
+
raise
|
866 |
+
|
867 |
+
if self.cell_states_to_model["alt_states"] is not None:
|
868 |
+
if type(self.cell_states_to_model["alt_states"]) is not list:
|
869 |
+
logger.error(
|
870 |
+
"self.cell_states_to_model['alt_states'] must be a list (even if it is one element)."
|
871 |
+
)
|
872 |
+
raise
|
873 |
+
if len(self.cell_states_to_model["alt_states"])!= len(set(self.cell_states_to_model["alt_states"])):
|
874 |
+
logger.error(
|
875 |
+
"All states must be unique.")
|
876 |
+
raise
|
877 |
+
|
878 |
+
else:
|
879 |
+
logger.error(
|
880 |
+
"cell_states_to_model must only have the following four keys: " \
|
881 |
+
"'state_key', 'start_state', 'goal_state', 'alt_states'." \
|
882 |
+
"For example, cell_states_to_model={" \
|
883 |
+
"'state_key': 'disease', " \
|
884 |
+
"'start_state': 'dcm', " \
|
885 |
+
"'goal_state': 'nf', " \
|
886 |
+
"'alt_states': ['hcm', 'other1', 'other2']}"
|
887 |
+
)
|
888 |
+
raise
|
889 |
+
|
890 |
+
if self.anchor_gene is not None:
|
891 |
+
self.anchor_gene = None
|
892 |
+
logger.warning(
|
893 |
+
"anchor_gene set to None. " \
|
894 |
+
"Currently, anchor gene not available " \
|
895 |
+
"when modeling multiple cell states.")
|
896 |
+
|
897 |
+
if self.perturb_type in ["inhibit","activate"]:
|
898 |
+
if self.perturb_rank_shift is None:
|
899 |
+
logger.error(
|
900 |
+
"If perturb_type is inhibit or activate then " \
|
901 |
+
"quartile to shift by must be specified.")
|
902 |
+
raise
|
903 |
+
|
904 |
+
if self.filter_data is not None:
|
905 |
+
for key,value in self.filter_data.items():
|
906 |
+
if type(value) != list:
|
907 |
+
self.filter_data[key] = [value]
|
908 |
+
logger.warning(
|
909 |
+
"Values in filter_data dict must be lists. " \
|
910 |
+
f"Changing {key} value to list ([{value}]).")
|
911 |
+
|
912 |
+
if self.cell_inds_to_perturb != "all":
|
913 |
+
if set(self.cell_inds_to_perturb.keys()) != {"start", "end"}:
|
914 |
+
logger.error(
|
915 |
+
"If cell_inds_to_perturb is a dictionary, keys must be 'start' and 'end'."
|
916 |
+
)
|
917 |
+
raise
|
918 |
+
if self.cell_inds_to_perturb["start"] < 0 or self.cell_inds_to_perturb["end"] < 0:
|
919 |
+
logger.error(
|
920 |
+
'cell_inds_to_perturb must be positive.'
|
921 |
+
)
|
922 |
+
raise
|
923 |
+
|
924 |
+
def perturb_data(self,
|
925 |
+
model_directory,
|
926 |
+
input_data_file,
|
927 |
+
output_directory,
|
928 |
+
output_prefix):
|
929 |
+
"""
|
930 |
+
Perturb genes in input data and save as results in output_directory.
|
931 |
+
|
932 |
+
Parameters
|
933 |
+
----------
|
934 |
+
model_directory : Path
|
935 |
+
Path to directory containing model
|
936 |
+
input_data_file : Path
|
937 |
+
Path to directory containing .dataset inputs
|
938 |
+
output_directory : Path
|
939 |
+
Path to directory where perturbation data will be saved as batched pickle files
|
940 |
+
output_prefix : str
|
941 |
+
Prefix for output files
|
942 |
+
"""
|
943 |
+
|
944 |
+
filtered_input_data = load_and_filter(self.filter_data, self.nproc, input_data_file)
|
945 |
+
model = load_model(self.model_type, self.num_classes, model_directory)
|
946 |
+
layer_to_quant = quant_layers(model)+self.emb_layer
|
947 |
+
|
948 |
+
if self.cell_states_to_model is None:
|
949 |
+
state_embs_dict = None
|
950 |
+
else:
|
951 |
+
# confirm that all states are valid to prevent futile filtering
|
952 |
+
state_name = self.cell_states_to_model["state_key"]
|
953 |
+
state_values = filtered_input_data[state_name]
|
954 |
+
for value in get_possible_states(self.cell_states_to_model):
|
955 |
+
if value not in state_values:
|
956 |
+
logger.error(
|
957 |
+
f"{value} is not present in the dataset's {state_name} attribute.")
|
958 |
+
raise
|
959 |
+
# get dictionary of average cell state embeddings for comparison
|
960 |
+
downsampled_data = downsample_and_sort(filtered_input_data, self.max_ncells)
|
961 |
+
state_embs_dict = get_cell_state_avg_embs(model,
|
962 |
+
downsampled_data,
|
963 |
+
self.cell_states_to_model,
|
964 |
+
layer_to_quant,
|
965 |
+
self.pad_token_id,
|
966 |
+
self.forward_batch_size,
|
967 |
+
self.nproc)
|
968 |
+
# filter for start state cells
|
969 |
+
start_state = self.cell_states_to_model["start_state"]
|
970 |
+
def filter_for_origin(example):
|
971 |
+
return example[state_name] in [start_state]
|
972 |
+
|
973 |
+
filtered_input_data = filtered_input_data.filter(filter_for_origin, num_proc=self.nproc)
|
974 |
+
|
975 |
+
self.in_silico_perturb(model,
|
976 |
+
filtered_input_data,
|
977 |
+
layer_to_quant,
|
978 |
+
state_embs_dict,
|
979 |
+
output_directory,
|
980 |
+
output_prefix)
|
981 |
+
|
982 |
+
# determine effect of perturbation on other genes
|
983 |
+
def in_silico_perturb(self,
|
984 |
+
model,
|
985 |
+
filtered_input_data,
|
986 |
+
layer_to_quant,
|
987 |
+
state_embs_dict,
|
988 |
+
output_directory,
|
989 |
+
output_prefix):
|
990 |
+
|
991 |
+
output_path_prefix = f"{output_directory}in_silico_{self.perturb_type}_{output_prefix}_dict_1Kbatch"
|
992 |
+
model_input_size = get_model_input_size(model)
|
993 |
+
|
994 |
+
# filter dataset for cells that have tokens to be perturbed
|
995 |
+
if self.anchor_token is not None:
|
996 |
+
def if_has_tokens_to_perturb(example):
|
997 |
+
return (len(set(example["input_ids"]).intersection(self.anchor_token))==len(self.anchor_token))
|
998 |
+
filtered_input_data = filtered_input_data.filter(if_has_tokens_to_perturb, num_proc=self.nproc)
|
999 |
+
if len(filtered_input_data) == 0:
|
1000 |
+
logger.error(
|
1001 |
+
"No cells in dataset contain anchor gene.")
|
1002 |
+
raise
|
1003 |
+
else:
|
1004 |
+
logger.info(f"# cells with anchor gene: {len(filtered_input_data)}")
|
1005 |
+
|
1006 |
+
if (self.tokens_to_perturb != "all") and (self.perturb_type != "overexpress"):
|
1007 |
+
# minimum # genes needed for perturbation test
|
1008 |
+
min_genes = len(self.tokens_to_perturb)
|
1009 |
+
|
1010 |
+
def if_has_tokens_to_perturb(example):
|
1011 |
+
return (len(set(example["input_ids"]).intersection(self.tokens_to_perturb))>=min_genes)
|
1012 |
+
filtered_input_data = filtered_input_data.filter(if_has_tokens_to_perturb, num_proc=self.nproc)
|
1013 |
+
if len(filtered_input_data) == 0:
|
1014 |
+
logger.error(
|
1015 |
+
"No cells in dataset contain all genes to perturb as a group.")
|
1016 |
+
raise
|
1017 |
+
|
1018 |
+
cos_sims_dict = defaultdict(list)
|
1019 |
+
pickle_batch = -1
|
1020 |
+
filtered_input_data = downsample_and_sort(filtered_input_data, self.max_ncells)
|
1021 |
+
if self.cell_inds_to_perturb != "all":
|
1022 |
+
if self.cell_inds_to_perturb["start"] >= len(filtered_input_data):
|
1023 |
+
logger.error("cell_inds_to_perturb['start'] is larger than the filtered dataset.")
|
1024 |
+
raise
|
1025 |
+
if self.cell_inds_to_perturb["end"] > len(filtered_input_data):
|
1026 |
+
logger.warning("cell_inds_to_perturb['end'] is larger than the filtered dataset. \
|
1027 |
+
Setting to the end of the filtered dataset.")
|
1028 |
+
self.cell_inds_to_perturb["end"] = len(filtered_input_data)
|
1029 |
+
filtered_input_data = filtered_input_data.select([i for i in range(self.cell_inds_to_perturb["start"], self.cell_inds_to_perturb["end"])])
|
1030 |
+
|
1031 |
+
# make perturbation batch w/ single perturbation in multiple cells
|
1032 |
+
if self.perturb_group == True:
|
1033 |
+
|
1034 |
+
def make_group_perturbation_batch(example):
|
1035 |
+
example_input_ids = example["input_ids"]
|
1036 |
+
example["tokens_to_perturb"] = self.tokens_to_perturb
|
1037 |
+
indices_to_perturb = [example_input_ids.index(token) if token in example_input_ids else None for token in self.tokens_to_perturb]
|
1038 |
+
indices_to_perturb = [item for item in indices_to_perturb if item is not None]
|
1039 |
+
if len(indices_to_perturb) > 0:
|
1040 |
+
example["perturb_index"] = indices_to_perturb
|
1041 |
+
else:
|
1042 |
+
# -100 indicates tokens to overexpress are not present in rank value encoding
|
1043 |
+
example["perturb_index"] = [-100]
|
1044 |
+
if self.perturb_type == "delete":
|
1045 |
+
example = delete_indices(example)
|
1046 |
+
elif self.perturb_type == "overexpress":
|
1047 |
+
example = overexpress_tokens(example)
|
1048 |
+
return example
|
1049 |
+
|
1050 |
+
perturbation_batch = filtered_input_data.map(make_group_perturbation_batch, num_proc=self.nproc)
|
1051 |
+
indices_to_perturb = perturbation_batch["perturb_index"]
|
1052 |
+
|
1053 |
+
cos_sims_data = quant_cos_sims(model,
|
1054 |
+
self.perturb_type,
|
1055 |
+
perturbation_batch,
|
1056 |
+
self.forward_batch_size,
|
1057 |
+
layer_to_quant,
|
1058 |
+
filtered_input_data,
|
1059 |
+
self.tokens_to_perturb,
|
1060 |
+
indices_to_perturb,
|
1061 |
+
self.perturb_group,
|
1062 |
+
self.cell_states_to_model,
|
1063 |
+
state_embs_dict,
|
1064 |
+
self.pad_token_id,
|
1065 |
+
model_input_size,
|
1066 |
+
self.nproc)
|
1067 |
+
|
1068 |
+
perturbed_genes = tuple(self.tokens_to_perturb)
|
1069 |
+
original_lengths = filtered_input_data["length"]
|
1070 |
+
if self.cell_states_to_model is None:
|
1071 |
+
# update cos sims dict
|
1072 |
+
# key is tuple of (perturbed_gene, affected_gene)
|
1073 |
+
# or (perturbed_genes, "cell_emb") for avg cell emb change
|
1074 |
+
cos_sims_data = cos_sims_data.to("cuda")
|
1075 |
+
max_padded_len = cos_sims_data.shape[1]
|
1076 |
+
for j in range(cos_sims_data.shape[0]):
|
1077 |
+
# remove padding before mean pooling cell embedding
|
1078 |
+
original_length = original_lengths[j]
|
1079 |
+
gene_list = filtered_input_data[j]["input_ids"]
|
1080 |
+
indices_removed = indices_to_perturb[j]
|
1081 |
+
padding_to_remove = max_padded_len - (original_length \
|
1082 |
+
- len(self.tokens_to_perturb) \
|
1083 |
+
- len(indices_removed))
|
1084 |
+
nonpadding_cos_sims_data = cos_sims_data[j][:-padding_to_remove]
|
1085 |
+
cell_cos_sim = torch.mean(nonpadding_cos_sims_data).item()
|
1086 |
+
cos_sims_dict[(perturbed_genes, "cell_emb")] += [cell_cos_sim]
|
1087 |
+
|
1088 |
+
if self.emb_mode == "cell_and_gene":
|
1089 |
+
for k in range(cos_sims_data.shape[1]):
|
1090 |
+
cos_sim_value = nonpadding_cos_sims_data[k]
|
1091 |
+
affected_gene = gene_list[k].item()
|
1092 |
+
cos_sims_dict[(perturbed_genes, affected_gene)] += [cos_sim_value.item()]
|
1093 |
+
else:
|
1094 |
+
# update cos sims dict
|
1095 |
+
# key is tuple of (perturbed_genes, "cell_emb")
|
1096 |
+
# value is list of tuples of cos sims for cell_states_to_model
|
1097 |
+
origin_state_key = self.cell_states_to_model["start_state"]
|
1098 |
+
cos_sims_origin = cos_sims_data[origin_state_key]
|
1099 |
+
for j in range(cos_sims_origin.shape[0]):
|
1100 |
+
data_list = []
|
1101 |
+
for data in list(cos_sims_data.values()):
|
1102 |
+
data_item = data.to("cuda")
|
1103 |
+
data_list += [data_item[j].item()]
|
1104 |
+
cos_sims_dict[(perturbed_genes, "cell_emb")] += [tuple(data_list)]
|
1105 |
+
|
1106 |
+
with open(f"{output_path_prefix}_raw.pickle", "wb") as fp:
|
1107 |
+
pickle.dump(cos_sims_dict, fp)
|
1108 |
+
|
1109 |
+
# make perturbation batch w/ multiple perturbations in single cell
|
1110 |
+
if self.perturb_group == False:
|
1111 |
+
|
1112 |
+
for i in trange(len(filtered_input_data)):
|
1113 |
+
example_cell = filtered_input_data.select([i])
|
1114 |
+
original_emb = forward_pass_single_cell(model, example_cell, layer_to_quant)
|
1115 |
+
gene_list = torch.squeeze(example_cell["input_ids"])
|
1116 |
+
|
1117 |
+
# reset to original type to prevent downstream issues due to forward_pass_single_cell modifying as torch format in place
|
1118 |
+
example_cell = filtered_input_data.select([i])
|
1119 |
+
|
1120 |
+
if self.anchor_token is None:
|
1121 |
+
for combo_lvl in range(self.combos+1):
|
1122 |
+
perturbation_batch, indices_to_perturb = make_perturbation_batch(example_cell,
|
1123 |
+
self.perturb_type,
|
1124 |
+
self.tokens_to_perturb,
|
1125 |
+
self.anchor_token,
|
1126 |
+
combo_lvl,
|
1127 |
+
self.nproc)
|
1128 |
+
cos_sims_data = quant_cos_sims(model,
|
1129 |
+
self.perturb_type,
|
1130 |
+
perturbation_batch,
|
1131 |
+
self.forward_batch_size,
|
1132 |
+
layer_to_quant,
|
1133 |
+
original_emb,
|
1134 |
+
self.tokens_to_perturb,
|
1135 |
+
indices_to_perturb,
|
1136 |
+
self.perturb_group,
|
1137 |
+
self.cell_states_to_model,
|
1138 |
+
state_embs_dict,
|
1139 |
+
self.pad_token_id,
|
1140 |
+
model_input_size,
|
1141 |
+
self.nproc)
|
1142 |
+
|
1143 |
+
if self.cell_states_to_model is None:
|
1144 |
+
# update cos sims dict
|
1145 |
+
# key is tuple of (perturbed_gene, affected_gene)
|
1146 |
+
# or (perturbed_gene, "cell_emb") for avg cell emb change
|
1147 |
+
cos_sims_data = cos_sims_data.to("cuda")
|
1148 |
+
for j in range(cos_sims_data.shape[0]):
|
1149 |
+
if self.tokens_to_perturb != "all":
|
1150 |
+
j_index = torch.tensor(indices_to_perturb[j])
|
1151 |
+
if j_index.shape[0]>1:
|
1152 |
+
j_index = torch.squeeze(j_index)
|
1153 |
+
else:
|
1154 |
+
j_index = torch.tensor([j])
|
1155 |
+
perturbed_gene = torch.index_select(gene_list, 0, j_index)
|
1156 |
+
|
1157 |
+
if perturbed_gene.shape[0]==1:
|
1158 |
+
perturbed_gene = perturbed_gene.item()
|
1159 |
+
elif perturbed_gene.shape[0]>1:
|
1160 |
+
perturbed_gene = tuple(perturbed_gene.tolist())
|
1161 |
+
|
1162 |
+
cell_cos_sim = torch.mean(cos_sims_data[j]).item()
|
1163 |
+
cos_sims_dict[(perturbed_gene, "cell_emb")] += [cell_cos_sim]
|
1164 |
+
|
1165 |
+
# not_j_index = list(set(i for i in range(gene_list.shape[0])).difference(j_index))
|
1166 |
+
# gene_list_j = torch.index_select(gene_list, 0, j_index)
|
1167 |
+
if self.emb_mode == "cell_and_gene":
|
1168 |
+
for k in range(cos_sims_data.shape[1]):
|
1169 |
+
cos_sim_value = cos_sims_data[j][k]
|
1170 |
+
affected_gene = gene_list[k].item()
|
1171 |
+
cos_sims_dict[(perturbed_gene, affected_gene)] += [cos_sim_value.item()]
|
1172 |
+
else:
|
1173 |
+
# update cos sims dict
|
1174 |
+
# key is tuple of (perturbed_gene, "cell_emb")
|
1175 |
+
# value is list of tuples of cos sims for cell_states_to_model
|
1176 |
+
origin_state_key = self.cell_states_to_model["start_state"]
|
1177 |
+
cos_sims_origin = cos_sims_data[origin_state_key]
|
1178 |
+
|
1179 |
+
for j in range(cos_sims_origin.shape[0]):
|
1180 |
+
if (self.tokens_to_perturb != "all") or (combo_lvl>0):
|
1181 |
+
j_index = torch.tensor(indices_to_perturb[j])
|
1182 |
+
if j_index.shape[0]>1:
|
1183 |
+
j_index = torch.squeeze(j_index)
|
1184 |
+
else:
|
1185 |
+
j_index = torch.tensor([j])
|
1186 |
+
perturbed_gene = torch.index_select(gene_list, 0, j_index)
|
1187 |
+
|
1188 |
+
if perturbed_gene.shape[0]==1:
|
1189 |
+
perturbed_gene = perturbed_gene.item()
|
1190 |
+
elif perturbed_gene.shape[0]>1:
|
1191 |
+
perturbed_gene = tuple(perturbed_gene.tolist())
|
1192 |
+
|
1193 |
+
data_list = []
|
1194 |
+
for data in list(cos_sims_data.values()):
|
1195 |
+
data_item = data.to("cuda")
|
1196 |
+
cell_data = torch.mean(data_item[j]).item()
|
1197 |
+
data_list += [cell_data]
|
1198 |
+
cos_sims_dict[(perturbed_gene, "cell_emb")] += [tuple(data_list)]
|
1199 |
+
|
1200 |
+
elif self.anchor_token is not None:
|
1201 |
+
perturbation_batch, indices_to_perturb = make_perturbation_batch(example_cell,
|
1202 |
+
self.perturb_type,
|
1203 |
+
self.tokens_to_perturb,
|
1204 |
+
None, # first run without anchor token to test individual gene perturbations
|
1205 |
+
0,
|
1206 |
+
self.nproc)
|
1207 |
+
cos_sims_data = quant_cos_sims(model,
|
1208 |
+
self.perturb_type,
|
1209 |
+
perturbation_batch,
|
1210 |
+
self.forward_batch_size,
|
1211 |
+
layer_to_quant,
|
1212 |
+
original_emb,
|
1213 |
+
self.tokens_to_perturb,
|
1214 |
+
indices_to_perturb,
|
1215 |
+
self.perturb_group,
|
1216 |
+
self.cell_states_to_model,
|
1217 |
+
state_embs_dict,
|
1218 |
+
self.pad_token_id,
|
1219 |
+
model_input_size,
|
1220 |
+
self.nproc)
|
1221 |
+
cos_sims_data = cos_sims_data.to("cuda")
|
1222 |
+
|
1223 |
+
combo_perturbation_batch, combo_indices_to_perturb = make_perturbation_batch(example_cell,
|
1224 |
+
self.perturb_type,
|
1225 |
+
self.tokens_to_perturb,
|
1226 |
+
self.anchor_token,
|
1227 |
+
1,
|
1228 |
+
self.nproc)
|
1229 |
+
combo_cos_sims_data = quant_cos_sims(model,
|
1230 |
+
self.perturb_type,
|
1231 |
+
combo_perturbation_batch,
|
1232 |
+
self.forward_batch_size,
|
1233 |
+
layer_to_quant,
|
1234 |
+
original_emb,
|
1235 |
+
self.tokens_to_perturb,
|
1236 |
+
combo_indices_to_perturb,
|
1237 |
+
self.perturb_group,
|
1238 |
+
self.cell_states_to_model,
|
1239 |
+
state_embs_dict,
|
1240 |
+
self.pad_token_id,
|
1241 |
+
model_input_size,
|
1242 |
+
self.nproc)
|
1243 |
+
combo_cos_sims_data = combo_cos_sims_data.to("cuda")
|
1244 |
+
|
1245 |
+
# update cos sims dict
|
1246 |
+
# key is tuple of (perturbed_gene, "cell_emb") for avg cell emb change
|
1247 |
+
anchor_index = example_cell["input_ids"][0].index(self.anchor_token[0])
|
1248 |
+
anchor_cell_cos_sim = torch.mean(cos_sims_data[anchor_index]).item()
|
1249 |
+
non_anchor_indices = [k for k in range(cos_sims_data.shape[0]) if k != anchor_index]
|
1250 |
+
cos_sims_data = cos_sims_data[non_anchor_indices,:]
|
1251 |
+
|
1252 |
+
for j in range(cos_sims_data.shape[0]):
|
1253 |
+
|
1254 |
+
if j<anchor_index:
|
1255 |
+
j_index = torch.tensor([j])
|
1256 |
+
else:
|
1257 |
+
j_index = torch.tensor([j+1])
|
1258 |
+
|
1259 |
+
perturbed_gene = torch.index_select(gene_list, 0, j_index)
|
1260 |
+
perturbed_gene = perturbed_gene.item()
|
1261 |
+
|
1262 |
+
cell_cos_sim = torch.mean(cos_sims_data[j]).item()
|
1263 |
+
combo_cos_sim = torch.mean(combo_cos_sims_data[j]).item()
|
1264 |
+
cos_sims_dict[(perturbed_gene, "cell_emb")] += [(anchor_cell_cos_sim, # cos sim anchor gene alone
|
1265 |
+
cell_cos_sim, # cos sim deleted gene alone
|
1266 |
+
combo_cos_sim)] # cos sim anchor gene + deleted gene
|
1267 |
+
|
1268 |
+
# save dict to disk every 100 cells
|
1269 |
+
if (i/100).is_integer():
|
1270 |
+
with open(f"{output_path_prefix}{pickle_batch}_raw.pickle", "wb") as fp:
|
1271 |
+
pickle.dump(cos_sims_dict, fp)
|
1272 |
+
# reset and clear memory every 1000 cells
|
1273 |
+
if (i/1000).is_integer():
|
1274 |
+
pickle_batch = pickle_batch+1
|
1275 |
+
# clear memory
|
1276 |
+
del perturbed_gene
|
1277 |
+
del cos_sims_data
|
1278 |
+
if self.cell_states_to_model is None:
|
1279 |
+
del cell_cos_sim
|
1280 |
+
if self.cell_states_to_model is not None:
|
1281 |
+
del cell_data
|
1282 |
+
del data_list
|
1283 |
+
elif self.anchor_token is None:
|
1284 |
+
if self.emb_mode == "cell_and_gene":
|
1285 |
+
del affected_gene
|
1286 |
+
del cos_sim_value
|
1287 |
+
else:
|
1288 |
+
del combo_cos_sim
|
1289 |
+
del combo_cos_sims_data
|
1290 |
+
# reset dict
|
1291 |
+
del cos_sims_dict
|
1292 |
+
cos_sims_dict = defaultdict(list)
|
1293 |
+
torch.cuda.empty_cache()
|
1294 |
+
|
1295 |
+
# save remainder cells
|
1296 |
+
with open(f"{output_path_prefix}{pickle_batch}_raw.pickle", "wb") as fp:
|
1297 |
+
pickle.dump(cos_sims_dict, fp)
|
geneformer/in_silico_perturber_stats.py
ADDED
@@ -0,0 +1,716 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Geneformer in silico perturber stats generator.
|
3 |
+
|
4 |
+
Usage:
|
5 |
+
from geneformer import InSilicoPerturberStats
|
6 |
+
ispstats = InSilicoPerturberStats(mode="goal_state_shift",
|
7 |
+
combos=0,
|
8 |
+
anchor_gene=None,
|
9 |
+
cell_states_to_model={"state_key": "disease",
|
10 |
+
"start_state": "dcm",
|
11 |
+
"goal_state": "nf",
|
12 |
+
"alt_states": ["hcm", "other1", "other2"]})
|
13 |
+
ispstats.get_stats("path/to/input_data",
|
14 |
+
None,
|
15 |
+
"path/to/output_directory",
|
16 |
+
"output_prefix")
|
17 |
+
"""
|
18 |
+
|
19 |
+
|
20 |
+
import os
|
21 |
+
import logging
|
22 |
+
import numpy as np
|
23 |
+
import pandas as pd
|
24 |
+
import pickle
|
25 |
+
import random
|
26 |
+
import statsmodels.stats.multitest as smt
|
27 |
+
from pathlib import Path
|
28 |
+
from scipy.stats import ranksums
|
29 |
+
from sklearn.mixture import GaussianMixture
|
30 |
+
from tqdm.notebook import trange, tqdm
|
31 |
+
|
32 |
+
from .in_silico_perturber import flatten_list
|
33 |
+
|
34 |
+
from .tokenizer import TOKEN_DICTIONARY_FILE
|
35 |
+
|
36 |
+
GENE_NAME_ID_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl"
|
37 |
+
|
38 |
+
logger = logging.getLogger(__name__)
|
39 |
+
|
40 |
+
# invert dictionary keys/values
|
41 |
+
def invert_dict(dictionary):
|
42 |
+
return {v: k for k, v in dictionary.items()}
|
43 |
+
|
44 |
+
# read raw dictionary files
|
45 |
+
def read_dictionaries(input_data_directory, cell_or_gene_emb, anchor_token):
|
46 |
+
file_found = 0
|
47 |
+
file_path_list = []
|
48 |
+
dict_list = []
|
49 |
+
for file in os.listdir(input_data_directory):
|
50 |
+
# process only _raw.pickle files
|
51 |
+
if file.endswith("_raw.pickle"):
|
52 |
+
file_found = 1
|
53 |
+
file_path_list += [f"{input_data_directory}/{file}"]
|
54 |
+
for file_path in tqdm(file_path_list):
|
55 |
+
with open(file_path, "rb") as fp:
|
56 |
+
cos_sims_dict = pickle.load(fp)
|
57 |
+
if cell_or_gene_emb == "cell":
|
58 |
+
cell_emb_dict = {k: v for k,
|
59 |
+
v in cos_sims_dict.items() if v and "cell_emb" in k}
|
60 |
+
dict_list += [cell_emb_dict]
|
61 |
+
elif cell_or_gene_emb == "gene":
|
62 |
+
gene_emb_dict = {k: v for k,
|
63 |
+
v in cos_sims_dict.items() if v and anchor_token == k[0]}
|
64 |
+
dict_list += [gene_emb_dict]
|
65 |
+
if file_found == 0:
|
66 |
+
logger.error(
|
67 |
+
"No raw data for processing found within provided directory. " \
|
68 |
+
"Please ensure data files end with '_raw.pickle'.")
|
69 |
+
raise
|
70 |
+
return dict_list
|
71 |
+
|
72 |
+
# get complete gene list
|
73 |
+
def get_gene_list(dict_list,mode):
|
74 |
+
if mode == "cell":
|
75 |
+
position = 0
|
76 |
+
elif mode == "gene":
|
77 |
+
position = 1
|
78 |
+
gene_set = set()
|
79 |
+
for dict_i in dict_list:
|
80 |
+
gene_set.update([k[position] for k, v in dict_i.items() if v])
|
81 |
+
gene_list = list(gene_set)
|
82 |
+
if mode == "gene":
|
83 |
+
gene_list.remove("cell_emb")
|
84 |
+
gene_list.sort()
|
85 |
+
return gene_list
|
86 |
+
|
87 |
+
def token_tuple_to_ensembl_ids(token_tuple, gene_token_id_dict):
|
88 |
+
return tuple([gene_token_id_dict.get(i, np.nan) for i in token_tuple])
|
89 |
+
|
90 |
+
def n_detections(token, dict_list, mode, anchor_token):
|
91 |
+
cos_sim_megalist = []
|
92 |
+
for dict_i in dict_list:
|
93 |
+
if mode == "cell":
|
94 |
+
cos_sim_megalist += dict_i.get((token, "cell_emb"),[])
|
95 |
+
elif mode == "gene":
|
96 |
+
cos_sim_megalist += dict_i.get((anchor_token, token),[])
|
97 |
+
return len(cos_sim_megalist)
|
98 |
+
|
99 |
+
def get_fdr(pvalues):
|
100 |
+
return list(smt.multipletests(pvalues, alpha=0.05, method="fdr_bh")[1])
|
101 |
+
|
102 |
+
def get_impact_component(test_value, gaussian_mixture_model):
|
103 |
+
impact_border = gaussian_mixture_model.means_[0][0]
|
104 |
+
nonimpact_border = gaussian_mixture_model.means_[1][0]
|
105 |
+
if test_value > nonimpact_border:
|
106 |
+
impact_component = 0
|
107 |
+
elif test_value < impact_border:
|
108 |
+
impact_component = 1
|
109 |
+
else:
|
110 |
+
impact_component_raw = gaussian_mixture_model.predict([[test_value]])[0]
|
111 |
+
if impact_component_raw == 1:
|
112 |
+
impact_component = 0
|
113 |
+
elif impact_component_raw == 0:
|
114 |
+
impact_component = 1
|
115 |
+
return impact_component
|
116 |
+
|
117 |
+
# aggregate data for single perturbation in multiple cells
|
118 |
+
def isp_aggregate_grouped_perturb(cos_sims_df, dict_list):
|
119 |
+
names=["Cosine_shift"]
|
120 |
+
cos_sims_full_df = pd.DataFrame(columns=names)
|
121 |
+
|
122 |
+
cos_shift_data = []
|
123 |
+
token = cos_sims_df["Gene"][0]
|
124 |
+
for dict_i in dict_list:
|
125 |
+
cos_shift_data += dict_i.get((token, "cell_emb"),[])
|
126 |
+
cos_sims_full_df["Cosine_shift"] = cos_shift_data
|
127 |
+
return cos_sims_full_df
|
128 |
+
|
129 |
+
# stats comparing cos sim shifts towards goal state of test perturbations vs random perturbations
|
130 |
+
def isp_stats_to_goal_state(cos_sims_df, dict_list, cell_states_to_model, genes_perturbed):
|
131 |
+
cell_state_key = cell_states_to_model["start_state"]
|
132 |
+
if ("alt_states" not in cell_states_to_model.keys()) \
|
133 |
+
or (len(cell_states_to_model["alt_states"]) == 0) \
|
134 |
+
or (cell_states_to_model["alt_states"] == [None]):
|
135 |
+
alt_end_state_exists = False
|
136 |
+
elif (len(cell_states_to_model["alt_states"]) > 0) and (cell_states_to_model["alt_states"] != [None]):
|
137 |
+
alt_end_state_exists = True
|
138 |
+
|
139 |
+
# for single perturbation in multiple cells, there are no random perturbations to compare to
|
140 |
+
if genes_perturbed != "all":
|
141 |
+
names=["Shift_to_goal_end",
|
142 |
+
"Shift_to_alt_end"]
|
143 |
+
if alt_end_state_exists == False:
|
144 |
+
names.remove("Shift_to_alt_end")
|
145 |
+
cos_sims_full_df = pd.DataFrame(columns=names)
|
146 |
+
|
147 |
+
cos_shift_data = []
|
148 |
+
token = cos_sims_df["Gene"][0]
|
149 |
+
for dict_i in dict_list:
|
150 |
+
cos_shift_data += dict_i.get((token, "cell_emb"),[])
|
151 |
+
if alt_end_state_exists == False:
|
152 |
+
cos_sims_full_df["Shift_to_goal_end"] = [goal_end for start_state,goal_end in cos_shift_data]
|
153 |
+
if alt_end_state_exists == True:
|
154 |
+
cos_sims_full_df["Shift_to_goal_end"] = [goal_end for start_state,goal_end,alt_end in cos_shift_data]
|
155 |
+
cos_sims_full_df["Shift_to_alt_end"] = [alt_end for start_state,goal_end,alt_end in cos_shift_data]
|
156 |
+
|
157 |
+
# sort by shift to desired state
|
158 |
+
cos_sims_full_df = cos_sims_full_df.sort_values(by=["Shift_to_goal_end"],
|
159 |
+
ascending=[False])
|
160 |
+
return cos_sims_full_df
|
161 |
+
|
162 |
+
elif genes_perturbed == "all":
|
163 |
+
random_tuples = []
|
164 |
+
for i in trange(cos_sims_df.shape[0]):
|
165 |
+
token = cos_sims_df["Gene"][i]
|
166 |
+
for dict_i in dict_list:
|
167 |
+
random_tuples += dict_i.get((token, "cell_emb"),[])
|
168 |
+
|
169 |
+
if alt_end_state_exists == False:
|
170 |
+
goal_end_random_megalist = [goal_end for start_state,goal_end in random_tuples]
|
171 |
+
elif alt_end_state_exists == True:
|
172 |
+
goal_end_random_megalist = [goal_end for start_state,goal_end,alt_end in random_tuples]
|
173 |
+
alt_end_random_megalist = [alt_end for start_state,goal_end,alt_end in random_tuples]
|
174 |
+
|
175 |
+
# downsample to improve speed of ranksums
|
176 |
+
if len(goal_end_random_megalist) > 100_000:
|
177 |
+
random.seed(42)
|
178 |
+
goal_end_random_megalist = random.sample(goal_end_random_megalist, k=100_000)
|
179 |
+
if alt_end_state_exists == True:
|
180 |
+
if len(alt_end_random_megalist) > 100_000:
|
181 |
+
random.seed(42)
|
182 |
+
alt_end_random_megalist = random.sample(alt_end_random_megalist, k=100_000)
|
183 |
+
|
184 |
+
names=["Gene",
|
185 |
+
"Gene_name",
|
186 |
+
"Ensembl_ID",
|
187 |
+
"Shift_to_goal_end",
|
188 |
+
"Shift_to_alt_end",
|
189 |
+
"Goal_end_vs_random_pval",
|
190 |
+
"Alt_end_vs_random_pval"]
|
191 |
+
if alt_end_state_exists == False:
|
192 |
+
names.remove("Shift_to_alt_end")
|
193 |
+
names.remove("Alt_end_vs_random_pval")
|
194 |
+
cos_sims_full_df = pd.DataFrame(columns=names)
|
195 |
+
|
196 |
+
for i in trange(cos_sims_df.shape[0]):
|
197 |
+
token = cos_sims_df["Gene"][i]
|
198 |
+
name = cos_sims_df["Gene_name"][i]
|
199 |
+
ensembl_id = cos_sims_df["Ensembl_ID"][i]
|
200 |
+
cos_shift_data = []
|
201 |
+
|
202 |
+
for dict_i in dict_list:
|
203 |
+
cos_shift_data += dict_i.get((token, "cell_emb"),[])
|
204 |
+
|
205 |
+
if alt_end_state_exists == False:
|
206 |
+
goal_end_cos_sim_megalist = [goal_end for start_state,goal_end in cos_shift_data]
|
207 |
+
elif alt_end_state_exists == True:
|
208 |
+
goal_end_cos_sim_megalist = [goal_end for start_state,goal_end,alt_end in cos_shift_data]
|
209 |
+
alt_end_cos_sim_megalist = [alt_end for start_state,goal_end,alt_end in cos_shift_data]
|
210 |
+
mean_alt_end = np.mean(alt_end_cos_sim_megalist)
|
211 |
+
pval_alt_end = ranksums(alt_end_random_megalist,alt_end_cos_sim_megalist).pvalue
|
212 |
+
|
213 |
+
mean_goal_end = np.mean(goal_end_cos_sim_megalist)
|
214 |
+
pval_goal_end = ranksums(goal_end_random_megalist,goal_end_cos_sim_megalist).pvalue
|
215 |
+
|
216 |
+
if alt_end_state_exists == False:
|
217 |
+
data_i = [token,
|
218 |
+
name,
|
219 |
+
ensembl_id,
|
220 |
+
mean_goal_end,
|
221 |
+
pval_goal_end]
|
222 |
+
elif alt_end_state_exists == True:
|
223 |
+
data_i = [token,
|
224 |
+
name,
|
225 |
+
ensembl_id,
|
226 |
+
mean_goal_end,
|
227 |
+
mean_alt_end,
|
228 |
+
pval_goal_end,
|
229 |
+
pval_alt_end]
|
230 |
+
|
231 |
+
cos_sims_df_i = pd.DataFrame(dict(zip(names,data_i)),index=[i])
|
232 |
+
cos_sims_full_df = pd.concat([cos_sims_full_df,cos_sims_df_i])
|
233 |
+
|
234 |
+
cos_sims_full_df["Goal_end_FDR"] = get_fdr(list(cos_sims_full_df["Goal_end_vs_random_pval"]))
|
235 |
+
if alt_end_state_exists == True:
|
236 |
+
cos_sims_full_df["Alt_end_FDR"] = get_fdr(list(cos_sims_full_df["Alt_end_vs_random_pval"]))
|
237 |
+
|
238 |
+
# quantify number of detections of each gene
|
239 |
+
cos_sims_full_df["N_Detections"] = [n_detections(i, dict_list, "cell", None) for i in cos_sims_full_df["Gene"]]
|
240 |
+
|
241 |
+
# sort by shift to desired state\
|
242 |
+
cos_sims_full_df["Sig"] = [1 if fdr<0.05 else 0 for fdr in cos_sims_full_df["Goal_end_FDR"]]
|
243 |
+
cos_sims_full_df = cos_sims_full_df.sort_values(by=["Sig",
|
244 |
+
"Shift_to_goal_end",
|
245 |
+
"Goal_end_FDR"],
|
246 |
+
ascending=[False,False,True])
|
247 |
+
|
248 |
+
return cos_sims_full_df
|
249 |
+
|
250 |
+
# stats comparing cos sim shifts of test perturbations vs null distribution
|
251 |
+
def isp_stats_vs_null(cos_sims_df, dict_list, null_dict_list):
|
252 |
+
cos_sims_full_df = cos_sims_df.copy()
|
253 |
+
|
254 |
+
cos_sims_full_df["Test_avg_shift"] = np.zeros(cos_sims_df.shape[0], dtype=float)
|
255 |
+
cos_sims_full_df["Null_avg_shift"] = np.zeros(cos_sims_df.shape[0], dtype=float)
|
256 |
+
cos_sims_full_df["Test_vs_null_avg_shift"] = np.zeros(cos_sims_df.shape[0], dtype=float)
|
257 |
+
cos_sims_full_df["Test_vs_null_pval"] = np.zeros(cos_sims_df.shape[0], dtype=float)
|
258 |
+
cos_sims_full_df["Test_vs_null_FDR"] = np.zeros(cos_sims_df.shape[0], dtype=float)
|
259 |
+
cos_sims_full_df["N_Detections_test"] = np.zeros(cos_sims_df.shape[0], dtype="uint32")
|
260 |
+
cos_sims_full_df["N_Detections_null"] = np.zeros(cos_sims_df.shape[0], dtype="uint32")
|
261 |
+
|
262 |
+
for i in trange(cos_sims_df.shape[0]):
|
263 |
+
token = cos_sims_df["Gene"][i]
|
264 |
+
test_shifts = []
|
265 |
+
null_shifts = []
|
266 |
+
|
267 |
+
for dict_i in dict_list:
|
268 |
+
test_shifts += dict_i.get((token, "cell_emb"),[])
|
269 |
+
|
270 |
+
for dict_i in null_dict_list:
|
271 |
+
null_shifts += dict_i.get((token, "cell_emb"),[])
|
272 |
+
|
273 |
+
cos_sims_full_df.loc[i, "Test_avg_shift"] = np.mean(test_shifts)
|
274 |
+
cos_sims_full_df.loc[i, "Null_avg_shift"] = np.mean(null_shifts)
|
275 |
+
cos_sims_full_df.loc[i, "Test_vs_null_avg_shift"] = np.mean(test_shifts)-np.mean(null_shifts)
|
276 |
+
cos_sims_full_df.loc[i, "Test_vs_null_pval"] = ranksums(test_shifts,
|
277 |
+
null_shifts, nan_policy="omit").pvalue
|
278 |
+
|
279 |
+
cos_sims_full_df.loc[i, "N_Detections_test"] = len(test_shifts)
|
280 |
+
cos_sims_full_df.loc[i, "N_Detections_null"] = len(null_shifts)
|
281 |
+
|
282 |
+
cos_sims_full_df["Test_vs_null_FDR"] = get_fdr(cos_sims_full_df["Test_vs_null_pval"])
|
283 |
+
|
284 |
+
cos_sims_full_df["Sig"] = [1 if fdr<0.05 else 0 for fdr in cos_sims_full_df["Test_vs_null_FDR"]]
|
285 |
+
cos_sims_full_df = cos_sims_full_df.sort_values(by=["Sig",
|
286 |
+
"Test_vs_null_avg_shift",
|
287 |
+
"Test_vs_null_FDR"],
|
288 |
+
ascending=[False,False,True])
|
289 |
+
return cos_sims_full_df
|
290 |
+
|
291 |
+
# stats for identifying perturbations with largest effect within a given set of cells
|
292 |
+
# fits a mixture model to 2 components (impact vs. non-impact) and
|
293 |
+
# reports the most likely component for each test perturbation
|
294 |
+
# Note: because assumes given perturbation has a consistent effect in the cells tested,
|
295 |
+
# we recommend only using the mixture model strategy with uniform cell populations
|
296 |
+
def isp_stats_mixture_model(cos_sims_df, dict_list, combos, anchor_token):
|
297 |
+
|
298 |
+
names=["Gene",
|
299 |
+
"Gene_name",
|
300 |
+
"Ensembl_ID"]
|
301 |
+
|
302 |
+
if combos == 0:
|
303 |
+
names += ["Test_avg_shift"]
|
304 |
+
elif combos == 1:
|
305 |
+
names += ["Anchor_shift",
|
306 |
+
"Test_token_shift",
|
307 |
+
"Sum_of_indiv_shifts",
|
308 |
+
"Combo_shift",
|
309 |
+
"Combo_minus_sum_shift"]
|
310 |
+
|
311 |
+
names += ["Impact_component",
|
312 |
+
"Impact_component_percent"]
|
313 |
+
|
314 |
+
cos_sims_full_df = pd.DataFrame(columns=names)
|
315 |
+
avg_values = []
|
316 |
+
gene_names = []
|
317 |
+
|
318 |
+
for i in trange(cos_sims_df.shape[0]):
|
319 |
+
token = cos_sims_df["Gene"][i]
|
320 |
+
name = cos_sims_df["Gene_name"][i]
|
321 |
+
ensembl_id = cos_sims_df["Ensembl_ID"][i]
|
322 |
+
cos_shift_data = []
|
323 |
+
|
324 |
+
for dict_i in dict_list:
|
325 |
+
if (combos == 0) and (anchor_token is not None):
|
326 |
+
cos_shift_data += dict_i.get((anchor_token, token),[])
|
327 |
+
else:
|
328 |
+
cos_shift_data += dict_i.get((token, "cell_emb"),[])
|
329 |
+
|
330 |
+
# Extract values for current gene
|
331 |
+
if combos == 0:
|
332 |
+
test_values = cos_shift_data
|
333 |
+
elif combos == 1:
|
334 |
+
test_values = []
|
335 |
+
for tup in cos_shift_data:
|
336 |
+
test_values.append(tup[2])
|
337 |
+
|
338 |
+
if len(test_values) > 0:
|
339 |
+
avg_value = np.mean(test_values)
|
340 |
+
avg_values.append(avg_value)
|
341 |
+
gene_names.append(name)
|
342 |
+
|
343 |
+
# fit Gaussian mixture model to dataset of mean for each gene
|
344 |
+
avg_values_to_fit = np.array(avg_values).reshape(-1, 1)
|
345 |
+
gm = GaussianMixture(n_components=2, random_state=0).fit(avg_values_to_fit)
|
346 |
+
|
347 |
+
for i in trange(cos_sims_df.shape[0]):
|
348 |
+
token = cos_sims_df["Gene"][i]
|
349 |
+
name = cos_sims_df["Gene_name"][i]
|
350 |
+
ensembl_id = cos_sims_df["Ensembl_ID"][i]
|
351 |
+
cos_shift_data = []
|
352 |
+
|
353 |
+
for dict_i in dict_list:
|
354 |
+
if (combos == 0) and (anchor_token is not None):
|
355 |
+
cos_shift_data += dict_i.get((anchor_token, token),[])
|
356 |
+
else:
|
357 |
+
cos_shift_data += dict_i.get((token, "cell_emb"),[])
|
358 |
+
|
359 |
+
if combos == 0:
|
360 |
+
mean_test = np.mean(cos_shift_data)
|
361 |
+
impact_components = [get_impact_component(value,gm) for value in cos_shift_data]
|
362 |
+
elif combos == 1:
|
363 |
+
anchor_cos_sim_megalist = [anchor for anchor,token,combo in cos_shift_data]
|
364 |
+
token_cos_sim_megalist = [token for anchor,token,combo in cos_shift_data]
|
365 |
+
anchor_plus_token_cos_sim_megalist = [1-((1-anchor)+(1-token)) for anchor,token,combo in cos_shift_data]
|
366 |
+
combo_anchor_token_cos_sim_megalist = [combo for anchor,token,combo in cos_shift_data]
|
367 |
+
combo_minus_sum_cos_sim_megalist = [combo-(1-((1-anchor)+(1-token))) for anchor,token,combo in cos_shift_data]
|
368 |
+
|
369 |
+
mean_anchor = np.mean(anchor_cos_sim_megalist)
|
370 |
+
mean_token = np.mean(token_cos_sim_megalist)
|
371 |
+
mean_sum = np.mean(anchor_plus_token_cos_sim_megalist)
|
372 |
+
mean_test = np.mean(combo_anchor_token_cos_sim_megalist)
|
373 |
+
mean_combo_minus_sum = np.mean(combo_minus_sum_cos_sim_megalist)
|
374 |
+
|
375 |
+
impact_components = [get_impact_component(value,gm) for value in combo_anchor_token_cos_sim_megalist]
|
376 |
+
|
377 |
+
impact_component = get_impact_component(mean_test,gm)
|
378 |
+
impact_component_percent = np.mean(impact_components)*100
|
379 |
+
|
380 |
+
data_i = [token,
|
381 |
+
name,
|
382 |
+
ensembl_id]
|
383 |
+
if combos == 0:
|
384 |
+
data_i += [mean_test]
|
385 |
+
elif combos == 1:
|
386 |
+
data_i += [mean_anchor,
|
387 |
+
mean_token,
|
388 |
+
mean_sum,
|
389 |
+
mean_test,
|
390 |
+
mean_combo_minus_sum]
|
391 |
+
data_i += [impact_component,
|
392 |
+
impact_component_percent]
|
393 |
+
|
394 |
+
cos_sims_df_i = pd.DataFrame(dict(zip(names,data_i)),index=[i])
|
395 |
+
cos_sims_full_df = pd.concat([cos_sims_full_df,cos_sims_df_i])
|
396 |
+
|
397 |
+
# quantify number of detections of each gene
|
398 |
+
cos_sims_full_df["N_Detections"] = [n_detections(i,
|
399 |
+
dict_list,
|
400 |
+
"gene",
|
401 |
+
anchor_token) for i in cos_sims_full_df["Gene"]]
|
402 |
+
|
403 |
+
if combos == 0:
|
404 |
+
cos_sims_full_df = cos_sims_full_df.sort_values(by=["Impact_component",
|
405 |
+
"Test_avg_shift"],
|
406 |
+
ascending=[False,True])
|
407 |
+
elif combos == 1:
|
408 |
+
cos_sims_full_df = cos_sims_full_df.sort_values(by=["Impact_component",
|
409 |
+
"Combo_minus_sum_shift"],
|
410 |
+
ascending=[False,True])
|
411 |
+
return cos_sims_full_df
|
412 |
+
|
413 |
+
class InSilicoPerturberStats:
|
414 |
+
valid_option_dict = {
|
415 |
+
"mode": {"goal_state_shift","vs_null","mixture_model","aggregate_data"},
|
416 |
+
"combos": {0,1},
|
417 |
+
"anchor_gene": {None, str},
|
418 |
+
"cell_states_to_model": {None, dict},
|
419 |
+
}
|
420 |
+
def __init__(
|
421 |
+
self,
|
422 |
+
mode="mixture_model",
|
423 |
+
genes_perturbed="all",
|
424 |
+
combos=0,
|
425 |
+
anchor_gene=None,
|
426 |
+
cell_states_to_model=None,
|
427 |
+
token_dictionary_file=TOKEN_DICTIONARY_FILE,
|
428 |
+
gene_name_id_dictionary_file=GENE_NAME_ID_DICTIONARY_FILE,
|
429 |
+
):
|
430 |
+
"""
|
431 |
+
Initialize in silico perturber stats generator.
|
432 |
+
|
433 |
+
Parameters
|
434 |
+
----------
|
435 |
+
mode : {"goal_state_shift","vs_null","mixture_model","aggregate_data"}
|
436 |
+
Type of stats.
|
437 |
+
"goal_state_shift": perturbation vs. random for desired cell state shift
|
438 |
+
"vs_null": perturbation vs. null from provided null distribution dataset
|
439 |
+
"mixture_model": perturbation in impact vs. no impact component of mixture model (no goal direction)
|
440 |
+
"aggregate_data": aggregates cosine shifts for single perturbation in multiple cells
|
441 |
+
genes_perturbed : "all", list
|
442 |
+
Genes perturbed in isp experiment.
|
443 |
+
Default is assuming genes_to_perturb in isp experiment was "all" (each gene in each cell).
|
444 |
+
Otherwise, may provide a list of ENSEMBL IDs of genes perturbed as a group all together.
|
445 |
+
combos : {0,1,2}
|
446 |
+
Whether to perturb genes individually (0), in pairs (1), or in triplets (2).
|
447 |
+
anchor_gene : None, str
|
448 |
+
ENSEMBL ID of gene to use as anchor in combination perturbations or in testing effect on downstream genes.
|
449 |
+
For example, if combos=1 and anchor_gene="ENSG00000136574":
|
450 |
+
analyzes data for anchor gene perturbed in combination with each other gene.
|
451 |
+
However, if combos=0 and anchor_gene="ENSG00000136574":
|
452 |
+
analyzes data for the effect of anchor gene's perturbation on the embedding of each other gene.
|
453 |
+
cell_states_to_model: None, dict
|
454 |
+
Cell states to model if testing perturbations that achieve goal state change.
|
455 |
+
Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states
|
456 |
+
state_key: key specifying name of column in .dataset that defines the start/goal states
|
457 |
+
start_state: value in the state_key column that specifies the start state
|
458 |
+
goal_state: value in the state_key column taht specifies the goal end state
|
459 |
+
alt_states: list of values in the state_key column that specify the alternate end states
|
460 |
+
For example: {"state_key": "disease",
|
461 |
+
"start_state": "dcm",
|
462 |
+
"goal_state": "nf",
|
463 |
+
"alt_states": ["hcm", "other1", "other2"]}
|
464 |
+
token_dictionary_file : Path
|
465 |
+
Path to pickle file containing token dictionary (Ensembl ID:token).
|
466 |
+
gene_name_id_dictionary_file : Path
|
467 |
+
Path to pickle file containing gene name to ID dictionary (gene name:Ensembl ID).
|
468 |
+
"""
|
469 |
+
|
470 |
+
self.mode = mode
|
471 |
+
self.genes_perturbed = genes_perturbed
|
472 |
+
self.combos = combos
|
473 |
+
self.anchor_gene = anchor_gene
|
474 |
+
self.cell_states_to_model = cell_states_to_model
|
475 |
+
|
476 |
+
self.validate_options()
|
477 |
+
|
478 |
+
# load token dictionary (Ensembl IDs:token)
|
479 |
+
with open(token_dictionary_file, "rb") as f:
|
480 |
+
self.gene_token_dict = pickle.load(f)
|
481 |
+
|
482 |
+
# load gene name dictionary (gene name:Ensembl ID)
|
483 |
+
with open(gene_name_id_dictionary_file, "rb") as f:
|
484 |
+
self.gene_name_id_dict = pickle.load(f)
|
485 |
+
|
486 |
+
if anchor_gene is None:
|
487 |
+
self.anchor_token = None
|
488 |
+
else:
|
489 |
+
self.anchor_token = self.gene_token_dict[self.anchor_gene]
|
490 |
+
|
491 |
+
def validate_options(self):
|
492 |
+
for attr_name,valid_options in self.valid_option_dict.items():
|
493 |
+
attr_value = self.__dict__[attr_name]
|
494 |
+
if type(attr_value) not in {list, dict}:
|
495 |
+
if attr_name in {"anchor_gene"}:
|
496 |
+
continue
|
497 |
+
elif attr_value in valid_options:
|
498 |
+
continue
|
499 |
+
valid_type = False
|
500 |
+
for option in valid_options:
|
501 |
+
if (option in [int,list,dict]) and isinstance(attr_value, option):
|
502 |
+
valid_type = True
|
503 |
+
break
|
504 |
+
if valid_type:
|
505 |
+
continue
|
506 |
+
logger.error(
|
507 |
+
f"Invalid option for {attr_name}. " \
|
508 |
+
f"Valid options for {attr_name}: {valid_options}"
|
509 |
+
)
|
510 |
+
raise
|
511 |
+
|
512 |
+
if self.cell_states_to_model is not None:
|
513 |
+
if len(self.cell_states_to_model.items()) == 1:
|
514 |
+
logger.warning(
|
515 |
+
"The single value dictionary for cell_states_to_model will be " \
|
516 |
+
"replaced with a dictionary with named keys for start, goal, and alternate states. " \
|
517 |
+
"Please specify state_key, start_state, goal_state, and alt_states " \
|
518 |
+
"in the cell_states_to_model dictionary for future use. " \
|
519 |
+
"For example, cell_states_to_model={" \
|
520 |
+
"'state_key': 'disease', " \
|
521 |
+
"'start_state': 'dcm', " \
|
522 |
+
"'goal_state': 'nf', " \
|
523 |
+
"'alt_states': ['hcm', 'other1', 'other2']}"
|
524 |
+
)
|
525 |
+
for key,value in self.cell_states_to_model.items():
|
526 |
+
if (len(value) == 3) and isinstance(value, tuple):
|
527 |
+
if isinstance(value[0],list) and isinstance(value[1],list) and isinstance(value[2],list):
|
528 |
+
if len(value[0]) == 1 and len(value[1]) == 1:
|
529 |
+
all_values = value[0]+value[1]+value[2]
|
530 |
+
if len(all_values) == len(set(all_values)):
|
531 |
+
continue
|
532 |
+
# reformat to the new named key format
|
533 |
+
state_values = flatten_list(list(self.cell_states_to_model.values()))
|
534 |
+
self.cell_states_to_model = {
|
535 |
+
"state_key": list(self.cell_states_to_model.keys())[0],
|
536 |
+
"start_state": state_values[0][0],
|
537 |
+
"goal_state": state_values[1][0],
|
538 |
+
"alt_states": state_values[2:][0]
|
539 |
+
}
|
540 |
+
elif set(self.cell_states_to_model.keys()) == {"state_key", "start_state", "goal_state", "alt_states"}:
|
541 |
+
if (self.cell_states_to_model["state_key"] is None) \
|
542 |
+
or (self.cell_states_to_model["start_state"] is None) \
|
543 |
+
or (self.cell_states_to_model["goal_state"] is None):
|
544 |
+
logger.error(
|
545 |
+
"Please specify 'state_key', 'start_state', and 'goal_state' in cell_states_to_model.")
|
546 |
+
raise
|
547 |
+
|
548 |
+
if self.cell_states_to_model["start_state"] == self.cell_states_to_model["goal_state"]:
|
549 |
+
logger.error(
|
550 |
+
"All states must be unique.")
|
551 |
+
raise
|
552 |
+
|
553 |
+
if self.cell_states_to_model["alt_states"] is not None:
|
554 |
+
if type(self.cell_states_to_model["alt_states"]) is not list:
|
555 |
+
logger.error(
|
556 |
+
"self.cell_states_to_model['alt_states'] must be a list (even if it is one element)."
|
557 |
+
)
|
558 |
+
raise
|
559 |
+
if len(self.cell_states_to_model["alt_states"])!= len(set(self.cell_states_to_model["alt_states"])):
|
560 |
+
logger.error(
|
561 |
+
"All states must be unique.")
|
562 |
+
raise
|
563 |
+
|
564 |
+
else:
|
565 |
+
logger.error(
|
566 |
+
"cell_states_to_model must only have the following four keys: " \
|
567 |
+
"'state_key', 'start_state', 'goal_state', 'alt_states'." \
|
568 |
+
"For example, cell_states_to_model={" \
|
569 |
+
"'state_key': 'disease', " \
|
570 |
+
"'start_state': 'dcm', " \
|
571 |
+
"'goal_state': 'nf', " \
|
572 |
+
"'alt_states': ['hcm', 'other1', 'other2']}"
|
573 |
+
)
|
574 |
+
raise
|
575 |
+
|
576 |
+
if self.anchor_gene is not None:
|
577 |
+
self.anchor_gene = None
|
578 |
+
logger.warning(
|
579 |
+
"anchor_gene set to None. " \
|
580 |
+
"Currently, anchor gene not available " \
|
581 |
+
"when modeling multiple cell states.")
|
582 |
+
|
583 |
+
if self.combos > 0:
|
584 |
+
if self.anchor_gene is None:
|
585 |
+
logger.error(
|
586 |
+
"Currently, stats are only supported for combination " \
|
587 |
+
"in silico perturbation run with anchor gene. Please add " \
|
588 |
+
"anchor gene when using with combos > 0. ")
|
589 |
+
raise
|
590 |
+
|
591 |
+
if (self.mode == "mixture_model") and (self.genes_perturbed != "all"):
|
592 |
+
logger.error(
|
593 |
+
"Mixture model mode requires multiple gene perturbations to fit model " \
|
594 |
+
"so is incompatible with a single grouped perturbation.")
|
595 |
+
raise
|
596 |
+
if (self.mode == "aggregate_data") and (self.genes_perturbed == "all"):
|
597 |
+
logger.error(
|
598 |
+
"Simple data aggregation mode is for single perturbation in multiple cells " \
|
599 |
+
"so is incompatible with a genes_perturbed being 'all'.")
|
600 |
+
raise
|
601 |
+
|
602 |
+
def get_stats(self,
|
603 |
+
input_data_directory,
|
604 |
+
null_dist_data_directory,
|
605 |
+
output_directory,
|
606 |
+
output_prefix):
|
607 |
+
"""
|
608 |
+
Get stats for in silico perturbation data and save as results in output_directory.
|
609 |
+
|
610 |
+
Parameters
|
611 |
+
----------
|
612 |
+
input_data_directory : Path
|
613 |
+
Path to directory containing cos_sim dictionary inputs
|
614 |
+
null_dist_data_directory : Path
|
615 |
+
Path to directory containing null distribution cos_sim dictionary inputs
|
616 |
+
output_directory : Path
|
617 |
+
Path to directory where perturbation data will be saved as .csv
|
618 |
+
output_prefix : str
|
619 |
+
Prefix for output .csv
|
620 |
+
|
621 |
+
Outputs
|
622 |
+
----------
|
623 |
+
Definition of possible columns in .csv output file.
|
624 |
+
|
625 |
+
Of note, not all columns will be present in all output files.
|
626 |
+
Some columns are specific to particular perturbation modes.
|
627 |
+
|
628 |
+
"Gene": gene token
|
629 |
+
"Gene_name": gene name
|
630 |
+
"Ensembl_ID": gene Ensembl ID
|
631 |
+
"N_Detections": number of cells in which each gene or gene combination was detected in the input dataset
|
632 |
+
"Sig": 1 if FDR<0.05, otherwise 0
|
633 |
+
|
634 |
+
"Shift_to_goal_end": cosine shift from start state towards goal end state in response to given perturbation
|
635 |
+
"Shift_to_alt_end": cosine shift from start state towards alternate end state in response to given perturbation
|
636 |
+
"Goal_end_vs_random_pval": pvalue of cosine shift from start state towards goal end state by Wilcoxon
|
637 |
+
pvalue compares shift caused by perturbing given gene compared to random genes
|
638 |
+
"Alt_end_vs_random_pval": pvalue of cosine shift from start state towards alternate end state by Wilcoxon
|
639 |
+
pvalue compares shift caused by perturbing given gene compared to random genes
|
640 |
+
"Goal_end_FDR": Benjamini-Hochberg correction of "Goal_end_vs_random_pval"
|
641 |
+
"Alt_end_FDR": Benjamini-Hochberg correction of "Alt_end_vs_random_pval"
|
642 |
+
|
643 |
+
"Test_avg_shift": cosine shift in response to given perturbation in cells from test distribution
|
644 |
+
"Null_avg_shift": cosine shift in response to given perturbation in cells from null distribution (e.g. random cells)
|
645 |
+
"Test_vs_null_avg_shift": difference in cosine shift in cells from test vs. null distribution
|
646 |
+
(i.e. "Test_avg_shift" minus "Null_avg_shift")
|
647 |
+
"Test_vs_null_pval": pvalue of cosine shift in test vs. null distribution
|
648 |
+
"Test_vs_null_FDR": Benjamini-Hochberg correction of "Test_vs_null_pval"
|
649 |
+
"N_Detections_test": "N_Detections" in cells from test distribution
|
650 |
+
"N_Detections_null": "N_Detections" in cells from null distribution
|
651 |
+
|
652 |
+
"Anchor_shift": cosine shift in response to given perturbation of anchor gene
|
653 |
+
"Test_token_shift": cosine shift in response to given perturbation of test gene
|
654 |
+
"Sum_of_indiv_shifts": sum of cosine shifts in response to individually perturbing test and anchor genes
|
655 |
+
"Combo_shift": cosine shift in response to given perturbation of both anchor and test gene(s) in combination
|
656 |
+
"Combo_minus_sum_shift": difference of cosine shifts in response combo perturbation vs. sum of individual perturbations
|
657 |
+
(i.e. "Combo_shift" minus "Sum_of_indiv_shifts")
|
658 |
+
"Impact_component": whether the given perturbation was modeled to be within the impact component by the mixture model
|
659 |
+
1: within impact component; 0: not within impact component
|
660 |
+
"Impact_component_percent": percent of cells in which given perturbation was modeled to be within impact component
|
661 |
+
"""
|
662 |
+
|
663 |
+
if self.mode not in ["goal_state_shift", "vs_null", "mixture_model","aggregate_data"]:
|
664 |
+
logger.error(
|
665 |
+
"Currently, only modes available are stats for goal_state_shift, " \
|
666 |
+
"vs_null (comparing to null distribution), and " \
|
667 |
+
"mixture_model (fitting mixture model for perturbations with or without impact.")
|
668 |
+
raise
|
669 |
+
|
670 |
+
self.gene_token_id_dict = invert_dict(self.gene_token_dict)
|
671 |
+
self.gene_id_name_dict = invert_dict(self.gene_name_id_dict)
|
672 |
+
|
673 |
+
# obtain total gene list
|
674 |
+
if (self.combos == 0) and (self.anchor_token is not None):
|
675 |
+
# cos sim data for effect of gene perturbation on the embedding of each other gene
|
676 |
+
dict_list = read_dictionaries(input_data_directory, "gene", self.anchor_token)
|
677 |
+
gene_list = get_gene_list(dict_list, "gene")
|
678 |
+
else:
|
679 |
+
# cos sim data for effect of gene perturbation on the embedding of each cell
|
680 |
+
dict_list = read_dictionaries(input_data_directory, "cell", self.anchor_token)
|
681 |
+
gene_list = get_gene_list(dict_list, "cell")
|
682 |
+
|
683 |
+
# initiate results dataframe
|
684 |
+
cos_sims_df_initial = pd.DataFrame({"Gene": gene_list,
|
685 |
+
"Gene_name": [self.token_to_gene_name(item) \
|
686 |
+
for item in gene_list], \
|
687 |
+
"Ensembl_ID": [token_tuple_to_ensembl_ids(genes, self.gene_token_id_dict) \
|
688 |
+
if self.genes_perturbed != "all" else \
|
689 |
+
self.gene_token_id_dict[genes[1]] \
|
690 |
+
if isinstance(genes,tuple) else \
|
691 |
+
self.gene_token_id_dict[genes] \
|
692 |
+
for genes in gene_list]}, \
|
693 |
+
index=[i for i in range(len(gene_list))])
|
694 |
+
|
695 |
+
if self.mode == "goal_state_shift":
|
696 |
+
cos_sims_df = isp_stats_to_goal_state(cos_sims_df_initial, dict_list, self.cell_states_to_model, self.genes_perturbed)
|
697 |
+
|
698 |
+
elif self.mode == "vs_null":
|
699 |
+
null_dict_list = read_dictionaries(null_dist_data_directory, "cell", self.anchor_token)
|
700 |
+
cos_sims_df = isp_stats_vs_null(cos_sims_df_initial, dict_list, null_dict_list)
|
701 |
+
|
702 |
+
elif self.mode == "mixture_model":
|
703 |
+
cos_sims_df = isp_stats_mixture_model(cos_sims_df_initial, dict_list, self.combos, self.anchor_token)
|
704 |
+
|
705 |
+
elif self.mode == "aggregate_data":
|
706 |
+
cos_sims_df = isp_aggregate_grouped_perturb(cos_sims_df_initial, dict_list)
|
707 |
+
|
708 |
+
# save perturbation stats to output_path
|
709 |
+
output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")
|
710 |
+
cos_sims_df.to_csv(output_path)
|
711 |
+
|
712 |
+
def token_to_gene_name(self, item):
|
713 |
+
if isinstance(item,int):
|
714 |
+
return self.gene_id_name_dict.get(self.gene_token_id_dict.get(item, np.nan), np.nan)
|
715 |
+
if isinstance(item,tuple):
|
716 |
+
return tuple([self.gene_id_name_dict.get(self.gene_token_id_dict.get(i, np.nan), np.nan) for i in item])
|
geneformer/pretrainer.py
ADDED
@@ -0,0 +1,822 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Geneformer precollator and pretrainer.
|
3 |
+
|
4 |
+
Huggingface data collator and trainer modified to accommodate single-cell transcriptomics data.
|
5 |
+
"""
|
6 |
+
import collections
|
7 |
+
import math
|
8 |
+
import pickle
|
9 |
+
import warnings
|
10 |
+
from enum import Enum
|
11 |
+
from typing import Dict, Iterator, List, Optional, Union
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
from datasets import Dataset
|
16 |
+
from packaging import version
|
17 |
+
from torch.utils.data.distributed import DistributedSampler
|
18 |
+
from torch.utils.data.sampler import RandomSampler
|
19 |
+
from transformers import (
|
20 |
+
BatchEncoding,
|
21 |
+
DataCollatorForLanguageModeling,
|
22 |
+
SpecialTokensMixin,
|
23 |
+
Trainer,
|
24 |
+
)
|
25 |
+
from transformers.file_utils import is_datasets_available, is_sagemaker_dp_enabled
|
26 |
+
from transformers.trainer_pt_utils import (
|
27 |
+
DistributedLengthGroupedSampler,
|
28 |
+
DistributedSamplerWithLoop,
|
29 |
+
LengthGroupedSampler,
|
30 |
+
)
|
31 |
+
from transformers.training_args import ParallelMode
|
32 |
+
from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
|
33 |
+
from transformers.utils.generic import _is_tensorflow, _is_torch
|
34 |
+
|
35 |
+
from .tokenizer import TOKEN_DICTIONARY_FILE
|
36 |
+
|
37 |
+
logger = logging.get_logger(__name__)
|
38 |
+
EncodedInput = List[int]
|
39 |
+
VERY_LARGE_INTEGER = int(
|
40 |
+
1e30
|
41 |
+
) # This is used to set the max input length for a model with infinite size input
|
42 |
+
LARGE_INTEGER = int(
|
43 |
+
1e20
|
44 |
+
) # This is used when we need something big but slightly smaller than VERY_LARGE_INTEGER
|
45 |
+
|
46 |
+
if is_sagemaker_dp_enabled():
|
47 |
+
import smdistributed.dataparallel.torch.distributed as dist
|
48 |
+
else:
|
49 |
+
import torch.distributed as dist
|
50 |
+
|
51 |
+
_is_torch_generator_available = False
|
52 |
+
if version.parse(torch.__version__) >= version.parse("1.6"):
|
53 |
+
_is_torch_generator_available = True
|
54 |
+
|
55 |
+
with open(TOKEN_DICTIONARY_FILE, "rb") as f:
|
56 |
+
token_dictionary = pickle.load(f)
|
57 |
+
|
58 |
+
|
59 |
+
class ExplicitEnum(Enum):
|
60 |
+
"""
|
61 |
+
Enum with more explicit error message for missing values.
|
62 |
+
"""
|
63 |
+
|
64 |
+
@classmethod
|
65 |
+
def _missing_(cls, value):
|
66 |
+
raise ValueError(
|
67 |
+
"%r is not a valid %s, please select one of %s"
|
68 |
+
% (value, cls.__name__, str(list(cls._value2member_map_.keys())))
|
69 |
+
)
|
70 |
+
|
71 |
+
|
72 |
+
class TruncationStrategy(ExplicitEnum):
|
73 |
+
"""
|
74 |
+
Possible values for the ``truncation`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
|
75 |
+
tab-completion in an IDE.
|
76 |
+
"""
|
77 |
+
|
78 |
+
ONLY_FIRST = "only_first"
|
79 |
+
ONLY_SECOND = "only_second"
|
80 |
+
LONGEST_FIRST = "longest_first"
|
81 |
+
DO_NOT_TRUNCATE = "do_not_truncate"
|
82 |
+
|
83 |
+
|
84 |
+
class PaddingStrategy(ExplicitEnum):
|
85 |
+
"""
|
86 |
+
Possible values for the ``padding`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for tab-completion
|
87 |
+
in an IDE.
|
88 |
+
"""
|
89 |
+
|
90 |
+
LONGEST = "longest"
|
91 |
+
MAX_LENGTH = "max_length"
|
92 |
+
DO_NOT_PAD = "do_not_pad"
|
93 |
+
|
94 |
+
|
95 |
+
class TensorType(ExplicitEnum):
|
96 |
+
"""
|
97 |
+
Possible values for the ``return_tensors`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
|
98 |
+
tab-completion in an IDE.
|
99 |
+
"""
|
100 |
+
|
101 |
+
PYTORCH = "pt"
|
102 |
+
TENSORFLOW = "tf"
|
103 |
+
NUMPY = "np"
|
104 |
+
JAX = "jax"
|
105 |
+
|
106 |
+
|
107 |
+
class GeneformerPreCollator(SpecialTokensMixin):
|
108 |
+
def __init__(self, *args, **kwargs) -> None:
|
109 |
+
|
110 |
+
super().__init__(mask_token = "<mask>", pad_token = "<pad>")
|
111 |
+
|
112 |
+
self.token_dictionary = kwargs.get("token_dictionary")
|
113 |
+
# self.mask_token = "<mask>"
|
114 |
+
# self.mask_token_id = self.token_dictionary.get("<mask>")
|
115 |
+
# self.pad_token = "<pad>"
|
116 |
+
# self.pad_token_id = self.token_dictionary.get("<pad>")
|
117 |
+
self.padding_side = "right"
|
118 |
+
# self.all_special_ids = [
|
119 |
+
# self.token_dictionary.get("<mask>"),
|
120 |
+
# self.token_dictionary.get("<pad>"),
|
121 |
+
# ]
|
122 |
+
self.model_input_names = ["input_ids"]
|
123 |
+
|
124 |
+
def convert_ids_to_tokens(self,value):
|
125 |
+
return self.token_dictionary.get(value)
|
126 |
+
|
127 |
+
def _get_padding_truncation_strategies(
|
128 |
+
self,
|
129 |
+
padding=False,
|
130 |
+
truncation=False,
|
131 |
+
max_length=None,
|
132 |
+
pad_to_multiple_of=None,
|
133 |
+
verbose=True,
|
134 |
+
**kwargs,
|
135 |
+
):
|
136 |
+
"""
|
137 |
+
Find the correct padding/truncation strategy with backward compatibility for old arguments (truncation_strategy
|
138 |
+
and pad_to_max_length) and behaviors.
|
139 |
+
"""
|
140 |
+
old_truncation_strategy = kwargs.pop("truncation_strategy", "do_not_truncate")
|
141 |
+
old_pad_to_max_length = kwargs.pop("pad_to_max_length", False)
|
142 |
+
|
143 |
+
# Backward compatibility for previous behavior, maybe we should deprecate it:
|
144 |
+
# If you only set max_length, it activates truncation for max_length
|
145 |
+
if max_length is not None and padding is False and truncation is False:
|
146 |
+
if verbose:
|
147 |
+
if not self.deprecation_warnings.get(
|
148 |
+
"Truncation-not-explicitly-activated", False
|
149 |
+
):
|
150 |
+
logger.warning(
|
151 |
+
"Truncation was not explicitly activated but `max_length` is provided a specific value, "
|
152 |
+
"please use `truncation=True` to explicitly truncate examples to max length. "
|
153 |
+
"Defaulting to 'longest_first' truncation strategy. "
|
154 |
+
"If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy "
|
155 |
+
"more precisely by providing a specific strategy to `truncation`."
|
156 |
+
)
|
157 |
+
self.deprecation_warnings["Truncation-not-explicitly-activated"] = True
|
158 |
+
truncation = "longest_first"
|
159 |
+
|
160 |
+
# Get padding strategy
|
161 |
+
if padding is False and old_pad_to_max_length:
|
162 |
+
if verbose:
|
163 |
+
warnings.warn(
|
164 |
+
"The `pad_to_max_length` argument is deprecated and will be removed in a future version, "
|
165 |
+
"use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or "
|
166 |
+
"use `padding='max_length'` to pad to a max length. In this case, you can give a specific "
|
167 |
+
"length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the "
|
168 |
+
"maximal input size of the model (e.g. 512 for Bert).",
|
169 |
+
FutureWarning,
|
170 |
+
)
|
171 |
+
if max_length is None:
|
172 |
+
padding_strategy = PaddingStrategy.LONGEST
|
173 |
+
else:
|
174 |
+
padding_strategy = PaddingStrategy.MAX_LENGTH
|
175 |
+
elif padding is not False:
|
176 |
+
if padding is True:
|
177 |
+
padding_strategy = (
|
178 |
+
PaddingStrategy.LONGEST
|
179 |
+
) # Default to pad to the longest sequence in the batch
|
180 |
+
elif not isinstance(padding, PaddingStrategy):
|
181 |
+
padding_strategy = PaddingStrategy(padding)
|
182 |
+
elif isinstance(padding, PaddingStrategy):
|
183 |
+
padding_strategy = padding
|
184 |
+
else:
|
185 |
+
padding_strategy = PaddingStrategy.DO_NOT_PAD
|
186 |
+
|
187 |
+
# Get truncation strategy
|
188 |
+
if truncation is False and old_truncation_strategy != "do_not_truncate":
|
189 |
+
if verbose:
|
190 |
+
warnings.warn(
|
191 |
+
"The `truncation_strategy` argument is deprecated and will be removed in a future version, "
|
192 |
+
"use `truncation=True` to truncate examples to a max length. You can give a specific "
|
193 |
+
"length with `max_length` (e.g. `max_length=45`) or leave max_length to None to truncate to the "
|
194 |
+
"maximal input size of the model (e.g. 512 for Bert). "
|
195 |
+
" If you have pairs of inputs, you can give a specific truncation strategy selected among "
|
196 |
+
"`truncation='only_first'` (will only truncate the first sentence in the pairs) "
|
197 |
+
"`truncation='only_second'` (will only truncate the second sentence in the pairs) "
|
198 |
+
"or `truncation='longest_first'` (will iteratively remove tokens from the longest sentence in the pairs).",
|
199 |
+
FutureWarning,
|
200 |
+
)
|
201 |
+
truncation_strategy = TruncationStrategy(old_truncation_strategy)
|
202 |
+
elif truncation is not False:
|
203 |
+
if truncation is True:
|
204 |
+
truncation_strategy = (
|
205 |
+
TruncationStrategy.LONGEST_FIRST
|
206 |
+
) # Default to truncate the longest sequences in pairs of inputs
|
207 |
+
elif not isinstance(truncation, TruncationStrategy):
|
208 |
+
truncation_strategy = TruncationStrategy(truncation)
|
209 |
+
elif isinstance(truncation, TruncationStrategy):
|
210 |
+
truncation_strategy = truncation
|
211 |
+
else:
|
212 |
+
truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
|
213 |
+
|
214 |
+
# Set max length if needed
|
215 |
+
if max_length is None:
|
216 |
+
if padding_strategy == PaddingStrategy.MAX_LENGTH:
|
217 |
+
if self.model_max_length > LARGE_INTEGER:
|
218 |
+
if verbose:
|
219 |
+
if not self.deprecation_warnings.get(
|
220 |
+
"Asking-to-pad-to-max_length", False
|
221 |
+
):
|
222 |
+
logger.warning(
|
223 |
+
"Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. "
|
224 |
+
"Default to no padding."
|
225 |
+
)
|
226 |
+
self.deprecation_warnings["Asking-to-pad-to-max_length"] = True
|
227 |
+
padding_strategy = PaddingStrategy.DO_NOT_PAD
|
228 |
+
else:
|
229 |
+
max_length = self.model_max_length
|
230 |
+
|
231 |
+
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE:
|
232 |
+
if self.model_max_length > LARGE_INTEGER:
|
233 |
+
if verbose:
|
234 |
+
if not self.deprecation_warnings.get(
|
235 |
+
"Asking-to-truncate-to-max_length", False
|
236 |
+
):
|
237 |
+
logger.warning(
|
238 |
+
"Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. "
|
239 |
+
"Default to no truncation."
|
240 |
+
)
|
241 |
+
self.deprecation_warnings[
|
242 |
+
"Asking-to-truncate-to-max_length"
|
243 |
+
] = True
|
244 |
+
truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
|
245 |
+
else:
|
246 |
+
max_length = self.model_max_length
|
247 |
+
|
248 |
+
# Test if we have a padding token
|
249 |
+
if padding_strategy != PaddingStrategy.DO_NOT_PAD and (
|
250 |
+
not self.pad_token or self.pad_token_id < 0
|
251 |
+
):
|
252 |
+
raise ValueError(
|
253 |
+
"Asking to pad but the tokenizer does not have a padding token. "
|
254 |
+
"Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` "
|
255 |
+
"or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`."
|
256 |
+
)
|
257 |
+
|
258 |
+
# Check that we will truncate to a multiple of pad_to_multiple_of if both are provided
|
259 |
+
if (
|
260 |
+
truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE
|
261 |
+
and padding_strategy != PaddingStrategy.DO_NOT_PAD
|
262 |
+
and pad_to_multiple_of is not None
|
263 |
+
and max_length is not None
|
264 |
+
and (max_length % pad_to_multiple_of != 0)
|
265 |
+
):
|
266 |
+
raise ValueError(
|
267 |
+
f"Truncation and padding are both activated but "
|
268 |
+
f"truncation length ({max_length}) is not a multiple of pad_to_multiple_of ({pad_to_multiple_of})."
|
269 |
+
)
|
270 |
+
|
271 |
+
return padding_strategy, truncation_strategy, max_length, kwargs
|
272 |
+
|
273 |
+
def pad(
|
274 |
+
self,
|
275 |
+
encoded_inputs: Union[
|
276 |
+
BatchEncoding,
|
277 |
+
List[BatchEncoding],
|
278 |
+
Dict[str, EncodedInput],
|
279 |
+
Dict[str, List[EncodedInput]],
|
280 |
+
List[Dict[str, EncodedInput]],
|
281 |
+
],
|
282 |
+
padding: Union[bool, str, PaddingStrategy] = True,
|
283 |
+
max_length: Optional[int] = None,
|
284 |
+
pad_to_multiple_of: Optional[int] = None,
|
285 |
+
return_attention_mask: Optional[bool] = True,
|
286 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
287 |
+
verbose: bool = True,
|
288 |
+
) -> BatchEncoding:
|
289 |
+
"""
|
290 |
+
Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length
|
291 |
+
in the batch.
|
292 |
+
|
293 |
+
Padding side (left/right) padding token ids are defined at the tokenizer level (with ``self.padding_side``,
|
294 |
+
``self.pad_token_id`` and ``self.pad_token_type_id``)
|
295 |
+
|
296 |
+
.. note::
|
297 |
+
|
298 |
+
If the ``encoded_inputs`` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the
|
299 |
+
result will use the same type unless you provide a different tensor type with ``return_tensors``. In the
|
300 |
+
case of PyTorch tensors, you will lose the specific device of your tensors however.
|
301 |
+
|
302 |
+
Args:
|
303 |
+
encoded_inputs (:class:`~transformers.BatchEncoding`, list of :class:`~transformers.BatchEncoding`, :obj:`Dict[str, List[int]]`, :obj:`Dict[str, List[List[int]]` or :obj:`List[Dict[str, List[int]]]`):
|
304 |
+
Tokenized inputs. Can represent one input (:class:`~transformers.BatchEncoding` or :obj:`Dict[str,
|
305 |
+
List[int]]`) or a batch of tokenized inputs (list of :class:`~transformers.BatchEncoding`, `Dict[str,
|
306 |
+
List[List[int]]]` or `List[Dict[str, List[int]]]`) so you can use this method during preprocessing as
|
307 |
+
well as in a PyTorch Dataloader collate function.
|
308 |
+
|
309 |
+
Instead of :obj:`List[int]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors),
|
310 |
+
see the note above for the return type.
|
311 |
+
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
312 |
+
Select a strategy to pad the returned sequences (according to the model's padding side and padding
|
313 |
+
index) among:
|
314 |
+
|
315 |
+
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
|
316 |
+
single sequence if provided).
|
317 |
+
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
318 |
+
maximum acceptable input length for the model if that argument is not provided.
|
319 |
+
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
320 |
+
different lengths).
|
321 |
+
max_length (:obj:`int`, `optional`):
|
322 |
+
Maximum length of the returned list and optionally padding length (see above).
|
323 |
+
pad_to_multiple_of (:obj:`int`, `optional`):
|
324 |
+
If set will pad the sequence to a multiple of the provided value.
|
325 |
+
|
326 |
+
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
|
327 |
+
>= 7.5 (Volta).
|
328 |
+
return_attention_mask (:obj:`bool`, `optional`):
|
329 |
+
Whether to return the attention mask. If left to the default, will return the attention mask according
|
330 |
+
to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute.
|
331 |
+
|
332 |
+
`What are attention masks? <../glossary.html#attention-mask>`__
|
333 |
+
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`):
|
334 |
+
If set, will return tensors instead of list of python integers. Acceptable values are:
|
335 |
+
|
336 |
+
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
|
337 |
+
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
|
338 |
+
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
|
339 |
+
verbose (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
340 |
+
Whether or not to print more information and warnings.
|
341 |
+
"""
|
342 |
+
# If we have a list of dicts, let's convert it in a dict of lists
|
343 |
+
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
|
344 |
+
if isinstance(encoded_inputs, (list, tuple)) and isinstance(
|
345 |
+
encoded_inputs[0], (dict, BatchEncoding)
|
346 |
+
):
|
347 |
+
encoded_inputs = {
|
348 |
+
key: [example[key] for example in encoded_inputs]
|
349 |
+
for key in encoded_inputs[0].keys()
|
350 |
+
}
|
351 |
+
|
352 |
+
# The model's main input name, usually `input_ids`, has be passed for padding
|
353 |
+
if self.model_input_names[0] not in encoded_inputs:
|
354 |
+
raise ValueError(
|
355 |
+
"You should supply an encoding or a list of encodings to this method"
|
356 |
+
f"that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}"
|
357 |
+
)
|
358 |
+
|
359 |
+
required_input = encoded_inputs[self.model_input_names[0]]
|
360 |
+
|
361 |
+
if not required_input:
|
362 |
+
if return_attention_mask:
|
363 |
+
encoded_inputs["attention_mask"] = []
|
364 |
+
return encoded_inputs
|
365 |
+
|
366 |
+
# If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects
|
367 |
+
# and rebuild them afterwards if no return_tensors is specified
|
368 |
+
# Note that we lose the specific device the tensor may be on for PyTorch
|
369 |
+
|
370 |
+
first_element = required_input[0]
|
371 |
+
if isinstance(first_element, (list, tuple)):
|
372 |
+
# first_element might be an empty list/tuple in some edge cases so we grab the first non empty element.
|
373 |
+
index = 0
|
374 |
+
while len(required_input[index]) == 0:
|
375 |
+
index += 1
|
376 |
+
if index < len(required_input):
|
377 |
+
first_element = required_input[index][0]
|
378 |
+
# At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
|
379 |
+
if not isinstance(first_element, (int, list, tuple)):
|
380 |
+
if is_tf_available() and _is_tensorflow(first_element):
|
381 |
+
return_tensors = "tf" if return_tensors is None else return_tensors
|
382 |
+
elif is_torch_available() and _is_torch(first_element):
|
383 |
+
return_tensors = "pt" if return_tensors is None else return_tensors
|
384 |
+
if isinstance(first_element, np.ndarray):
|
385 |
+
return_tensors = "np" if return_tensors is None else return_tensors
|
386 |
+
else:
|
387 |
+
raise ValueError(
|
388 |
+
f"type of {first_element} unknown: {type(first_element)}. "
|
389 |
+
f"Should be one of a python, numpy, pytorch or tensorflow object."
|
390 |
+
)
|
391 |
+
|
392 |
+
for key, value in encoded_inputs.items():
|
393 |
+
encoded_inputs[key] = to_py_obj(value)
|
394 |
+
|
395 |
+
|
396 |
+
# Convert padding_strategy in PaddingStrategy
|
397 |
+
padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies(
|
398 |
+
padding=padding, max_length=max_length, verbose=verbose
|
399 |
+
)
|
400 |
+
|
401 |
+
required_input = encoded_inputs[self.model_input_names[0]]
|
402 |
+
if required_input and not isinstance(required_input[0], (list, tuple)):
|
403 |
+
encoded_inputs = self._pad(
|
404 |
+
encoded_inputs,
|
405 |
+
max_length=max_length,
|
406 |
+
padding_strategy=padding_strategy,
|
407 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
408 |
+
return_attention_mask=return_attention_mask,
|
409 |
+
)
|
410 |
+
return BatchEncoding(encoded_inputs, tensor_type=return_tensors)
|
411 |
+
|
412 |
+
batch_size = len(required_input)
|
413 |
+
assert all(
|
414 |
+
len(v) == batch_size for v in encoded_inputs.values()
|
415 |
+
), "Some items in the output dictionary have a different batch size than others."
|
416 |
+
|
417 |
+
if padding_strategy == PaddingStrategy.LONGEST:
|
418 |
+
max_length = max(len(inputs) for inputs in required_input)
|
419 |
+
padding_strategy = PaddingStrategy.MAX_LENGTH
|
420 |
+
|
421 |
+
batch_outputs = {}
|
422 |
+
for i in range(batch_size):
|
423 |
+
inputs = dict((k, v[i]) for k, v in encoded_inputs.items())
|
424 |
+
outputs = self._pad(
|
425 |
+
inputs,
|
426 |
+
max_length=max_length,
|
427 |
+
padding_strategy=padding_strategy,
|
428 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
429 |
+
return_attention_mask=return_attention_mask,
|
430 |
+
)
|
431 |
+
|
432 |
+
for key, value in outputs.items():
|
433 |
+
if key not in batch_outputs:
|
434 |
+
batch_outputs[key] = []
|
435 |
+
batch_outputs[key].append(value)
|
436 |
+
|
437 |
+
return BatchEncoding(batch_outputs, tensor_type=return_tensors)
|
438 |
+
|
439 |
+
def _pad(
|
440 |
+
self,
|
441 |
+
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
|
442 |
+
max_length: Optional[int] = None,
|
443 |
+
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
444 |
+
pad_to_multiple_of: Optional[int] = None,
|
445 |
+
return_attention_mask: Optional[bool] = None,
|
446 |
+
) -> dict:
|
447 |
+
"""
|
448 |
+
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
|
449 |
+
|
450 |
+
Args:
|
451 |
+
encoded_inputs: Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
|
452 |
+
max_length: maximum length of the returned list and optionally padding length (see below).
|
453 |
+
Will truncate by taking into account the special tokens.
|
454 |
+
padding_strategy: PaddingStrategy to use for padding.
|
455 |
+
|
456 |
+
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
|
457 |
+
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
|
458 |
+
- PaddingStrategy.DO_NOT_PAD: Do not pad
|
459 |
+
The tokenizer padding sides are defined in self.padding_side:
|
460 |
+
|
461 |
+
- 'left': pads on the left of the sequences
|
462 |
+
- 'right': pads on the right of the sequences
|
463 |
+
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
|
464 |
+
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
|
465 |
+
>= 7.5 (Volta).
|
466 |
+
return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics)
|
467 |
+
"""
|
468 |
+
# Load from model defaults
|
469 |
+
if return_attention_mask is None:
|
470 |
+
return_attention_mask = "attention_mask" in self.model_input_names
|
471 |
+
|
472 |
+
required_input = encoded_inputs[self.model_input_names[0]]
|
473 |
+
|
474 |
+
if padding_strategy == PaddingStrategy.LONGEST:
|
475 |
+
max_length = len(required_input)
|
476 |
+
|
477 |
+
if (
|
478 |
+
max_length is not None
|
479 |
+
and pad_to_multiple_of is not None
|
480 |
+
and (max_length % pad_to_multiple_of != 0)
|
481 |
+
):
|
482 |
+
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
483 |
+
|
484 |
+
needs_to_be_padded = (
|
485 |
+
padding_strategy != PaddingStrategy.DO_NOT_PAD
|
486 |
+
and len(required_input) != max_length
|
487 |
+
)
|
488 |
+
|
489 |
+
if needs_to_be_padded:
|
490 |
+
difference = max_length - len(required_input)
|
491 |
+
if self.padding_side == "right":
|
492 |
+
if return_attention_mask:
|
493 |
+
encoded_inputs["attention_mask"] = [1] * len(required_input) + [
|
494 |
+
0
|
495 |
+
] * difference
|
496 |
+
if "token_type_ids" in encoded_inputs:
|
497 |
+
encoded_inputs["token_type_ids"] = (
|
498 |
+
encoded_inputs["token_type_ids"]
|
499 |
+
+ [self.pad_token_type_id] * difference
|
500 |
+
)
|
501 |
+
if "special_tokens_mask" in encoded_inputs:
|
502 |
+
encoded_inputs["special_tokens_mask"] = (
|
503 |
+
encoded_inputs["special_tokens_mask"] + [1] * difference
|
504 |
+
)
|
505 |
+
encoded_inputs[self.model_input_names[0]] = (
|
506 |
+
required_input + [self.pad_token_id] * difference
|
507 |
+
)
|
508 |
+
elif self.padding_side == "left":
|
509 |
+
if return_attention_mask:
|
510 |
+
encoded_inputs["attention_mask"] = [0] * difference + [1] * len(
|
511 |
+
required_input
|
512 |
+
)
|
513 |
+
if "token_type_ids" in encoded_inputs:
|
514 |
+
encoded_inputs["token_type_ids"] = [
|
515 |
+
self.pad_token_type_id
|
516 |
+
] * difference + encoded_inputs["token_type_ids"]
|
517 |
+
if "special_tokens_mask" in encoded_inputs:
|
518 |
+
encoded_inputs["special_tokens_mask"] = [
|
519 |
+
1
|
520 |
+
] * difference + encoded_inputs["special_tokens_mask"]
|
521 |
+
encoded_inputs[self.model_input_names[0]] = [
|
522 |
+
self.pad_token_id
|
523 |
+
] * difference + required_input
|
524 |
+
else:
|
525 |
+
raise ValueError("Invalid padding strategy:" + str(self.padding_side))
|
526 |
+
elif return_attention_mask and "attention_mask" not in encoded_inputs:
|
527 |
+
encoded_inputs["attention_mask"] = [1] * len(required_input)
|
528 |
+
|
529 |
+
return encoded_inputs
|
530 |
+
|
531 |
+
def get_special_tokens_mask(
|
532 |
+
self,
|
533 |
+
token_ids_0: List[int],
|
534 |
+
token_ids_1: Optional[List[int]] = None,
|
535 |
+
already_has_special_tokens: bool = False,
|
536 |
+
) -> List[int]:
|
537 |
+
"""
|
538 |
+
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
539 |
+
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
|
540 |
+
Args:
|
541 |
+
token_ids_0 (:obj:`List[int]`):
|
542 |
+
List of ids of the first sequence.
|
543 |
+
token_ids_1 (:obj:`List[int]`, `optional`):
|
544 |
+
List of ids of the second sequence.
|
545 |
+
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
546 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
547 |
+
Returns:
|
548 |
+
A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
549 |
+
"""
|
550 |
+
assert already_has_special_tokens and token_ids_1 is None, (
|
551 |
+
"You cannot use ``already_has_special_tokens=False`` with this tokenizer. "
|
552 |
+
"Please use a slow (full python) tokenizer to activate this argument."
|
553 |
+
"Or set `return_special_tokens_mask=True` when calling the encoding method "
|
554 |
+
"to get the special tokens mask in any tokenizer. "
|
555 |
+
)
|
556 |
+
|
557 |
+
all_special_ids = self.all_special_ids # cache the property
|
558 |
+
|
559 |
+
special_tokens_mask = [
|
560 |
+
1 if token in all_special_ids else 0 for token in token_ids_0
|
561 |
+
]
|
562 |
+
|
563 |
+
return special_tokens_mask
|
564 |
+
|
565 |
+
def convert_tokens_to_ids(
|
566 |
+
self, tokens: Union[str, List[str]]
|
567 |
+
) -> Union[int, List[int]]:
|
568 |
+
"""
|
569 |
+
Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the
|
570 |
+
vocabulary.
|
571 |
+
Args:
|
572 |
+
tokens (:obj:`str` or :obj:`List[str]`): One or several token(s) to convert to token id(s).
|
573 |
+
Returns:
|
574 |
+
:obj:`int` or :obj:`List[int]`: The token id or list of token ids.
|
575 |
+
"""
|
576 |
+
if tokens is None:
|
577 |
+
return None
|
578 |
+
|
579 |
+
if isinstance(tokens, str):
|
580 |
+
return self._convert_token_to_id_with_added_voc(tokens)
|
581 |
+
|
582 |
+
ids = []
|
583 |
+
for token in tokens:
|
584 |
+
ids.append(self._convert_token_to_id_with_added_voc(token))
|
585 |
+
return ids
|
586 |
+
|
587 |
+
def _convert_token_to_id_with_added_voc(self, token):
|
588 |
+
if token is None:
|
589 |
+
return None
|
590 |
+
|
591 |
+
return self.token_dictionary.get(token)
|
592 |
+
|
593 |
+
def __len__(self):
|
594 |
+
return len(self.token_dictionary)
|
595 |
+
|
596 |
+
|
597 |
+
class GeneformerPretrainer(Trainer):
|
598 |
+
def __init__(self, *args, **kwargs):
|
599 |
+
data_collator = kwargs.get("data_collator",None)
|
600 |
+
token_dictionary = kwargs.pop("token_dictionary")
|
601 |
+
|
602 |
+
if data_collator is None:
|
603 |
+
precollator = GeneformerPreCollator(token_dictionary=token_dictionary)
|
604 |
+
|
605 |
+
# # Data Collator Functions
|
606 |
+
data_collator = DataCollatorForLanguageModeling(
|
607 |
+
tokenizer=precollator, mlm=True, mlm_probability=0.15
|
608 |
+
)
|
609 |
+
kwargs["data_collator"] = data_collator
|
610 |
+
|
611 |
+
# load previously saved length vector for dataset to speed up LengthGroupedSampler
|
612 |
+
# pre-obtained with [dataset[i]["length"] for i in range(len(dataset))]
|
613 |
+
example_lengths_file = kwargs.pop("example_lengths_file")
|
614 |
+
if example_lengths_file:
|
615 |
+
with open(example_lengths_file, "rb") as f:
|
616 |
+
self.example_lengths = pickle.load(f)
|
617 |
+
else:
|
618 |
+
raise Exception(
|
619 |
+
"example_lengths_file is required; e.g. https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/genecorpus_30M_2048_sorted_lengths.pkl"
|
620 |
+
)
|
621 |
+
super().__init__(*args, **kwargs)
|
622 |
+
|
623 |
+
# modify LengthGroupedSampler to avoid dataset[length_column_name] hanging
|
624 |
+
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
625 |
+
if not isinstance(self.train_dataset, collections.abc.Sized):
|
626 |
+
return None
|
627 |
+
|
628 |
+
generator = None
|
629 |
+
if self.args.world_size <= 1 and _is_torch_generator_available:
|
630 |
+
generator = torch.Generator()
|
631 |
+
generator.manual_seed(
|
632 |
+
int(torch.empty((), dtype=torch.int64).random_().item())
|
633 |
+
)
|
634 |
+
|
635 |
+
# Build the sampler.
|
636 |
+
if self.args.group_by_length:
|
637 |
+
if is_datasets_available() and isinstance(self.train_dataset, Dataset):
|
638 |
+
lengths = self.example_lengths
|
639 |
+
else:
|
640 |
+
lengths = None
|
641 |
+
model_input_name = (
|
642 |
+
self.tokenizer.model_input_names[0]
|
643 |
+
if self.tokenizer is not None
|
644 |
+
else None
|
645 |
+
)
|
646 |
+
if self.args.world_size <= 1:
|
647 |
+
return LengthGroupedSampler(
|
648 |
+
dataset=self.train_dataset,
|
649 |
+
batch_size=self.args.train_batch_size,
|
650 |
+
lengths=lengths,
|
651 |
+
model_input_name=model_input_name,
|
652 |
+
generator=generator,
|
653 |
+
)
|
654 |
+
else:
|
655 |
+
return CustomDistributedLengthGroupedSampler(
|
656 |
+
dataset=self.train_dataset,
|
657 |
+
batch_size=self.args.train_batch_size,
|
658 |
+
num_replicas=self.args.world_size,
|
659 |
+
rank=self.args.process_index,
|
660 |
+
lengths=lengths,
|
661 |
+
model_input_name=model_input_name,
|
662 |
+
seed=self.args.seed,
|
663 |
+
)
|
664 |
+
|
665 |
+
else:
|
666 |
+
if self.args.world_size <= 1:
|
667 |
+
if _is_torch_generator_available:
|
668 |
+
return RandomSampler(self.train_dataset, generator=generator)
|
669 |
+
return RandomSampler(self.train_dataset)
|
670 |
+
elif (
|
671 |
+
self.args.parallel_mode
|
672 |
+
in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL]
|
673 |
+
and not self.args.dataloader_drop_last
|
674 |
+
):
|
675 |
+
# Use a loop for TPUs when drop_last is False to have all batches have the same size.
|
676 |
+
return DistributedSamplerWithLoop(
|
677 |
+
self.train_dataset,
|
678 |
+
batch_size=self.args.per_device_train_batch_size,
|
679 |
+
num_replicas=self.args.world_size,
|
680 |
+
rank=self.args.process_index,
|
681 |
+
seed=self.args.seed,
|
682 |
+
)
|
683 |
+
else:
|
684 |
+
return DistributedSampler(
|
685 |
+
self.train_dataset,
|
686 |
+
num_replicas=self.args.world_size,
|
687 |
+
rank=self.args.process_index,
|
688 |
+
seed=self.args.seed,
|
689 |
+
)
|
690 |
+
|
691 |
+
|
692 |
+
class CustomDistributedLengthGroupedSampler(DistributedLengthGroupedSampler):
|
693 |
+
r"""
|
694 |
+
Distributed Sampler that samples indices in a way that groups together features of the dataset of roughly the same
|
695 |
+
length while keeping a bit of randomness.
|
696 |
+
"""
|
697 |
+
# Copied and adapted from PyTorch DistributedSampler.
|
698 |
+
def __init__(
|
699 |
+
self,
|
700 |
+
dataset: Dataset,
|
701 |
+
batch_size: int,
|
702 |
+
num_replicas: Optional[int] = None,
|
703 |
+
rank: Optional[int] = None,
|
704 |
+
seed: int = 0,
|
705 |
+
drop_last: bool = False,
|
706 |
+
lengths: Optional[List[int]] = None,
|
707 |
+
model_input_name: Optional[str] = None,
|
708 |
+
):
|
709 |
+
if num_replicas is None:
|
710 |
+
if not dist.is_available():
|
711 |
+
raise RuntimeError("Requires distributed package to be available")
|
712 |
+
num_replicas = dist.get_world_size()
|
713 |
+
if rank is None:
|
714 |
+
if not dist.is_available():
|
715 |
+
raise RuntimeError("Requires distributed package to be available")
|
716 |
+
rank = dist.get_rank()
|
717 |
+
self.dataset = dataset
|
718 |
+
self.batch_size = batch_size
|
719 |
+
self.num_replicas = num_replicas
|
720 |
+
self.rank = rank
|
721 |
+
self.epoch = 0
|
722 |
+
self.drop_last = drop_last
|
723 |
+
# If the dataset length is evenly divisible by # of replicas, then there
|
724 |
+
# is no need to drop any data, since the dataset will be split equally.
|
725 |
+
if self.drop_last and len(self.dataset) % self.num_replicas != 0:
|
726 |
+
# Split to nearest available length that is evenly divisible.
|
727 |
+
# This is to ensure each rank receives the same amount of data when
|
728 |
+
# using this Sampler.
|
729 |
+
self.num_samples = math.ceil(
|
730 |
+
(len(self.dataset) - self.num_replicas) / self.num_replicas
|
731 |
+
)
|
732 |
+
else:
|
733 |
+
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
|
734 |
+
self.total_size = self.num_samples * self.num_replicas
|
735 |
+
self.seed = seed
|
736 |
+
self.model_input_name = (
|
737 |
+
model_input_name if model_input_name is not None else "input_ids"
|
738 |
+
)
|
739 |
+
|
740 |
+
if lengths is None:
|
741 |
+
print("Lengths is none - calculating lengths.")
|
742 |
+
if (
|
743 |
+
not (
|
744 |
+
isinstance(dataset[0], dict)
|
745 |
+
or isinstance(dataset[0], BatchEncoding)
|
746 |
+
)
|
747 |
+
or self.model_input_name not in dataset[0]
|
748 |
+
):
|
749 |
+
raise ValueError(
|
750 |
+
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
|
751 |
+
f"'{self.model_input_name}' key."
|
752 |
+
)
|
753 |
+
lengths = [len(feature[self.model_input_name]) for feature in dataset]
|
754 |
+
self.lengths = lengths
|
755 |
+
|
756 |
+
def __iter__(self) -> Iterator:
|
757 |
+
# Deterministically shuffle based on epoch and seed
|
758 |
+
g = torch.Generator()
|
759 |
+
g.manual_seed(self.seed + self.epoch)
|
760 |
+
|
761 |
+
indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=g)
|
762 |
+
|
763 |
+
if not self.drop_last:
|
764 |
+
# add extra samples to make it evenly divisible
|
765 |
+
indices += indices[: (self.total_size - len(indices))]
|
766 |
+
else:
|
767 |
+
# remove tail of data to make it evenly divisible.
|
768 |
+
indices = indices[: self.total_size]
|
769 |
+
assert len(indices) == self.total_size
|
770 |
+
|
771 |
+
# subsample
|
772 |
+
indices = indices[self.rank : self.total_size : self.num_replicas]
|
773 |
+
assert len(indices) == self.num_samples
|
774 |
+
|
775 |
+
return iter(indices)
|
776 |
+
|
777 |
+
|
778 |
+
def get_length_grouped_indices(
|
779 |
+
lengths, batch_size, mega_batch_mult=None, generator=None
|
780 |
+
):
|
781 |
+
"""
|
782 |
+
Return a list of indices so that each slice of :obj:`batch_size` consecutive indices correspond to elements of
|
783 |
+
similar lengths. To do this, the indices are:
|
784 |
+
|
785 |
+
- randomly permuted
|
786 |
+
- grouped in mega-batches of size :obj:`mega_batch_mult * batch_size`
|
787 |
+
- sorted by length in each mega-batch
|
788 |
+
|
789 |
+
The result is the concatenation of all mega-batches, with the batch of :obj:`batch_size` containing the element of
|
790 |
+
maximum length placed first, so that an OOM happens sooner rather than later.
|
791 |
+
"""
|
792 |
+
# Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
|
793 |
+
if mega_batch_mult is None:
|
794 |
+
# mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
|
795 |
+
mega_batch_mult = min(len(lengths) // (batch_size * 4), 1000)
|
796 |
+
# Just in case, for tiny datasets
|
797 |
+
if mega_batch_mult == 0:
|
798 |
+
mega_batch_mult = 1
|
799 |
+
|
800 |
+
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
|
801 |
+
indices = torch.randperm(len(lengths), generator=generator)
|
802 |
+
megabatch_size = mega_batch_mult * batch_size
|
803 |
+
megabatches = [
|
804 |
+
indices[i : i + megabatch_size].tolist()
|
805 |
+
for i in range(0, len(lengths), megabatch_size)
|
806 |
+
]
|
807 |
+
megabatches = [
|
808 |
+
list(sorted(megabatch, key=lambda i: lengths[i], reverse=True))
|
809 |
+
for megabatch in megabatches
|
810 |
+
]
|
811 |
+
|
812 |
+
# The rest is to get the biggest batch first.
|
813 |
+
# Since each megabatch is sorted by descending length, the longest element is the first
|
814 |
+
megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
|
815 |
+
max_idx = torch.argmax(torch.tensor(megabatch_maximums)).item()
|
816 |
+
# Switch to put the longest element in first position
|
817 |
+
megabatches[0][0], megabatches[max_idx][0] = (
|
818 |
+
megabatches[max_idx][0],
|
819 |
+
megabatches[0][0],
|
820 |
+
)
|
821 |
+
|
822 |
+
return [item for sublist in megabatches for item in sublist]
|
geneformer/token_dictionary.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fcf53d1c87c08786f73aaf7c09da9778bfb8299e86b03411daa4143ac64ac0a7
|
3 |
+
size 270111
|
geneformer/tokenizer.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Geneformer tokenizer.
|
3 |
+
|
4 |
+
Input data:
|
5 |
+
Required format: raw counts scRNAseq data without feature selection as .loom file
|
6 |
+
Required row (gene) attribute: "ensembl_id"; Ensembl ID for each gene
|
7 |
+
Required col (cell) attribute: "n_counts"; total read counts in that cell
|
8 |
+
Optional col (cell) attribute: "filter_pass"; binary indicator of whether cell should be tokenized based on user-defined filtering criteria
|
9 |
+
Optional col (cell) attributes: any other cell metadata can be passed on to the tokenized dataset as a custom attribute dictionary as shown below
|
10 |
+
|
11 |
+
Usage:
|
12 |
+
from geneformer import TranscriptomeTokenizer
|
13 |
+
tk = TranscriptomeTokenizer({"cell_type": "cell_type", "organ_major": "organ_major"}, nproc=4)
|
14 |
+
tk.tokenize_data("loom_data_directory", "output_directory", "output_prefix")
|
15 |
+
"""
|
16 |
+
|
17 |
+
import pickle
|
18 |
+
from pathlib import Path
|
19 |
+
|
20 |
+
import logging
|
21 |
+
|
22 |
+
import warnings
|
23 |
+
warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*")
|
24 |
+
|
25 |
+
import loompy as lp
|
26 |
+
import numpy as np
|
27 |
+
from datasets import Dataset
|
28 |
+
|
29 |
+
logger = logging.getLogger(__name__)
|
30 |
+
|
31 |
+
GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl"
|
32 |
+
TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl"
|
33 |
+
|
34 |
+
|
35 |
+
def tokenize_cell(gene_vector, gene_tokens):
|
36 |
+
"""
|
37 |
+
Convert normalized gene expression vector to tokenized rank value encoding.
|
38 |
+
"""
|
39 |
+
# create array of gene vector with token indices
|
40 |
+
# mask undetected genes
|
41 |
+
nonzero_mask = np.nonzero(gene_vector)[0]
|
42 |
+
# sort by median-scaled gene values
|
43 |
+
sorted_indices = np.argsort(-gene_vector[nonzero_mask])
|
44 |
+
# tokenize
|
45 |
+
sentence_tokens = gene_tokens[nonzero_mask][sorted_indices]
|
46 |
+
return sentence_tokens
|
47 |
+
|
48 |
+
|
49 |
+
class TranscriptomeTokenizer:
|
50 |
+
def __init__(
|
51 |
+
self,
|
52 |
+
custom_attr_name_dict=None,
|
53 |
+
nproc=1,
|
54 |
+
gene_median_file=GENE_MEDIAN_FILE,
|
55 |
+
token_dictionary_file=TOKEN_DICTIONARY_FILE,
|
56 |
+
):
|
57 |
+
"""
|
58 |
+
Initialize tokenizer.
|
59 |
+
|
60 |
+
Parameters
|
61 |
+
----------
|
62 |
+
custom_attr_name_dict : None, dict
|
63 |
+
Dictionary of custom attributes to be added to the dataset.
|
64 |
+
Keys are the names of the attributes in the loom file.
|
65 |
+
Values are the names of the attributes in the dataset.
|
66 |
+
nproc : int
|
67 |
+
Number of processes to use for dataset mapping.
|
68 |
+
gene_median_file : Path
|
69 |
+
Path to pickle file containing dictionary of non-zero median
|
70 |
+
gene expression values across Genecorpus-30M.
|
71 |
+
token_dictionary_file : Path
|
72 |
+
Path to pickle file containing token dictionary (Ensembl IDs:token).
|
73 |
+
"""
|
74 |
+
# dictionary of custom attributes {output dataset column name: input .loom column name}
|
75 |
+
self.custom_attr_name_dict = custom_attr_name_dict
|
76 |
+
|
77 |
+
# number of processes for dataset mapping
|
78 |
+
self.nproc = nproc
|
79 |
+
|
80 |
+
# load dictionary of gene normalization factors
|
81 |
+
# (non-zero median value of expression across Genecorpus-30M)
|
82 |
+
with open(gene_median_file, "rb") as f:
|
83 |
+
self.gene_median_dict = pickle.load(f)
|
84 |
+
|
85 |
+
# load token dictionary (Ensembl IDs:token)
|
86 |
+
with open(token_dictionary_file, "rb") as f:
|
87 |
+
self.gene_token_dict = pickle.load(f)
|
88 |
+
|
89 |
+
# gene keys for full vocabulary
|
90 |
+
self.gene_keys = list(self.gene_median_dict.keys())
|
91 |
+
|
92 |
+
# protein-coding and miRNA gene list dictionary for selecting .loom rows for tokenization
|
93 |
+
self.genelist_dict = dict(zip(self.gene_keys, [True] * len(self.gene_keys)))
|
94 |
+
|
95 |
+
def tokenize_data(self, loom_data_directory, output_directory, output_prefix):
|
96 |
+
"""
|
97 |
+
Tokenize .loom files in loom_data_directory and save as tokenized .dataset in output_directory.
|
98 |
+
|
99 |
+
Parameters
|
100 |
+
----------
|
101 |
+
loom_data_directory : Path
|
102 |
+
Path to directory containing loom files
|
103 |
+
output_directory : Path
|
104 |
+
Path to directory where tokenized data will be saved as .dataset
|
105 |
+
output_prefix : str
|
106 |
+
Prefix for output .dataset
|
107 |
+
"""
|
108 |
+
tokenized_cells, cell_metadata = self.tokenize_files(Path(loom_data_directory))
|
109 |
+
tokenized_dataset = self.create_dataset(tokenized_cells, cell_metadata)
|
110 |
+
|
111 |
+
output_path = (Path(output_directory) / output_prefix).with_suffix(".dataset")
|
112 |
+
tokenized_dataset.save_to_disk(output_path)
|
113 |
+
|
114 |
+
def tokenize_files(self, loom_data_directory):
|
115 |
+
tokenized_cells = []
|
116 |
+
if self.custom_attr_name_dict is not None:
|
117 |
+
loom_cell_attr = [attr_key for attr_key in self.custom_attr_name_dict.keys()]
|
118 |
+
cell_metadata = {attr_key: [] for attr_key in self.custom_attr_name_dict.values()}
|
119 |
+
|
120 |
+
# loops through directories to tokenize .loom files
|
121 |
+
file_found = 0
|
122 |
+
for loom_file_path in loom_data_directory.glob("*.loom"):
|
123 |
+
file_found = 1
|
124 |
+
print(f"Tokenizing {loom_file_path}")
|
125 |
+
file_tokenized_cells, file_cell_metadata = self.tokenize_file(
|
126 |
+
loom_file_path
|
127 |
+
)
|
128 |
+
tokenized_cells += file_tokenized_cells
|
129 |
+
if self.custom_attr_name_dict is not None:
|
130 |
+
for k in loom_cell_attr:
|
131 |
+
cell_metadata[self.custom_attr_name_dict[k]] += file_cell_metadata[k]
|
132 |
+
else:
|
133 |
+
cell_metadata = None
|
134 |
+
|
135 |
+
if file_found == 0:
|
136 |
+
logger.error(
|
137 |
+
f"No .loom files found in directory {loom_data_directory}.")
|
138 |
+
raise
|
139 |
+
return tokenized_cells, cell_metadata
|
140 |
+
|
141 |
+
def tokenize_file(self, loom_file_path):
|
142 |
+
if self.custom_attr_name_dict is not None:
|
143 |
+
file_cell_metadata = {
|
144 |
+
attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
|
145 |
+
}
|
146 |
+
|
147 |
+
with lp.connect(str(loom_file_path)) as data:
|
148 |
+
# define coordinates of detected protein-coding or miRNA genes and vector of their normalization factors
|
149 |
+
coding_miRNA_loc = np.where(
|
150 |
+
[self.genelist_dict.get(i, False) for i in data.ra["ensembl_id"]]
|
151 |
+
)[0]
|
152 |
+
norm_factor_vector = np.array(
|
153 |
+
[
|
154 |
+
self.gene_median_dict[i]
|
155 |
+
for i in data.ra["ensembl_id"][coding_miRNA_loc]
|
156 |
+
]
|
157 |
+
)
|
158 |
+
coding_miRNA_ids = data.ra["ensembl_id"][coding_miRNA_loc]
|
159 |
+
coding_miRNA_tokens = np.array(
|
160 |
+
[self.gene_token_dict[i] for i in coding_miRNA_ids]
|
161 |
+
)
|
162 |
+
|
163 |
+
# define coordinates of cells passing filters for inclusion (e.g. QC)
|
164 |
+
try:
|
165 |
+
data.ca["filter_pass"]
|
166 |
+
except AttributeError:
|
167 |
+
var_exists = False
|
168 |
+
else:
|
169 |
+
var_exists = True
|
170 |
+
|
171 |
+
if var_exists is True:
|
172 |
+
filter_pass_loc = np.where(
|
173 |
+
[True if i == 1 else False for i in data.ca["filter_pass"]]
|
174 |
+
)[0]
|
175 |
+
elif var_exists is False:
|
176 |
+
print(
|
177 |
+
f"{loom_file_path} has no column attribute 'filter_pass'; tokenizing all cells."
|
178 |
+
)
|
179 |
+
filter_pass_loc = np.array([i for i in range(data.shape[1])])
|
180 |
+
|
181 |
+
# scan through .loom files and tokenize cells
|
182 |
+
tokenized_cells = []
|
183 |
+
for (_ix, _selection, view) in data.scan(items=filter_pass_loc, axis=1):
|
184 |
+
# select subview with protein-coding and miRNA genes
|
185 |
+
subview = view.view[coding_miRNA_loc, :]
|
186 |
+
|
187 |
+
# normalize by total counts per cell and multiply by 10,000 to allocate bits to precision
|
188 |
+
# and normalize by gene normalization factors
|
189 |
+
subview_norm_array = (
|
190 |
+
subview[:, :]
|
191 |
+
/ subview.ca.n_counts
|
192 |
+
* 10_000
|
193 |
+
/ norm_factor_vector[:, None]
|
194 |
+
)
|
195 |
+
# tokenize subview gene vectors
|
196 |
+
tokenized_cells += [
|
197 |
+
tokenize_cell(subview_norm_array[:, i], coding_miRNA_tokens)
|
198 |
+
for i in range(subview_norm_array.shape[1])
|
199 |
+
]
|
200 |
+
|
201 |
+
# add custom attributes for subview to dict
|
202 |
+
if self.custom_attr_name_dict is not None:
|
203 |
+
for k in file_cell_metadata.keys():
|
204 |
+
file_cell_metadata[k] += subview.ca[k].tolist()
|
205 |
+
else:
|
206 |
+
file_cell_metadata = None
|
207 |
+
|
208 |
+
return tokenized_cells, file_cell_metadata
|
209 |
+
|
210 |
+
def create_dataset(self, tokenized_cells, cell_metadata):
|
211 |
+
# create dict for dataset creation
|
212 |
+
dataset_dict = {"input_ids": tokenized_cells}
|
213 |
+
if self.custom_attr_name_dict is not None:
|
214 |
+
dataset_dict.update(cell_metadata)
|
215 |
+
|
216 |
+
# create dataset
|
217 |
+
output_dataset = Dataset.from_dict(dataset_dict)
|
218 |
+
|
219 |
+
# truncate dataset
|
220 |
+
def truncate(example):
|
221 |
+
example["input_ids"] = example["input_ids"][0:2048]
|
222 |
+
return example
|
223 |
+
|
224 |
+
output_dataset_truncated = output_dataset.map(truncate, num_proc=self.nproc)
|
225 |
+
|
226 |
+
# measure lengths of dataset
|
227 |
+
def measure_length(example):
|
228 |
+
example["length"] = len(example["input_ids"])
|
229 |
+
return example
|
230 |
+
|
231 |
+
output_dataset_truncated_w_length = output_dataset_truncated.map(
|
232 |
+
measure_length, num_proc=self.nproc
|
233 |
+
)
|
234 |
+
|
235 |
+
return output_dataset_truncated_w_length
|
generation_config.json
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"pad_token_id": 0,
|
4 |
+
"transformers_version": "4.32.0"
|
5 |
+
}
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:199d33652d295dfe6ef97b3d3dccdc2f528931ffbe683243ec5a70842637e329
|
3 |
+
size 31494773
|
setup.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup
|
2 |
+
|
3 |
+
setup(
|
4 |
+
name="geneformer",
|
5 |
+
version="0.0.1",
|
6 |
+
author="Christina Theodoris",
|
7 |
+
author_email="christina.theodoris@gladstone.ucsf.edu",
|
8 |
+
description="Geneformer is a transformer model pretrained \
|
9 |
+
on a large-scale corpus of ~30 million single \
|
10 |
+
cell transcriptomes to enable context-aware \
|
11 |
+
predictions in settings with limited data in \
|
12 |
+
network biology.",
|
13 |
+
packages=["geneformer"],
|
14 |
+
include_package_data=True,
|
15 |
+
install_requires=[
|
16 |
+
"datasets",
|
17 |
+
"loompy",
|
18 |
+
"numpy",
|
19 |
+
"transformers",
|
20 |
+
],
|
21 |
+
)
|
training_args.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6aa4702ebe332247df4beb04ad957645d492c05ebca9f9b600770a7c658e7800
|
3 |
+
size 4219
|