jinbo1129 commited on
Commit
0d75654
1 Parent(s): d85fc99

commit from jbz 2023.08.29 for the first time

Browse files
MANIFEST.in ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ include geneformer/gene_median_dictionary.pkl
2
+ include geneformer/token_dictionary.pkl
3
+ include geneformer/gene_name_id_dict.pkl
README.md ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ datasets: ctheodoris/Genecorpus-30M
3
+ license: apache-2.0
4
+ ---
5
+ # Geneformer
6
+ Geneformer is a foundation transformer model pretrained on a large-scale corpus of ~30 million single cell transcriptomes to enable context-aware predictions in settings with limited data in network biology.
7
+
8
+ See [our manuscript](https://rdcu.be/ddrx0) for details.
9
+
10
+ # Model Description
11
+ Geneformer is a foundation transformer model pretrained on [Genecorpus-30M](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M), a pretraining corpus comprised of ~30 million single cell transcriptomes from a broad range of human tissues. We excluded cells with high mutational burdens (e.g. malignant cells and immortalized cell lines) that could lead to substantial network rewiring without companion genome sequencing to facilitate interpretation. Each single cell’s transcriptome is presented to the model as a rank value encoding where genes are ranked by their expression in that cell normalized by their expression across the entire Genecorpus-30M. The rank value encoding provides a nonparametric representation of that cell’s transcriptome and takes advantage of the many observations of each gene’s expression across Genecorpus-30M to prioritize genes that distinguish cell state. Specifically, this method will deprioritize ubiquitously highly-expressed housekeeping genes by normalizing them to a lower rank. Conversely, genes such as transcription factors that may be lowly expressed when they are expressed but highly distinguish cell state will move to a higher rank within the encoding. Furthermore, this rank-based approach may be more robust against technical artifacts that may systematically bias the absolute transcript counts value while the overall relative ranking of genes within each cell remains more stable.
12
+
13
+ The rank value encoding of each single cell’s transcriptome then proceeds through six transformer encoder units. Pretraining was accomplished using a masked learning objective where 15% of the genes within each transcriptome were masked and the model was trained to predict which gene should be within each masked position in that specific cell state using the context of the remaining unmasked genes. A major strength of this approach is that it is entirely self-supervised and can be accomplished on completely unlabeled data, which allows the inclusion of large amounts of training data without being restricted to samples with accompanying labels.
14
+
15
+ We detail applications and results in [our manuscript](https://rdcu.be/ddrx0).
16
+
17
+ During pretraining, Geneformer gained a fundamental understanding of network dynamics, encoding network hierarchy in the model’s attention weights in a completely self-supervised manner. Fine-tuning Geneformer towards a diverse panel of downstream tasks relevant to chromatin and network dynamics using limited task-specific data demonstrated that Geneformer consistently boosted predictive accuracy. Applied to disease modeling with limited patient data, Geneformer identified candidate therapeutic targets. Overall, Geneformer represents a pretrained deep learning model from which fine-tuning towards a broad range of downstream applications can be pursued to accelerate discovery of key network regulators and candidate therapeutic targets.
18
+
19
+ In [our manuscript](https://rdcu.be/ddrx0), we report results for the 6 layer Geneformer model pretrained on Genecorpus-30M. We additionally provide within this repository a 12 layer Geneformer model, scaled up with retained width:depth aspect ratio, also pretrained on Genecorpus-30M.
20
+
21
+ # Application
22
+ The pretrained Geneformer model can be used directly for zero-shot learning, for example for in silico perturbation analysis, or by fine-tuning towards the relevant downstream task, such as gene or cell state classification.
23
+
24
+ Example applications demonstrated in [our manuscript](https://rdcu.be/ddrx0) include:
25
+
26
+ *Fine-tuning*:
27
+ - transcription factor dosage sensitivity
28
+ - chromatin dynamics (bivalently marked promoters)
29
+ - transcription factor regulatory range
30
+ - gene network centrality
31
+ - transcription factor targets
32
+ - cell type annotation
33
+ - batch integration
34
+ - cell state classification across differentiation
35
+ - disease classification
36
+ - in silico perturbation to determine disease-driving genes
37
+ - in silico treatment to determine candidate therapeutic targets
38
+
39
+ *Zero-shot learning*:
40
+ - batch integration
41
+ - gene context specificity
42
+ - in silico reprogramming
43
+ - in silico differentiation
44
+ - in silico perturbation to determine impact on cell state
45
+ - in silico perturbation to determine transcription factor targets
46
+ - in silico perturbation to determine transcription factor cooperativity
47
+
48
+ # Installation
49
+ In addition to the pretrained model, contained herein are functions for tokenizing and collating data specific to single cell transcriptomics, pretraining the model, fine-tuning the model, extracting and plotting cell embeddings, and performing in silico pertrubation with either the pretrained or fine-tuned models. To install:
50
+
51
+ ```bash
52
+ git clone https://huggingface.co/ctheodoris/Geneformer
53
+ cd Geneformer
54
+ pip install .
55
+ ```
56
+
57
+ For usage, see [examples](https://huggingface.co/ctheodoris/Geneformer/tree/main/examples) for:
58
+ - tokenizing transcriptomes
59
+ - pretraining
60
+ - hyperparameter tuning
61
+ - fine-tuning
62
+ - extracting and plotting cell embeddings
63
+ - in silico perturbation
64
+
65
+ Please note that the fine-tuning examples are meant to be generally applicable and the input datasets and labels will vary dependent on the downstream task. Example input files for a few of the downstream tasks demonstrated in the manuscript are located within the [example_input_files directory](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files) in the dataset repository, but these only represent a few example fine-tuning applications.
66
+
67
+ Please note that GPU resources are required for efficient usage of Geneformer. Additionally, we strongly recommend tuning hyperparameters for each downstream fine-tuning application as this can significantly boost predictive potential in the downstream task (e.g. max learning rate, learning schedule, number of layers to freeze, etc.).
config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertForMaskedLM"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.02,
6
+ "classifier_dropout": null,
7
+ "hidden_act": "relu",
8
+ "hidden_dropout_prob": 0.02,
9
+ "hidden_size": 256,
10
+ "initializer_range": 0.02,
11
+ "intermediate_size": 512,
12
+ "layer_norm_eps": 1e-12,
13
+ "max_position_embeddings": 2048,
14
+ "model_type": "bert",
15
+ "num_attention_heads": 4,
16
+ "num_hidden_layers": 6,
17
+ "pad_token_id": 0,
18
+ "position_embedding_type": "absolute",
19
+ "torch_dtype": "float32",
20
+ "transformers_version": "4.32.0",
21
+ "type_vocab_size": 2,
22
+ "use_cache": true,
23
+ "vocab_size": 15994
24
+ }
examples/cell_classification.ipynb ADDED
@@ -0,0 +1,1952 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "234afff3",
6
+ "metadata": {},
7
+ "source": [
8
+ "## Geneformer Fine-Tuning for Cell Annotation Application"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": 2,
14
+ "id": "1cbe6178-ea4d-478a-80a8-65ffaa4c1820",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "import os\n",
19
+ "GPU_NUMBER = [0]\n",
20
+ "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \",\".join([str(s) for s in GPU_NUMBER])\n",
21
+ "os.environ[\"NCCL_DEBUG\"] = \"INFO\""
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": 3,
27
+ "id": "a9885d9f-00ac-4c84-b6a3-b7b648a90f0f",
28
+ "metadata": {},
29
+ "outputs": [],
30
+ "source": [
31
+ "# imports\n",
32
+ "from collections import Counter\n",
33
+ "import datetime\n",
34
+ "import pickle\n",
35
+ "import subprocess\n",
36
+ "import seaborn as sns; sns.set()\n",
37
+ "from datasets import load_from_disk\n",
38
+ "from sklearn.metrics import accuracy_score, f1_score\n",
39
+ "from transformers import BertForSequenceClassification\n",
40
+ "from transformers import Trainer\n",
41
+ "from transformers.training_args import TrainingArguments\n",
42
+ "\n",
43
+ "from geneformer import DataCollatorForCellClassification"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "markdown",
48
+ "id": "68bd3b98-5409-4105-b7af-f1ff64ea6a72",
49
+ "metadata": {},
50
+ "source": [
51
+ "## Prepare training and evaluation datasets"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": 15,
57
+ "id": "5735f1b7-7595-4a02-be17-2c5b970ad81a",
58
+ "metadata": {},
59
+ "outputs": [],
60
+ "source": [
61
+ "# load cell type dataset (includes all tissues)\n",
62
+ "train_dataset=load_from_disk(\"/path/to/cell_type_train_data.dataset\")"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "execution_count": null,
68
+ "id": "a4297a02-4c4c-434c-ae55-3387a0b239b5",
69
+ "metadata": {
70
+ "collapsed": true,
71
+ "jupyter": {
72
+ "outputs_hidden": true
73
+ },
74
+ "tags": []
75
+ },
76
+ "outputs": [],
77
+ "source": [
78
+ "dataset_list = []\n",
79
+ "evalset_list = []\n",
80
+ "organ_list = []\n",
81
+ "target_dict_list = []\n",
82
+ "\n",
83
+ "for organ in Counter(train_dataset[\"organ_major\"]).keys():\n",
84
+ " # collect list of tissues for fine-tuning (immune and bone marrow are included together)\n",
85
+ " if organ in [\"bone_marrow\"]: \n",
86
+ " continue\n",
87
+ " elif organ==\"immune\":\n",
88
+ " organ_ids = [\"immune\",\"bone_marrow\"]\n",
89
+ " organ_list += [\"immune\"]\n",
90
+ " else:\n",
91
+ " organ_ids = [organ]\n",
92
+ " organ_list += [organ]\n",
93
+ " \n",
94
+ " print(organ)\n",
95
+ " \n",
96
+ " # filter datasets for given organ\n",
97
+ " def if_organ(example):\n",
98
+ " return example[\"organ_major\"] in organ_ids\n",
99
+ " trainset_organ = train_dataset.filter(if_organ, num_proc=16)\n",
100
+ " \n",
101
+ " # per scDeepsort published method, drop cell types representing <0.5% of cells\n",
102
+ " celltype_counter = Counter(trainset_organ[\"cell_type\"])\n",
103
+ " total_cells = sum(celltype_counter.values())\n",
104
+ " cells_to_keep = [k for k,v in celltype_counter.items() if v>(0.005*total_cells)]\n",
105
+ " def if_not_rare_celltype(example):\n",
106
+ " return example[\"cell_type\"] in cells_to_keep\n",
107
+ " trainset_organ_subset = trainset_organ.filter(if_not_rare_celltype, num_proc=16)\n",
108
+ " \n",
109
+ " # shuffle datasets and rename columns\n",
110
+ " trainset_organ_shuffled = trainset_organ_subset.shuffle(seed=42)\n",
111
+ " trainset_organ_shuffled = trainset_organ_shuffled.rename_column(\"cell_type\",\"label\")\n",
112
+ " trainset_organ_shuffled = trainset_organ_shuffled.remove_columns(\"organ_major\")\n",
113
+ " \n",
114
+ " # create dictionary of cell types : label ids\n",
115
+ " target_names = list(Counter(trainset_organ_shuffled[\"label\"]).keys())\n",
116
+ " target_name_id_dict = dict(zip(target_names,[i for i in range(len(target_names))]))\n",
117
+ " target_dict_list += [target_name_id_dict]\n",
118
+ " \n",
119
+ " # change labels to numerical ids\n",
120
+ " def classes_to_ids(example):\n",
121
+ " example[\"label\"] = target_name_id_dict[example[\"label\"]]\n",
122
+ " return example\n",
123
+ " labeled_trainset = trainset_organ_shuffled.map(classes_to_ids, num_proc=16)\n",
124
+ " \n",
125
+ " # create 80/20 train/eval splits\n",
126
+ " labeled_train_split = labeled_trainset.select([i for i in range(0,round(len(labeled_trainset)*0.8))])\n",
127
+ " labeled_eval_split = labeled_trainset.select([i for i in range(round(len(labeled_trainset)*0.8),len(labeled_trainset))])\n",
128
+ " \n",
129
+ " # filter dataset for cell types in corresponding training set\n",
130
+ " trained_labels = list(Counter(labeled_train_split[\"label\"]).keys())\n",
131
+ " def if_trained_label(example):\n",
132
+ " return example[\"label\"] in trained_labels\n",
133
+ " labeled_eval_split_subset = labeled_eval_split.filter(if_trained_label, num_proc=16)\n",
134
+ "\n",
135
+ " dataset_list += [labeled_train_split]\n",
136
+ " evalset_list += [labeled_eval_split_subset]"
137
+ ]
138
+ },
139
+ {
140
+ "cell_type": "code",
141
+ "execution_count": 20,
142
+ "id": "83e20521-597a-4c54-897b-c4d42ea622c2",
143
+ "metadata": {},
144
+ "outputs": [],
145
+ "source": [
146
+ "trainset_dict = dict(zip(organ_list,dataset_list))\n",
147
+ "traintargetdict_dict = dict(zip(organ_list,target_dict_list))\n",
148
+ "\n",
149
+ "evalset_dict = dict(zip(organ_list,evalset_list))"
150
+ ]
151
+ },
152
+ {
153
+ "cell_type": "markdown",
154
+ "id": "10eb110d-ba43-4efc-bc43-1815d6912647",
155
+ "metadata": {},
156
+ "source": [
157
+ "## Fine-Tune With Cell Classification Learning Objective and Quantify Predictive Performance"
158
+ ]
159
+ },
160
+ {
161
+ "cell_type": "code",
162
+ "execution_count": 18,
163
+ "id": "cd7b1cfb-f5cb-460e-ae77-769522ece054",
164
+ "metadata": {},
165
+ "outputs": [],
166
+ "source": [
167
+ "def compute_metrics(pred):\n",
168
+ " labels = pred.label_ids\n",
169
+ " preds = pred.predictions.argmax(-1)\n",
170
+ " # calculate accuracy and macro f1 using sklearn's function\n",
171
+ " acc = accuracy_score(labels, preds)\n",
172
+ " macro_f1 = f1_score(labels, preds, average='macro')\n",
173
+ " return {\n",
174
+ " 'accuracy': acc,\n",
175
+ " 'macro_f1': macro_f1\n",
176
+ " }"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "markdown",
181
+ "id": "beaab7a4-cc13-4e8f-b137-ed18ff7b633c",
182
+ "metadata": {},
183
+ "source": [
184
+ "### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example hyperparameters are defined below, but please see the \"hyperparam_optimiz_for_disease_classifier\" script for an example of how to tune hyperparameters for downstream applications."
185
+ ]
186
+ },
187
+ {
188
+ "cell_type": "code",
189
+ "execution_count": 19,
190
+ "id": "d24e1ab7-0131-44bd-b458-1ce5ba31853e",
191
+ "metadata": {},
192
+ "outputs": [],
193
+ "source": [
194
+ "# set model parameters\n",
195
+ "# max input size\n",
196
+ "max_input_size = 2 ** 11 # 2048\n",
197
+ "\n",
198
+ "# set training hyperparameters\n",
199
+ "# max learning rate\n",
200
+ "max_lr = 5e-5\n",
201
+ "# how many pretrained layers to freeze\n",
202
+ "freeze_layers = 0\n",
203
+ "# number gpus\n",
204
+ "num_gpus = 1\n",
205
+ "# number cpu cores\n",
206
+ "num_proc = 16\n",
207
+ "# batch size for training and eval\n",
208
+ "geneformer_batch_size = 12\n",
209
+ "# learning schedule\n",
210
+ "lr_schedule_fn = \"linear\"\n",
211
+ "# warmup steps\n",
212
+ "warmup_steps = 500\n",
213
+ "# number of epochs\n",
214
+ "epochs = 10\n",
215
+ "# optimizer\n",
216
+ "optimizer = \"adamw\""
217
+ ]
218
+ },
219
+ {
220
+ "cell_type": "code",
221
+ "execution_count": 20,
222
+ "id": "05164c24-5fbf-4372-b26c-a43f3777a88d",
223
+ "metadata": {},
224
+ "outputs": [
225
+ {
226
+ "name": "stderr",
227
+ "output_type": "stream",
228
+ "text": [
229
+ "Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n",
230
+ "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
231
+ "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
232
+ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias']\n",
233
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
234
+ ]
235
+ },
236
+ {
237
+ "name": "stdout",
238
+ "output_type": "stream",
239
+ "text": [
240
+ "spleen\n"
241
+ ]
242
+ },
243
+ {
244
+ "name": "stderr",
245
+ "output_type": "stream",
246
+ "text": [
247
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
248
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
249
+ ]
250
+ },
251
+ {
252
+ "data": {
253
+ "text/html": [
254
+ "\n",
255
+ " <div>\n",
256
+ " \n",
257
+ " <progress value='10280' max='10280' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
258
+ " [10280/10280 13:33, Epoch 10/10]\n",
259
+ " </div>\n",
260
+ " <table border=\"1\" class=\"dataframe\">\n",
261
+ " <thead>\n",
262
+ " <tr style=\"text-align: left;\">\n",
263
+ " <th>Epoch</th>\n",
264
+ " <th>Training Loss</th>\n",
265
+ " <th>Validation Loss</th>\n",
266
+ " <th>Accuracy</th>\n",
267
+ " <th>Macro F1</th>\n",
268
+ " <th>Weighted F1</th>\n",
269
+ " </tr>\n",
270
+ " </thead>\n",
271
+ " <tbody>\n",
272
+ " <tr>\n",
273
+ " <td>1</td>\n",
274
+ " <td>0.087000</td>\n",
275
+ " <td>0.068067</td>\n",
276
+ " <td>0.985404</td>\n",
277
+ " <td>0.956839</td>\n",
278
+ " <td>0.985483</td>\n",
279
+ " </tr>\n",
280
+ " <tr>\n",
281
+ " <td>2</td>\n",
282
+ " <td>0.044400</td>\n",
283
+ " <td>0.075289</td>\n",
284
+ " <td>0.985079</td>\n",
285
+ " <td>0.955069</td>\n",
286
+ " <td>0.984898</td>\n",
287
+ " </tr>\n",
288
+ " <tr>\n",
289
+ " <td>3</td>\n",
290
+ " <td>0.066700</td>\n",
291
+ " <td>0.078703</td>\n",
292
+ " <td>0.983782</td>\n",
293
+ " <td>0.953240</td>\n",
294
+ " <td>0.983959</td>\n",
295
+ " </tr>\n",
296
+ " <tr>\n",
297
+ " <td>4</td>\n",
298
+ " <td>0.037400</td>\n",
299
+ " <td>0.057132</td>\n",
300
+ " <td>0.989945</td>\n",
301
+ " <td>0.970619</td>\n",
302
+ " <td>0.989883</td>\n",
303
+ " </tr>\n",
304
+ " <tr>\n",
305
+ " <td>5</td>\n",
306
+ " <td>0.025000</td>\n",
307
+ " <td>0.061644</td>\n",
308
+ " <td>0.988323</td>\n",
309
+ " <td>0.961126</td>\n",
310
+ " <td>0.988211</td>\n",
311
+ " </tr>\n",
312
+ " <tr>\n",
313
+ " <td>6</td>\n",
314
+ " <td>0.022400</td>\n",
315
+ " <td>0.065323</td>\n",
316
+ " <td>0.989296</td>\n",
317
+ " <td>0.969737</td>\n",
318
+ " <td>0.989362</td>\n",
319
+ " </tr>\n",
320
+ " <tr>\n",
321
+ " <td>7</td>\n",
322
+ " <td>0.018600</td>\n",
323
+ " <td>0.063710</td>\n",
324
+ " <td>0.989620</td>\n",
325
+ " <td>0.969436</td>\n",
326
+ " <td>0.989579</td>\n",
327
+ " </tr>\n",
328
+ " <tr>\n",
329
+ " <td>8</td>\n",
330
+ " <td>0.039800</td>\n",
331
+ " <td>0.065919</td>\n",
332
+ " <td>0.989945</td>\n",
333
+ " <td>0.968065</td>\n",
334
+ " <td>0.989802</td>\n",
335
+ " </tr>\n",
336
+ " <tr>\n",
337
+ " <td>9</td>\n",
338
+ " <td>0.030200</td>\n",
339
+ " <td>0.061359</td>\n",
340
+ " <td>0.990269</td>\n",
341
+ " <td>0.971700</td>\n",
342
+ " <td>0.990314</td>\n",
343
+ " </tr>\n",
344
+ " <tr>\n",
345
+ " <td>10</td>\n",
346
+ " <td>0.013400</td>\n",
347
+ " <td>0.059181</td>\n",
348
+ " <td>0.991567</td>\n",
349
+ " <td>0.974599</td>\n",
350
+ " <td>0.991552</td>\n",
351
+ " </tr>\n",
352
+ " </tbody>\n",
353
+ "</table><p>"
354
+ ],
355
+ "text/plain": [
356
+ "<IPython.core.display.HTML object>"
357
+ ]
358
+ },
359
+ "metadata": {},
360
+ "output_type": "display_data"
361
+ },
362
+ {
363
+ "name": "stderr",
364
+ "output_type": "stream",
365
+ "text": [
366
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
367
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
368
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
369
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
370
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
371
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
372
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
373
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
374
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
375
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
376
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
377
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
378
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
379
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
380
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
381
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
382
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
383
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
384
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
385
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
386
+ ]
387
+ },
388
+ {
389
+ "data": {
390
+ "text/html": [
391
+ "\n",
392
+ " <div>\n",
393
+ " \n",
394
+ " <progress value='257' max='257' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
395
+ " [257/257 00:07]\n",
396
+ " </div>\n",
397
+ " "
398
+ ],
399
+ "text/plain": [
400
+ "<IPython.core.display.HTML object>"
401
+ ]
402
+ },
403
+ "metadata": {},
404
+ "output_type": "display_data"
405
+ },
406
+ {
407
+ "name": "stderr",
408
+ "output_type": "stream",
409
+ "text": [
410
+ "Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n",
411
+ "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
412
+ "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
413
+ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias']\n",
414
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
415
+ ]
416
+ },
417
+ {
418
+ "name": "stdout",
419
+ "output_type": "stream",
420
+ "text": [
421
+ "kidney\n"
422
+ ]
423
+ },
424
+ {
425
+ "name": "stderr",
426
+ "output_type": "stream",
427
+ "text": [
428
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
429
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
430
+ ]
431
+ },
432
+ {
433
+ "data": {
434
+ "text/html": [
435
+ "\n",
436
+ " <div>\n",
437
+ " \n",
438
+ " <progress value='29340' max='29340' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
439
+ " [29340/29340 45:43, Epoch 10/10]\n",
440
+ " </div>\n",
441
+ " <table border=\"1\" class=\"dataframe\">\n",
442
+ " <thead>\n",
443
+ " <tr style=\"text-align: left;\">\n",
444
+ " <th>Epoch</th>\n",
445
+ " <th>Training Loss</th>\n",
446
+ " <th>Validation Loss</th>\n",
447
+ " <th>Accuracy</th>\n",
448
+ " <th>Macro F1</th>\n",
449
+ " <th>Weighted F1</th>\n",
450
+ " </tr>\n",
451
+ " </thead>\n",
452
+ " <tbody>\n",
453
+ " <tr>\n",
454
+ " <td>1</td>\n",
455
+ " <td>0.326900</td>\n",
456
+ " <td>0.299193</td>\n",
457
+ " <td>0.912500</td>\n",
458
+ " <td>0.823067</td>\n",
459
+ " <td>0.909627</td>\n",
460
+ " </tr>\n",
461
+ " <tr>\n",
462
+ " <td>2</td>\n",
463
+ " <td>0.224200</td>\n",
464
+ " <td>0.239580</td>\n",
465
+ " <td>0.926477</td>\n",
466
+ " <td>0.850237</td>\n",
467
+ " <td>0.923902</td>\n",
468
+ " </tr>\n",
469
+ " <tr>\n",
470
+ " <td>3</td>\n",
471
+ " <td>0.221600</td>\n",
472
+ " <td>0.242810</td>\n",
473
+ " <td>0.930227</td>\n",
474
+ " <td>0.878553</td>\n",
475
+ " <td>0.930349</td>\n",
476
+ " </tr>\n",
477
+ " <tr>\n",
478
+ " <td>4</td>\n",
479
+ " <td>0.166100</td>\n",
480
+ " <td>0.264178</td>\n",
481
+ " <td>0.933409</td>\n",
482
+ " <td>0.884759</td>\n",
483
+ " <td>0.933031</td>\n",
484
+ " </tr>\n",
485
+ " <tr>\n",
486
+ " <td>5</td>\n",
487
+ " <td>0.144100</td>\n",
488
+ " <td>0.279282</td>\n",
489
+ " <td>0.935000</td>\n",
490
+ " <td>0.887659</td>\n",
491
+ " <td>0.934987</td>\n",
492
+ " </tr>\n",
493
+ " <tr>\n",
494
+ " <td>6</td>\n",
495
+ " <td>0.112800</td>\n",
496
+ " <td>0.307647</td>\n",
497
+ " <td>0.935909</td>\n",
498
+ " <td>0.889239</td>\n",
499
+ " <td>0.935365</td>\n",
500
+ " </tr>\n",
501
+ " <tr>\n",
502
+ " <td>7</td>\n",
503
+ " <td>0.084600</td>\n",
504
+ " <td>0.326399</td>\n",
505
+ " <td>0.932841</td>\n",
506
+ " <td>0.892447</td>\n",
507
+ " <td>0.933191</td>\n",
508
+ " </tr>\n",
509
+ " <tr>\n",
510
+ " <td>8</td>\n",
511
+ " <td>0.068300</td>\n",
512
+ " <td>0.332626</td>\n",
513
+ " <td>0.936591</td>\n",
514
+ " <td>0.891629</td>\n",
515
+ " <td>0.936354</td>\n",
516
+ " </tr>\n",
517
+ " <tr>\n",
518
+ " <td>9</td>\n",
519
+ " <td>0.065500</td>\n",
520
+ " <td>0.348174</td>\n",
521
+ " <td>0.935227</td>\n",
522
+ " <td>0.889484</td>\n",
523
+ " <td>0.935040</td>\n",
524
+ " </tr>\n",
525
+ " <tr>\n",
526
+ " <td>10</td>\n",
527
+ " <td>0.046100</td>\n",
528
+ " <td>0.355350</td>\n",
529
+ " <td>0.935000</td>\n",
530
+ " <td>0.894578</td>\n",
531
+ " <td>0.934971</td>\n",
532
+ " </tr>\n",
533
+ " </tbody>\n",
534
+ "</table><p>"
535
+ ],
536
+ "text/plain": [
537
+ "<IPython.core.display.HTML object>"
538
+ ]
539
+ },
540
+ "metadata": {},
541
+ "output_type": "display_data"
542
+ },
543
+ {
544
+ "name": "stderr",
545
+ "output_type": "stream",
546
+ "text": [
547
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
548
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
549
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
550
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
551
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
552
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
553
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
554
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
555
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
556
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
557
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
558
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
559
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
560
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
561
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
562
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
563
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
564
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
565
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
566
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
567
+ ]
568
+ },
569
+ {
570
+ "data": {
571
+ "text/html": [
572
+ "\n",
573
+ " <div>\n",
574
+ " \n",
575
+ " <progress value='734' max='734' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
576
+ " [734/734 00:27]\n",
577
+ " </div>\n",
578
+ " "
579
+ ],
580
+ "text/plain": [
581
+ "<IPython.core.display.HTML object>"
582
+ ]
583
+ },
584
+ "metadata": {},
585
+ "output_type": "display_data"
586
+ },
587
+ {
588
+ "name": "stderr",
589
+ "output_type": "stream",
590
+ "text": [
591
+ "Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n",
592
+ "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
593
+ "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
594
+ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias']\n",
595
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
596
+ ]
597
+ },
598
+ {
599
+ "name": "stdout",
600
+ "output_type": "stream",
601
+ "text": [
602
+ "lung\n"
603
+ ]
604
+ },
605
+ {
606
+ "name": "stderr",
607
+ "output_type": "stream",
608
+ "text": [
609
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
610
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
611
+ ]
612
+ },
613
+ {
614
+ "data": {
615
+ "text/html": [
616
+ "\n",
617
+ " <div>\n",
618
+ " \n",
619
+ " <progress value='21750' max='21750' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
620
+ " [21750/21750 30:32, Epoch 10/10]\n",
621
+ " </div>\n",
622
+ " <table border=\"1\" class=\"dataframe\">\n",
623
+ " <thead>\n",
624
+ " <tr style=\"text-align: left;\">\n",
625
+ " <th>Epoch</th>\n",
626
+ " <th>Training Loss</th>\n",
627
+ " <th>Validation Loss</th>\n",
628
+ " <th>Accuracy</th>\n",
629
+ " <th>Macro F1</th>\n",
630
+ " <th>Weighted F1</th>\n",
631
+ " </tr>\n",
632
+ " </thead>\n",
633
+ " <tbody>\n",
634
+ " <tr>\n",
635
+ " <td>1</td>\n",
636
+ " <td>0.337600</td>\n",
637
+ " <td>0.341523</td>\n",
638
+ " <td>0.906360</td>\n",
639
+ " <td>0.759979</td>\n",
640
+ " <td>0.899310</td>\n",
641
+ " </tr>\n",
642
+ " <tr>\n",
643
+ " <td>2</td>\n",
644
+ " <td>0.211900</td>\n",
645
+ " <td>0.258954</td>\n",
646
+ " <td>0.928429</td>\n",
647
+ " <td>0.835534</td>\n",
648
+ " <td>0.925903</td>\n",
649
+ " </tr>\n",
650
+ " <tr>\n",
651
+ " <td>3</td>\n",
652
+ " <td>0.208600</td>\n",
653
+ " <td>0.282081</td>\n",
654
+ " <td>0.930421</td>\n",
655
+ " <td>0.842786</td>\n",
656
+ " <td>0.928013</td>\n",
657
+ " </tr>\n",
658
+ " <tr>\n",
659
+ " <td>4</td>\n",
660
+ " <td>0.144400</td>\n",
661
+ " <td>0.253047</td>\n",
662
+ " <td>0.935479</td>\n",
663
+ " <td>0.871712</td>\n",
664
+ " <td>0.935234</td>\n",
665
+ " </tr>\n",
666
+ " <tr>\n",
667
+ " <td>5</td>\n",
668
+ " <td>0.109200</td>\n",
669
+ " <td>0.268833</td>\n",
670
+ " <td>0.939464</td>\n",
671
+ " <td>0.876173</td>\n",
672
+ " <td>0.938870</td>\n",
673
+ " </tr>\n",
674
+ " <tr>\n",
675
+ " <td>6</td>\n",
676
+ " <td>0.132700</td>\n",
677
+ " <td>0.282697</td>\n",
678
+ " <td>0.940536</td>\n",
679
+ " <td>0.883271</td>\n",
680
+ " <td>0.940191</td>\n",
681
+ " </tr>\n",
682
+ " <tr>\n",
683
+ " <td>7</td>\n",
684
+ " <td>0.081800</td>\n",
685
+ " <td>0.295864</td>\n",
686
+ " <td>0.940843</td>\n",
687
+ " <td>0.884201</td>\n",
688
+ " <td>0.940170</td>\n",
689
+ " </tr>\n",
690
+ " <tr>\n",
691
+ " <td>8</td>\n",
692
+ " <td>0.035900</td>\n",
693
+ " <td>0.306600</td>\n",
694
+ " <td>0.941916</td>\n",
695
+ " <td>0.884777</td>\n",
696
+ " <td>0.941578</td>\n",
697
+ " </tr>\n",
698
+ " <tr>\n",
699
+ " <td>9</td>\n",
700
+ " <td>0.050800</td>\n",
701
+ " <td>0.311677</td>\n",
702
+ " <td>0.940536</td>\n",
703
+ " <td>0.883437</td>\n",
704
+ " <td>0.940294</td>\n",
705
+ " </tr>\n",
706
+ " <tr>\n",
707
+ " <td>10</td>\n",
708
+ " <td>0.035800</td>\n",
709
+ " <td>0.315360</td>\n",
710
+ " <td>0.940843</td>\n",
711
+ " <td>0.883551</td>\n",
712
+ " <td>0.940612</td>\n",
713
+ " </tr>\n",
714
+ " </tbody>\n",
715
+ "</table><p>"
716
+ ],
717
+ "text/plain": [
718
+ "<IPython.core.display.HTML object>"
719
+ ]
720
+ },
721
+ "metadata": {},
722
+ "output_type": "display_data"
723
+ },
724
+ {
725
+ "name": "stderr",
726
+ "output_type": "stream",
727
+ "text": [
728
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
729
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
730
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
731
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
732
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
733
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
734
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
735
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
736
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
737
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
738
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
739
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
740
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
741
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
742
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
743
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
744
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
745
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
746
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
747
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
748
+ ]
749
+ },
750
+ {
751
+ "data": {
752
+ "text/html": [
753
+ "\n",
754
+ " <div>\n",
755
+ " \n",
756
+ " <progress value='544' max='544' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
757
+ " [544/544 00:19]\n",
758
+ " </div>\n",
759
+ " "
760
+ ],
761
+ "text/plain": [
762
+ "<IPython.core.display.HTML object>"
763
+ ]
764
+ },
765
+ "metadata": {},
766
+ "output_type": "display_data"
767
+ },
768
+ {
769
+ "name": "stderr",
770
+ "output_type": "stream",
771
+ "text": [
772
+ "Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n",
773
+ "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
774
+ "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
775
+ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias']\n",
776
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
777
+ ]
778
+ },
779
+ {
780
+ "name": "stdout",
781
+ "output_type": "stream",
782
+ "text": [
783
+ "brain\n"
784
+ ]
785
+ },
786
+ {
787
+ "name": "stderr",
788
+ "output_type": "stream",
789
+ "text": [
790
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
791
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
792
+ ]
793
+ },
794
+ {
795
+ "data": {
796
+ "text/html": [
797
+ "\n",
798
+ " <div>\n",
799
+ " \n",
800
+ " <progress value='8880' max='8880' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
801
+ " [8880/8880 11:14, Epoch 10/10]\n",
802
+ " </div>\n",
803
+ " <table border=\"1\" class=\"dataframe\">\n",
804
+ " <thead>\n",
805
+ " <tr style=\"text-align: left;\">\n",
806
+ " <th>Epoch</th>\n",
807
+ " <th>Training Loss</th>\n",
808
+ " <th>Validation Loss</th>\n",
809
+ " <th>Accuracy</th>\n",
810
+ " <th>Macro F1</th>\n",
811
+ " <th>Weighted F1</th>\n",
812
+ " </tr>\n",
813
+ " </thead>\n",
814
+ " <tbody>\n",
815
+ " <tr>\n",
816
+ " <td>1</td>\n",
817
+ " <td>0.163100</td>\n",
818
+ " <td>0.156640</td>\n",
819
+ " <td>0.970345</td>\n",
820
+ " <td>0.736455</td>\n",
821
+ " <td>0.960714</td>\n",
822
+ " </tr>\n",
823
+ " <tr>\n",
824
+ " <td>2</td>\n",
825
+ " <td>0.149800</td>\n",
826
+ " <td>0.134897</td>\n",
827
+ " <td>0.968844</td>\n",
828
+ " <td>0.747114</td>\n",
829
+ " <td>0.960726</td>\n",
830
+ " </tr>\n",
831
+ " <tr>\n",
832
+ " <td>3</td>\n",
833
+ " <td>0.105600</td>\n",
834
+ " <td>0.115354</td>\n",
835
+ " <td>0.972222</td>\n",
836
+ " <td>0.775271</td>\n",
837
+ " <td>0.964932</td>\n",
838
+ " </tr>\n",
839
+ " <tr>\n",
840
+ " <td>4</td>\n",
841
+ " <td>0.086900</td>\n",
842
+ " <td>0.207918</td>\n",
843
+ " <td>0.968844</td>\n",
844
+ " <td>0.707927</td>\n",
845
+ " <td>0.958257</td>\n",
846
+ " </tr>\n",
847
+ " <tr>\n",
848
+ " <td>5</td>\n",
849
+ " <td>0.056400</td>\n",
850
+ " <td>0.106548</td>\n",
851
+ " <td>0.974099</td>\n",
852
+ " <td>0.839838</td>\n",
853
+ " <td>0.971611</td>\n",
854
+ " </tr>\n",
855
+ " <tr>\n",
856
+ " <td>6</td>\n",
857
+ " <td>0.037600</td>\n",
858
+ " <td>0.117437</td>\n",
859
+ " <td>0.978228</td>\n",
860
+ " <td>0.856578</td>\n",
861
+ " <td>0.975665</td>\n",
862
+ " </tr>\n",
863
+ " <tr>\n",
864
+ " <td>7</td>\n",
865
+ " <td>0.030500</td>\n",
866
+ " <td>0.127885</td>\n",
867
+ " <td>0.974474</td>\n",
868
+ " <td>0.856296</td>\n",
869
+ " <td>0.973531</td>\n",
870
+ " </tr>\n",
871
+ " <tr>\n",
872
+ " <td>8</td>\n",
873
+ " <td>0.019300</td>\n",
874
+ " <td>0.143203</td>\n",
875
+ " <td>0.977853</td>\n",
876
+ " <td>0.859362</td>\n",
877
+ " <td>0.975776</td>\n",
878
+ " </tr>\n",
879
+ " <tr>\n",
880
+ " <td>9</td>\n",
881
+ " <td>0.007400</td>\n",
882
+ " <td>0.153758</td>\n",
883
+ " <td>0.972598</td>\n",
884
+ " <td>0.852835</td>\n",
885
+ " <td>0.972314</td>\n",
886
+ " </tr>\n",
887
+ " <tr>\n",
888
+ " <td>10</td>\n",
889
+ " <td>0.017200</td>\n",
890
+ " <td>0.153911</td>\n",
891
+ " <td>0.975976</td>\n",
892
+ " <td>0.858196</td>\n",
893
+ " <td>0.974498</td>\n",
894
+ " </tr>\n",
895
+ " </tbody>\n",
896
+ "</table><p>"
897
+ ],
898
+ "text/plain": [
899
+ "<IPython.core.display.HTML object>"
900
+ ]
901
+ },
902
+ "metadata": {},
903
+ "output_type": "display_data"
904
+ },
905
+ {
906
+ "name": "stderr",
907
+ "output_type": "stream",
908
+ "text": [
909
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
910
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
911
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
912
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
913
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
914
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
915
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
916
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
917
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
918
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
919
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
920
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
921
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
922
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
923
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
924
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
925
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
926
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
927
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
928
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
929
+ ]
930
+ },
931
+ {
932
+ "data": {
933
+ "text/html": [
934
+ "\n",
935
+ " <div>\n",
936
+ " \n",
937
+ " <progress value='222' max='222' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
938
+ " [222/222 00:04]\n",
939
+ " </div>\n",
940
+ " "
941
+ ],
942
+ "text/plain": [
943
+ "<IPython.core.display.HTML object>"
944
+ ]
945
+ },
946
+ "metadata": {},
947
+ "output_type": "display_data"
948
+ },
949
+ {
950
+ "name": "stderr",
951
+ "output_type": "stream",
952
+ "text": [
953
+ "Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n",
954
+ "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
955
+ "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
956
+ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias']\n",
957
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
958
+ ]
959
+ },
960
+ {
961
+ "name": "stdout",
962
+ "output_type": "stream",
963
+ "text": [
964
+ "placenta\n"
965
+ ]
966
+ },
967
+ {
968
+ "name": "stderr",
969
+ "output_type": "stream",
970
+ "text": [
971
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
972
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
973
+ ]
974
+ },
975
+ {
976
+ "data": {
977
+ "text/html": [
978
+ "\n",
979
+ " <div>\n",
980
+ " \n",
981
+ " <progress value='6180' max='6180' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
982
+ " [6180/6180 10:28, Epoch 10/10]\n",
983
+ " </div>\n",
984
+ " <table border=\"1\" class=\"dataframe\">\n",
985
+ " <thead>\n",
986
+ " <tr style=\"text-align: left;\">\n",
987
+ " <th>Epoch</th>\n",
988
+ " <th>Training Loss</th>\n",
989
+ " <th>Validation Loss</th>\n",
990
+ " <th>Accuracy</th>\n",
991
+ " <th>Macro F1</th>\n",
992
+ " <th>Weighted F1</th>\n",
993
+ " </tr>\n",
994
+ " </thead>\n",
995
+ " <tbody>\n",
996
+ " <tr>\n",
997
+ " <td>1</td>\n",
998
+ " <td>0.128700</td>\n",
999
+ " <td>0.125175</td>\n",
1000
+ " <td>0.960626</td>\n",
1001
+ " <td>0.935752</td>\n",
1002
+ " <td>0.959463</td>\n",
1003
+ " </tr>\n",
1004
+ " <tr>\n",
1005
+ " <td>2</td>\n",
1006
+ " <td>0.064000</td>\n",
1007
+ " <td>0.215607</td>\n",
1008
+ " <td>0.951456</td>\n",
1009
+ " <td>0.920579</td>\n",
1010
+ " <td>0.949828</td>\n",
1011
+ " </tr>\n",
1012
+ " <tr>\n",
1013
+ " <td>3</td>\n",
1014
+ " <td>0.051300</td>\n",
1015
+ " <td>0.203044</td>\n",
1016
+ " <td>0.961165</td>\n",
1017
+ " <td>0.934195</td>\n",
1018
+ " <td>0.959470</td>\n",
1019
+ " </tr>\n",
1020
+ " <tr>\n",
1021
+ " <td>4</td>\n",
1022
+ " <td>0.045300</td>\n",
1023
+ " <td>0.115701</td>\n",
1024
+ " <td>0.978964</td>\n",
1025
+ " <td>0.966387</td>\n",
1026
+ " <td>0.978788</td>\n",
1027
+ " </tr>\n",
1028
+ " <tr>\n",
1029
+ " <td>5</td>\n",
1030
+ " <td>0.048200</td>\n",
1031
+ " <td>0.149484</td>\n",
1032
+ " <td>0.973571</td>\n",
1033
+ " <td>0.958927</td>\n",
1034
+ " <td>0.973305</td>\n",
1035
+ " </tr>\n",
1036
+ " <tr>\n",
1037
+ " <td>6</td>\n",
1038
+ " <td>0.040900</td>\n",
1039
+ " <td>0.134339</td>\n",
1040
+ " <td>0.978964</td>\n",
1041
+ " <td>0.967466</td>\n",
1042
+ " <td>0.978899</td>\n",
1043
+ " </tr>\n",
1044
+ " <tr>\n",
1045
+ " <td>7</td>\n",
1046
+ " <td>0.001600</td>\n",
1047
+ " <td>0.159900</td>\n",
1048
+ " <td>0.978425</td>\n",
1049
+ " <td>0.966713</td>\n",
1050
+ " <td>0.978211</td>\n",
1051
+ " </tr>\n",
1052
+ " <tr>\n",
1053
+ " <td>8</td>\n",
1054
+ " <td>0.002400</td>\n",
1055
+ " <td>0.125351</td>\n",
1056
+ " <td>0.979504</td>\n",
1057
+ " <td>0.968064</td>\n",
1058
+ " <td>0.979428</td>\n",
1059
+ " </tr>\n",
1060
+ " <tr>\n",
1061
+ " <td>9</td>\n",
1062
+ " <td>0.009400</td>\n",
1063
+ " <td>0.120132</td>\n",
1064
+ " <td>0.980583</td>\n",
1065
+ " <td>0.969631</td>\n",
1066
+ " <td>0.980506</td>\n",
1067
+ " </tr>\n",
1068
+ " <tr>\n",
1069
+ " <td>10</td>\n",
1070
+ " <td>0.001500</td>\n",
1071
+ " <td>0.137864</td>\n",
1072
+ " <td>0.978964</td>\n",
1073
+ " <td>0.967180</td>\n",
1074
+ " <td>0.978825</td>\n",
1075
+ " </tr>\n",
1076
+ " </tbody>\n",
1077
+ "</table><p>"
1078
+ ],
1079
+ "text/plain": [
1080
+ "<IPython.core.display.HTML object>"
1081
+ ]
1082
+ },
1083
+ "metadata": {},
1084
+ "output_type": "display_data"
1085
+ },
1086
+ {
1087
+ "name": "stderr",
1088
+ "output_type": "stream",
1089
+ "text": [
1090
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1091
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1092
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1093
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1094
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1095
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1096
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1097
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1098
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1099
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1100
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1101
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1102
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1103
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1104
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1105
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1106
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1107
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1108
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1109
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
1110
+ ]
1111
+ },
1112
+ {
1113
+ "data": {
1114
+ "text/html": [
1115
+ "\n",
1116
+ " <div>\n",
1117
+ " \n",
1118
+ " <progress value='155' max='155' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
1119
+ " [155/155 00:05]\n",
1120
+ " </div>\n",
1121
+ " "
1122
+ ],
1123
+ "text/plain": [
1124
+ "<IPython.core.display.HTML object>"
1125
+ ]
1126
+ },
1127
+ "metadata": {},
1128
+ "output_type": "display_data"
1129
+ },
1130
+ {
1131
+ "name": "stderr",
1132
+ "output_type": "stream",
1133
+ "text": [
1134
+ "Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n",
1135
+ "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
1136
+ "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
1137
+ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias']\n",
1138
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
1139
+ ]
1140
+ },
1141
+ {
1142
+ "name": "stdout",
1143
+ "output_type": "stream",
1144
+ "text": [
1145
+ "immune\n"
1146
+ ]
1147
+ },
1148
+ {
1149
+ "name": "stderr",
1150
+ "output_type": "stream",
1151
+ "text": [
1152
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1153
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
1154
+ ]
1155
+ },
1156
+ {
1157
+ "data": {
1158
+ "text/html": [
1159
+ "\n",
1160
+ " <div>\n",
1161
+ " \n",
1162
+ " <progress value='17140' max='17140' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
1163
+ " [17140/17140 22:02, Epoch 10/10]\n",
1164
+ " </div>\n",
1165
+ " <table border=\"1\" class=\"dataframe\">\n",
1166
+ " <thead>\n",
1167
+ " <tr style=\"text-align: left;\">\n",
1168
+ " <th>Epoch</th>\n",
1169
+ " <th>Training Loss</th>\n",
1170
+ " <th>Validation Loss</th>\n",
1171
+ " <th>Accuracy</th>\n",
1172
+ " <th>Macro F1</th>\n",
1173
+ " <th>Weighted F1</th>\n",
1174
+ " </tr>\n",
1175
+ " </thead>\n",
1176
+ " <tbody>\n",
1177
+ " <tr>\n",
1178
+ " <td>1</td>\n",
1179
+ " <td>0.288900</td>\n",
1180
+ " <td>0.231582</td>\n",
1181
+ " <td>0.936770</td>\n",
1182
+ " <td>0.868405</td>\n",
1183
+ " <td>0.934816</td>\n",
1184
+ " </tr>\n",
1185
+ " <tr>\n",
1186
+ " <td>2</td>\n",
1187
+ " <td>0.203200</td>\n",
1188
+ " <td>0.206292</td>\n",
1189
+ " <td>0.937354</td>\n",
1190
+ " <td>0.888661</td>\n",
1191
+ " <td>0.939555</td>\n",
1192
+ " </tr>\n",
1193
+ " <tr>\n",
1194
+ " <td>3</td>\n",
1195
+ " <td>0.183500</td>\n",
1196
+ " <td>0.195811</td>\n",
1197
+ " <td>0.944942</td>\n",
1198
+ " <td>0.891149</td>\n",
1199
+ " <td>0.944008</td>\n",
1200
+ " </tr>\n",
1201
+ " <tr>\n",
1202
+ " <td>4</td>\n",
1203
+ " <td>0.151000</td>\n",
1204
+ " <td>0.219581</td>\n",
1205
+ " <td>0.947665</td>\n",
1206
+ " <td>0.906578</td>\n",
1207
+ " <td>0.947093</td>\n",
1208
+ " </tr>\n",
1209
+ " <tr>\n",
1210
+ " <td>5</td>\n",
1211
+ " <td>0.090000</td>\n",
1212
+ " <td>0.247120</td>\n",
1213
+ " <td>0.946693</td>\n",
1214
+ " <td>0.898812</td>\n",
1215
+ " <td>0.945808</td>\n",
1216
+ " </tr>\n",
1217
+ " <tr>\n",
1218
+ " <td>6</td>\n",
1219
+ " <td>0.060400</td>\n",
1220
+ " <td>0.249662</td>\n",
1221
+ " <td>0.948444</td>\n",
1222
+ " <td>0.905014</td>\n",
1223
+ " <td>0.947975</td>\n",
1224
+ " </tr>\n",
1225
+ " <tr>\n",
1226
+ " <td>7</td>\n",
1227
+ " <td>0.071300</td>\n",
1228
+ " <td>0.272767</td>\n",
1229
+ " <td>0.949416</td>\n",
1230
+ " <td>0.911514</td>\n",
1231
+ " <td>0.949748</td>\n",
1232
+ " </tr>\n",
1233
+ " <tr>\n",
1234
+ " <td>8</td>\n",
1235
+ " <td>0.052600</td>\n",
1236
+ " <td>0.305051</td>\n",
1237
+ " <td>0.945331</td>\n",
1238
+ " <td>0.902348</td>\n",
1239
+ " <td>0.944987</td>\n",
1240
+ " </tr>\n",
1241
+ " <tr>\n",
1242
+ " <td>9</td>\n",
1243
+ " <td>0.026900</td>\n",
1244
+ " <td>0.294135</td>\n",
1245
+ " <td>0.948638</td>\n",
1246
+ " <td>0.904058</td>\n",
1247
+ " <td>0.948296</td>\n",
1248
+ " </tr>\n",
1249
+ " <tr>\n",
1250
+ " <td>10</td>\n",
1251
+ " <td>0.034500</td>\n",
1252
+ " <td>0.292029</td>\n",
1253
+ " <td>0.950195</td>\n",
1254
+ " <td>0.908547</td>\n",
1255
+ " <td>0.949753</td>\n",
1256
+ " </tr>\n",
1257
+ " </tbody>\n",
1258
+ "</table><p>"
1259
+ ],
1260
+ "text/plain": [
1261
+ "<IPython.core.display.HTML object>"
1262
+ ]
1263
+ },
1264
+ "metadata": {},
1265
+ "output_type": "display_data"
1266
+ },
1267
+ {
1268
+ "name": "stderr",
1269
+ "output_type": "stream",
1270
+ "text": [
1271
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1272
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1273
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1274
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1275
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1276
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1277
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1278
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1279
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1280
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1281
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1282
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1283
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1284
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1285
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1286
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1287
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1288
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1289
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1290
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
1291
+ ]
1292
+ },
1293
+ {
1294
+ "data": {
1295
+ "text/html": [
1296
+ "\n",
1297
+ " <div>\n",
1298
+ " \n",
1299
+ " <progress value='429' max='429' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
1300
+ " [429/429 00:13]\n",
1301
+ " </div>\n",
1302
+ " "
1303
+ ],
1304
+ "text/plain": [
1305
+ "<IPython.core.display.HTML object>"
1306
+ ]
1307
+ },
1308
+ "metadata": {},
1309
+ "output_type": "display_data"
1310
+ },
1311
+ {
1312
+ "name": "stderr",
1313
+ "output_type": "stream",
1314
+ "text": [
1315
+ "Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n",
1316
+ "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
1317
+ "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
1318
+ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias']\n",
1319
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
1320
+ ]
1321
+ },
1322
+ {
1323
+ "name": "stdout",
1324
+ "output_type": "stream",
1325
+ "text": [
1326
+ "large_intestine\n"
1327
+ ]
1328
+ },
1329
+ {
1330
+ "name": "stderr",
1331
+ "output_type": "stream",
1332
+ "text": [
1333
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1334
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
1335
+ ]
1336
+ },
1337
+ {
1338
+ "data": {
1339
+ "text/html": [
1340
+ "\n",
1341
+ " <div>\n",
1342
+ " \n",
1343
+ " <progress value='33070' max='33070' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
1344
+ " [33070/33070 43:02, Epoch 10/10]\n",
1345
+ " </div>\n",
1346
+ " <table border=\"1\" class=\"dataframe\">\n",
1347
+ " <thead>\n",
1348
+ " <tr style=\"text-align: left;\">\n",
1349
+ " <th>Epoch</th>\n",
1350
+ " <th>Training Loss</th>\n",
1351
+ " <th>Validation Loss</th>\n",
1352
+ " <th>Accuracy</th>\n",
1353
+ " <th>Macro F1</th>\n",
1354
+ " <th>Weighted F1</th>\n",
1355
+ " </tr>\n",
1356
+ " </thead>\n",
1357
+ " <tbody>\n",
1358
+ " <tr>\n",
1359
+ " <td>1</td>\n",
1360
+ " <td>0.306200</td>\n",
1361
+ " <td>0.312431</td>\n",
1362
+ " <td>0.908266</td>\n",
1363
+ " <td>0.786242</td>\n",
1364
+ " <td>0.900768</td>\n",
1365
+ " </tr>\n",
1366
+ " <tr>\n",
1367
+ " <td>2</td>\n",
1368
+ " <td>0.223900</td>\n",
1369
+ " <td>0.248096</td>\n",
1370
+ " <td>0.925101</td>\n",
1371
+ " <td>0.841251</td>\n",
1372
+ " <td>0.920987</td>\n",
1373
+ " </tr>\n",
1374
+ " <tr>\n",
1375
+ " <td>3</td>\n",
1376
+ " <td>0.173600</td>\n",
1377
+ " <td>0.259997</td>\n",
1378
+ " <td>0.925907</td>\n",
1379
+ " <td>0.850348</td>\n",
1380
+ " <td>0.926290</td>\n",
1381
+ " </tr>\n",
1382
+ " <tr>\n",
1383
+ " <td>4</td>\n",
1384
+ " <td>0.162900</td>\n",
1385
+ " <td>0.282306</td>\n",
1386
+ " <td>0.925000</td>\n",
1387
+ " <td>0.873669</td>\n",
1388
+ " <td>0.925531</td>\n",
1389
+ " </tr>\n",
1390
+ " <tr>\n",
1391
+ " <td>5</td>\n",
1392
+ " <td>0.143400</td>\n",
1393
+ " <td>0.254494</td>\n",
1394
+ " <td>0.937903</td>\n",
1395
+ " <td>0.876749</td>\n",
1396
+ " <td>0.937836</td>\n",
1397
+ " </tr>\n",
1398
+ " <tr>\n",
1399
+ " <td>6</td>\n",
1400
+ " <td>0.104500</td>\n",
1401
+ " <td>0.289942</td>\n",
1402
+ " <td>0.934677</td>\n",
1403
+ " <td>0.875333</td>\n",
1404
+ " <td>0.934339</td>\n",
1405
+ " </tr>\n",
1406
+ " <tr>\n",
1407
+ " <td>7</td>\n",
1408
+ " <td>0.080300</td>\n",
1409
+ " <td>0.313914</td>\n",
1410
+ " <td>0.935484</td>\n",
1411
+ " <td>0.877271</td>\n",
1412
+ " <td>0.934986</td>\n",
1413
+ " </tr>\n",
1414
+ " <tr>\n",
1415
+ " <td>8</td>\n",
1416
+ " <td>0.063500</td>\n",
1417
+ " <td>0.339868</td>\n",
1418
+ " <td>0.936290</td>\n",
1419
+ " <td>0.882267</td>\n",
1420
+ " <td>0.936187</td>\n",
1421
+ " </tr>\n",
1422
+ " <tr>\n",
1423
+ " <td>9</td>\n",
1424
+ " <td>0.042500</td>\n",
1425
+ " <td>0.345784</td>\n",
1426
+ " <td>0.938911</td>\n",
1427
+ " <td>0.882963</td>\n",
1428
+ " <td>0.938682</td>\n",
1429
+ " </tr>\n",
1430
+ " <tr>\n",
1431
+ " <td>10</td>\n",
1432
+ " <td>0.038900</td>\n",
1433
+ " <td>0.352199</td>\n",
1434
+ " <td>0.939516</td>\n",
1435
+ " <td>0.885509</td>\n",
1436
+ " <td>0.939497</td>\n",
1437
+ " </tr>\n",
1438
+ " </tbody>\n",
1439
+ "</table><p>"
1440
+ ],
1441
+ "text/plain": [
1442
+ "<IPython.core.display.HTML object>"
1443
+ ]
1444
+ },
1445
+ "metadata": {},
1446
+ "output_type": "display_data"
1447
+ },
1448
+ {
1449
+ "name": "stderr",
1450
+ "output_type": "stream",
1451
+ "text": [
1452
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1453
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1454
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1455
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1456
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1457
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1458
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1459
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1460
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1461
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1462
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1463
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1464
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1465
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1466
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1467
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1468
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1469
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1470
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1471
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
1472
+ ]
1473
+ },
1474
+ {
1475
+ "data": {
1476
+ "text/html": [
1477
+ "\n",
1478
+ " <div>\n",
1479
+ " \n",
1480
+ " <progress value='827' max='827' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
1481
+ " [827/827 00:26]\n",
1482
+ " </div>\n",
1483
+ " "
1484
+ ],
1485
+ "text/plain": [
1486
+ "<IPython.core.display.HTML object>"
1487
+ ]
1488
+ },
1489
+ "metadata": {},
1490
+ "output_type": "display_data"
1491
+ },
1492
+ {
1493
+ "name": "stderr",
1494
+ "output_type": "stream",
1495
+ "text": [
1496
+ "Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n",
1497
+ "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
1498
+ "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
1499
+ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias']\n",
1500
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
1501
+ ]
1502
+ },
1503
+ {
1504
+ "name": "stdout",
1505
+ "output_type": "stream",
1506
+ "text": [
1507
+ "pancreas\n"
1508
+ ]
1509
+ },
1510
+ {
1511
+ "name": "stderr",
1512
+ "output_type": "stream",
1513
+ "text": [
1514
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1515
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
1516
+ ]
1517
+ },
1518
+ {
1519
+ "data": {
1520
+ "text/html": [
1521
+ "\n",
1522
+ " <div>\n",
1523
+ " \n",
1524
+ " <progress value='18280' max='18280' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
1525
+ " [18280/18280 23:32, Epoch 10/10]\n",
1526
+ " </div>\n",
1527
+ " <table border=\"1\" class=\"dataframe\">\n",
1528
+ " <thead>\n",
1529
+ " <tr style=\"text-align: left;\">\n",
1530
+ " <th>Epoch</th>\n",
1531
+ " <th>Training Loss</th>\n",
1532
+ " <th>Validation Loss</th>\n",
1533
+ " <th>Accuracy</th>\n",
1534
+ " <th>Macro F1</th>\n",
1535
+ " <th>Weighted F1</th>\n",
1536
+ " </tr>\n",
1537
+ " </thead>\n",
1538
+ " <tbody>\n",
1539
+ " <tr>\n",
1540
+ " <td>1</td>\n",
1541
+ " <td>0.340100</td>\n",
1542
+ " <td>0.343200</td>\n",
1543
+ " <td>0.896244</td>\n",
1544
+ " <td>0.655661</td>\n",
1545
+ " <td>0.879469</td>\n",
1546
+ " </tr>\n",
1547
+ " <tr>\n",
1548
+ " <td>2</td>\n",
1549
+ " <td>0.178300</td>\n",
1550
+ " <td>0.224033</td>\n",
1551
+ " <td>0.930890</td>\n",
1552
+ " <td>0.859772</td>\n",
1553
+ " <td>0.925342</td>\n",
1554
+ " </tr>\n",
1555
+ " <tr>\n",
1556
+ " <td>3</td>\n",
1557
+ " <td>0.154200</td>\n",
1558
+ " <td>0.208034</td>\n",
1559
+ " <td>0.941284</td>\n",
1560
+ " <td>0.887012</td>\n",
1561
+ " <td>0.939485</td>\n",
1562
+ " </tr>\n",
1563
+ " <tr>\n",
1564
+ " <td>4</td>\n",
1565
+ " <td>0.121200</td>\n",
1566
+ " <td>0.216660</td>\n",
1567
+ " <td>0.940372</td>\n",
1568
+ " <td>0.880716</td>\n",
1569
+ " <td>0.939431</td>\n",
1570
+ " </tr>\n",
1571
+ " <tr>\n",
1572
+ " <td>5</td>\n",
1573
+ " <td>0.099900</td>\n",
1574
+ " <td>0.254255</td>\n",
1575
+ " <td>0.940554</td>\n",
1576
+ " <td>0.889088</td>\n",
1577
+ " <td>0.938300</td>\n",
1578
+ " </tr>\n",
1579
+ " <tr>\n",
1580
+ " <td>6</td>\n",
1581
+ " <td>0.065800</td>\n",
1582
+ " <td>0.267429</td>\n",
1583
+ " <td>0.942743</td>\n",
1584
+ " <td>0.897682</td>\n",
1585
+ " <td>0.942815</td>\n",
1586
+ " </tr>\n",
1587
+ " <tr>\n",
1588
+ " <td>7</td>\n",
1589
+ " <td>0.061200</td>\n",
1590
+ " <td>0.282509</td>\n",
1591
+ " <td>0.945478</td>\n",
1592
+ " <td>0.898797</td>\n",
1593
+ " <td>0.943881</td>\n",
1594
+ " </tr>\n",
1595
+ " <tr>\n",
1596
+ " <td>8</td>\n",
1597
+ " <td>0.036800</td>\n",
1598
+ " <td>0.301781</td>\n",
1599
+ " <td>0.943837</td>\n",
1600
+ " <td>0.903816</td>\n",
1601
+ " <td>0.944163</td>\n",
1602
+ " </tr>\n",
1603
+ " <tr>\n",
1604
+ " <td>9</td>\n",
1605
+ " <td>0.035400</td>\n",
1606
+ " <td>0.317026</td>\n",
1607
+ " <td>0.942560</td>\n",
1608
+ " <td>0.902241</td>\n",
1609
+ " <td>0.942071</td>\n",
1610
+ " </tr>\n",
1611
+ " <tr>\n",
1612
+ " <td>10</td>\n",
1613
+ " <td>0.014200</td>\n",
1614
+ " <td>0.313259</td>\n",
1615
+ " <td>0.946754</td>\n",
1616
+ " <td>0.904955</td>\n",
1617
+ " <td>0.946129</td>\n",
1618
+ " </tr>\n",
1619
+ " </tbody>\n",
1620
+ "</table><p>"
1621
+ ],
1622
+ "text/plain": [
1623
+ "<IPython.core.display.HTML object>"
1624
+ ]
1625
+ },
1626
+ "metadata": {},
1627
+ "output_type": "display_data"
1628
+ },
1629
+ {
1630
+ "name": "stderr",
1631
+ "output_type": "stream",
1632
+ "text": [
1633
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1634
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1635
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1636
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1637
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1638
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1639
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1640
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1641
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1642
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1643
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1644
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1645
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1646
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1647
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1648
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1649
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1650
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1651
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1652
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
1653
+ ]
1654
+ },
1655
+ {
1656
+ "data": {
1657
+ "text/html": [
1658
+ "\n",
1659
+ " <div>\n",
1660
+ " \n",
1661
+ " <progress value='457' max='457' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
1662
+ " [457/457 00:11]\n",
1663
+ " </div>\n",
1664
+ " "
1665
+ ],
1666
+ "text/plain": [
1667
+ "<IPython.core.display.HTML object>"
1668
+ ]
1669
+ },
1670
+ "metadata": {},
1671
+ "output_type": "display_data"
1672
+ },
1673
+ {
1674
+ "name": "stderr",
1675
+ "output_type": "stream",
1676
+ "text": [
1677
+ "Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n",
1678
+ "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
1679
+ "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
1680
+ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias']\n",
1681
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
1682
+ ]
1683
+ },
1684
+ {
1685
+ "name": "stdout",
1686
+ "output_type": "stream",
1687
+ "text": [
1688
+ "liver\n"
1689
+ ]
1690
+ },
1691
+ {
1692
+ "name": "stderr",
1693
+ "output_type": "stream",
1694
+ "text": [
1695
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1696
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
1697
+ ]
1698
+ },
1699
+ {
1700
+ "data": {
1701
+ "text/html": [
1702
+ "\n",
1703
+ " <div>\n",
1704
+ " \n",
1705
+ " <progress value='18690' max='18690' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
1706
+ " [18690/18690 26:56, Epoch 10/10]\n",
1707
+ " </div>\n",
1708
+ " <table border=\"1\" class=\"dataframe\">\n",
1709
+ " <thead>\n",
1710
+ " <tr style=\"text-align: left;\">\n",
1711
+ " <th>Epoch</th>\n",
1712
+ " <th>Training Loss</th>\n",
1713
+ " <th>Validation Loss</th>\n",
1714
+ " <th>Accuracy</th>\n",
1715
+ " <th>Macro F1</th>\n",
1716
+ " <th>Weighted F1</th>\n",
1717
+ " </tr>\n",
1718
+ " </thead>\n",
1719
+ " <tbody>\n",
1720
+ " <tr>\n",
1721
+ " <td>1</td>\n",
1722
+ " <td>0.388500</td>\n",
1723
+ " <td>0.385503</td>\n",
1724
+ " <td>0.878188</td>\n",
1725
+ " <td>0.673887</td>\n",
1726
+ " <td>0.871348</td>\n",
1727
+ " </tr>\n",
1728
+ " <tr>\n",
1729
+ " <td>2</td>\n",
1730
+ " <td>0.315900</td>\n",
1731
+ " <td>0.302775</td>\n",
1732
+ " <td>0.907437</td>\n",
1733
+ " <td>0.754182</td>\n",
1734
+ " <td>0.903474</td>\n",
1735
+ " </tr>\n",
1736
+ " <tr>\n",
1737
+ " <td>3</td>\n",
1738
+ " <td>0.242600</td>\n",
1739
+ " <td>0.321844</td>\n",
1740
+ " <td>0.907972</td>\n",
1741
+ " <td>0.779504</td>\n",
1742
+ " <td>0.905881</td>\n",
1743
+ " </tr>\n",
1744
+ " <tr>\n",
1745
+ " <td>4</td>\n",
1746
+ " <td>0.238600</td>\n",
1747
+ " <td>0.323119</td>\n",
1748
+ " <td>0.911539</td>\n",
1749
+ " <td>0.790922</td>\n",
1750
+ " <td>0.910299</td>\n",
1751
+ " </tr>\n",
1752
+ " <tr>\n",
1753
+ " <td>5</td>\n",
1754
+ " <td>0.160100</td>\n",
1755
+ " <td>0.328203</td>\n",
1756
+ " <td>0.915641</td>\n",
1757
+ " <td>0.793490</td>\n",
1758
+ " <td>0.913836</td>\n",
1759
+ " </tr>\n",
1760
+ " <tr>\n",
1761
+ " <td>6</td>\n",
1762
+ " <td>0.163100</td>\n",
1763
+ " <td>0.348942</td>\n",
1764
+ " <td>0.917425</td>\n",
1765
+ " <td>0.813604</td>\n",
1766
+ " <td>0.916911</td>\n",
1767
+ " </tr>\n",
1768
+ " <tr>\n",
1769
+ " <td>7</td>\n",
1770
+ " <td>0.124100</td>\n",
1771
+ " <td>0.373799</td>\n",
1772
+ " <td>0.916890</td>\n",
1773
+ " <td>0.820355</td>\n",
1774
+ " <td>0.916688</td>\n",
1775
+ " </tr>\n",
1776
+ " <tr>\n",
1777
+ " <td>8</td>\n",
1778
+ " <td>0.118700</td>\n",
1779
+ " <td>0.399474</td>\n",
1780
+ " <td>0.916890</td>\n",
1781
+ " <td>0.818839</td>\n",
1782
+ " <td>0.916640</td>\n",
1783
+ " </tr>\n",
1784
+ " <tr>\n",
1785
+ " <td>9</td>\n",
1786
+ " <td>0.066800</td>\n",
1787
+ " <td>0.414363</td>\n",
1788
+ " <td>0.917603</td>\n",
1789
+ " <td>0.830703</td>\n",
1790
+ " <td>0.917226</td>\n",
1791
+ " </tr>\n",
1792
+ " <tr>\n",
1793
+ " <td>10</td>\n",
1794
+ " <td>0.075800</td>\n",
1795
+ " <td>0.413828</td>\n",
1796
+ " <td>0.919030</td>\n",
1797
+ " <td>0.828149</td>\n",
1798
+ " <td>0.918506</td>\n",
1799
+ " </tr>\n",
1800
+ " </tbody>\n",
1801
+ "</table><p>"
1802
+ ],
1803
+ "text/plain": [
1804
+ "<IPython.core.display.HTML object>"
1805
+ ]
1806
+ },
1807
+ "metadata": {},
1808
+ "output_type": "display_data"
1809
+ },
1810
+ {
1811
+ "name": "stderr",
1812
+ "output_type": "stream",
1813
+ "text": [
1814
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1815
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1816
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1817
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1818
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1819
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1820
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1821
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1822
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1823
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1824
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1825
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1826
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1827
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1828
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1829
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1830
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1831
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1832
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1833
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
1834
+ ]
1835
+ },
1836
+ {
1837
+ "data": {
1838
+ "text/html": [
1839
+ "\n",
1840
+ " <div>\n",
1841
+ " \n",
1842
+ " <progress value='936' max='468' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
1843
+ " [468/468 00:39]\n",
1844
+ " </div>\n",
1845
+ " "
1846
+ ],
1847
+ "text/plain": [
1848
+ "<IPython.core.display.HTML object>"
1849
+ ]
1850
+ },
1851
+ "metadata": {},
1852
+ "output_type": "display_data"
1853
+ }
1854
+ ],
1855
+ "source": [
1856
+ "for organ in organ_list:\n",
1857
+ " print(organ)\n",
1858
+ " organ_trainset = trainset_dict[organ]\n",
1859
+ " organ_evalset = evalset_dict[organ]\n",
1860
+ " organ_label_dict = traintargetdict_dict[organ]\n",
1861
+ " \n",
1862
+ " # set logging steps\n",
1863
+ " logging_steps = round(len(organ_trainset)/geneformer_batch_size/10)\n",
1864
+ " \n",
1865
+ " # reload pretrained model\n",
1866
+ " model = BertForSequenceClassification.from_pretrained(\"/path/to/pretrained_model/\", \n",
1867
+ " num_labels=len(organ_label_dict.keys()),\n",
1868
+ " output_attentions = False,\n",
1869
+ " output_hidden_states = False).to(\"cuda\")\n",
1870
+ " \n",
1871
+ " # define output directory path\n",
1872
+ " current_date = datetime.datetime.now()\n",
1873
+ " datestamp = f\"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}\"\n",
1874
+ " output_dir = f\"/path/to/models/{datestamp}_geneformer_CellClassifier_{organ}_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_F{freeze_layers}/\"\n",
1875
+ " \n",
1876
+ " # ensure not overwriting previously saved model\n",
1877
+ " saved_model_test = os.path.join(output_dir, f\"pytorch_model.bin\")\n",
1878
+ " if os.path.isfile(saved_model_test) == True:\n",
1879
+ " raise Exception(\"Model already saved to this directory.\")\n",
1880
+ "\n",
1881
+ " # make output directory\n",
1882
+ " subprocess.call(f'mkdir {output_dir}', shell=True)\n",
1883
+ " \n",
1884
+ " # set training arguments\n",
1885
+ " training_args = {\n",
1886
+ " \"learning_rate\": max_lr,\n",
1887
+ " \"do_train\": True,\n",
1888
+ " \"do_eval\": True,\n",
1889
+ " \"evaluation_strategy\": \"epoch\",\n",
1890
+ " \"save_strategy\": \"epoch\",\n",
1891
+ " \"logging_steps\": logging_steps,\n",
1892
+ " \"group_by_length\": True,\n",
1893
+ " \"length_column_name\": \"length\",\n",
1894
+ " \"disable_tqdm\": False,\n",
1895
+ " \"lr_scheduler_type\": lr_schedule_fn,\n",
1896
+ " \"warmup_steps\": warmup_steps,\n",
1897
+ " \"weight_decay\": 0.001,\n",
1898
+ " \"per_device_train_batch_size\": geneformer_batch_size,\n",
1899
+ " \"per_device_eval_batch_size\": geneformer_batch_size,\n",
1900
+ " \"num_train_epochs\": epochs,\n",
1901
+ " \"load_best_model_at_end\": True,\n",
1902
+ " \"output_dir\": output_dir,\n",
1903
+ " }\n",
1904
+ " \n",
1905
+ " training_args_init = TrainingArguments(**training_args)\n",
1906
+ "\n",
1907
+ " # create the trainer\n",
1908
+ " trainer = Trainer(\n",
1909
+ " model=model,\n",
1910
+ " args=training_args_init,\n",
1911
+ " data_collator=DataCollatorForCellClassification(),\n",
1912
+ " train_dataset=organ_trainset,\n",
1913
+ " eval_dataset=organ_evalset,\n",
1914
+ " compute_metrics=compute_metrics\n",
1915
+ " )\n",
1916
+ " # train the cell type classifier\n",
1917
+ " trainer.train()\n",
1918
+ " predictions = trainer.predict(organ_evalset)\n",
1919
+ " with open(f\"{output_dir}predictions.pickle\", \"wb\") as fp:\n",
1920
+ " pickle.dump(predictions, fp)\n",
1921
+ " trainer.save_metrics(\"eval\",predictions.metrics)\n",
1922
+ " trainer.save_model(output_dir)"
1923
+ ]
1924
+ }
1925
+ ],
1926
+ "metadata": {
1927
+ "kernelspec": {
1928
+ "display_name": "Python 3 (ipykernel)",
1929
+ "language": "python",
1930
+ "name": "python3"
1931
+ },
1932
+ "language_info": {
1933
+ "codemirror_mode": {
1934
+ "name": "ipython",
1935
+ "version": 3
1936
+ },
1937
+ "file_extension": ".py",
1938
+ "mimetype": "text/x-python",
1939
+ "name": "python",
1940
+ "nbconvert_exporter": "python",
1941
+ "pygments_lexer": "ipython3",
1942
+ "version": "3.10.11"
1943
+ },
1944
+ "vscode": {
1945
+ "interpreter": {
1946
+ "hash": "eba1599a1f7e611c14c87ccff6793920aa63510b01fc0e229d6dd014149b8829"
1947
+ }
1948
+ }
1949
+ },
1950
+ "nbformat": 4,
1951
+ "nbformat_minor": 5
1952
+ }
examples/extract_and_plot_cell_embeddings.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
examples/gene_classification.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
examples/hyperparam_optimiz_for_disease_classifier.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # hyperparameter optimization with raytune for disease classification
5
+
6
+ # imports
7
+ import os
8
+ import subprocess
9
+ GPU_NUMBER = [0,1,2,3]
10
+ os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
11
+ os.environ["NCCL_DEBUG"] = "INFO"
12
+ os.environ["CONDA_OVERRIDE_GLIBC"] = "2.56"
13
+ os.environ["LD_LIBRARY_PATH"] = "/path/to/miniconda3/lib:/path/to/sw/lib:/path/to/sw/lib"
14
+
15
+ # initiate runtime environment for raytune
16
+ import pyarrow # must occur prior to ray import
17
+ import ray
18
+ from ray import tune
19
+ from ray.tune import ExperimentAnalysis
20
+ from ray.tune.suggest.hyperopt import HyperOptSearch
21
+ ray.shutdown() #engage new ray session
22
+ runtime_env = {"conda": "base",
23
+ "env_vars": {"LD_LIBRARY_PATH": "/path/to/miniconda3/lib:/path/to/sw/lib:/path/to/sw/lib"}}
24
+ ray.init(runtime_env=runtime_env)
25
+
26
+ def initialize_ray_with_check(ip_address):
27
+ """
28
+ Initialize Ray with a specified IP address and check its status and accessibility.
29
+
30
+ Args:
31
+ - ip_address (str): The IP address (with port) to initialize Ray.
32
+
33
+ Returns:
34
+ - bool: True if initialization was successful and dashboard is accessible, False otherwise.
35
+ """
36
+ try:
37
+ ray.init(address=ip_address)
38
+ print(ray.nodes())
39
+
40
+ services = ray.get_webui_url()
41
+ if not services:
42
+ raise RuntimeError("Ray dashboard is not accessible.")
43
+ else:
44
+ print(f"Ray dashboard is accessible at: {services}")
45
+ return True
46
+ except Exception as e:
47
+ print(f"Error initializing Ray: {e}")
48
+ return False
49
+
50
+ # Usage:
51
+ ip = 'your_ip:xxxx' # Replace with your actual IP address and port
52
+ if initialize_ray_with_check(ip):
53
+ print("Ray initialized successfully.")
54
+ else:
55
+ print("Error during Ray initialization.")
56
+
57
+ import datetime
58
+ import numpy as np
59
+ import pandas as pd
60
+ import random
61
+ import seaborn as sns; sns.set()
62
+ from collections import Counter
63
+ from datasets import load_from_disk
64
+ from scipy.stats import ranksums
65
+ from sklearn.metrics import accuracy_score
66
+ from transformers import BertForSequenceClassification
67
+ from transformers import Trainer
68
+ from transformers.training_args import TrainingArguments
69
+
70
+ from geneformer import DataCollatorForCellClassification
71
+
72
+ # number of CPU cores
73
+ num_proc=30
74
+
75
+ # load train dataset with columns:
76
+ # cell_type (annotation of each cell's type)
77
+ # disease (healthy or disease state)
78
+ # individual (unique ID for each patient)
79
+ # length (length of that cell's rank value encoding)
80
+ train_dataset=load_from_disk("/path/to/disease_train_data.dataset")
81
+
82
+ # filter dataset for given cell_type
83
+ def if_cell_type(example):
84
+ return example["cell_type"].startswith("Cardiomyocyte")
85
+
86
+ trainset_v2 = train_dataset.filter(if_cell_type, num_proc=num_proc)
87
+
88
+ # create dictionary of disease states : label ids
89
+ target_names = ["healthy", "disease1", "disease2"]
90
+ target_name_id_dict = dict(zip(target_names,[i for i in range(len(target_names))]))
91
+
92
+ trainset_v3 = trainset_v2.rename_column("disease","label")
93
+
94
+ # change labels to numerical ids
95
+ def classes_to_ids(example):
96
+ example["label"] = target_name_id_dict[example["label"]]
97
+ return example
98
+
99
+ trainset_v4 = trainset_v3.map(classes_to_ids, num_proc=num_proc)
100
+
101
+ # separate into train, validation, test sets
102
+ indiv_set = set(trainset_v4["individual"])
103
+ random.seed(42)
104
+ train_indiv = random.sample(indiv_set,round(0.7*len(indiv_set)))
105
+ eval_indiv = [indiv for indiv in indiv_set if indiv not in train_indiv]
106
+ valid_indiv = random.sample(eval_indiv,round(0.5*len(eval_indiv)))
107
+ test_indiv = [indiv for indiv in eval_indiv if indiv not in valid_indiv]
108
+
109
+ def if_train(example):
110
+ return example["individual"] in train_indiv
111
+
112
+ classifier_trainset = trainset_v4.filter(if_train,num_proc=num_proc).shuffle(seed=42)
113
+
114
+ def if_valid(example):
115
+ return example["individual"] in valid_indiv
116
+
117
+ classifier_validset = trainset_v4.filter(if_valid,num_proc=num_proc).shuffle(seed=42)
118
+
119
+ # define output directory path
120
+ current_date = datetime.datetime.now()
121
+ datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
122
+ output_dir = f"/path/to/models/{datestamp}_geneformer_DiseaseClassifier/"
123
+
124
+ # ensure not overwriting previously saved model
125
+ saved_model_test = os.path.join(output_dir, f"pytorch_model.bin")
126
+ if os.path.isfile(saved_model_test) == True:
127
+ raise Exception("Model already saved to this directory.")
128
+
129
+ # make output directory
130
+ subprocess.call(f'mkdir {output_dir}', shell=True)
131
+
132
+ # set training parameters
133
+ # how many pretrained layers to freeze
134
+ freeze_layers = 2
135
+ # batch size for training and eval
136
+ geneformer_batch_size = 12
137
+ # number of epochs
138
+ epochs = 1
139
+ # logging steps
140
+ logging_steps = round(len(classifier_trainset)/geneformer_batch_size/10)
141
+
142
+ # define function to initiate model
143
+ def model_init():
144
+ model = BertForSequenceClassification.from_pretrained("/path/to/pretrained_model/",
145
+ num_labels=len(target_names),
146
+ output_attentions = False,
147
+ output_hidden_states = False)
148
+ if freeze_layers is not None:
149
+ modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
150
+ for module in modules_to_freeze:
151
+ for param in module.parameters():
152
+ param.requires_grad = False
153
+
154
+ model = model.to("cuda:0")
155
+ return model
156
+
157
+ # define metrics
158
+ # note: macro f1 score recommended for imbalanced multiclass classifiers
159
+ def compute_metrics(pred):
160
+ labels = pred.label_ids
161
+ preds = pred.predictions.argmax(-1)
162
+ # calculate accuracy using sklearn's function
163
+ acc = accuracy_score(labels, preds)
164
+ return {
165
+ 'accuracy': acc,
166
+ }
167
+
168
+ # set training arguments
169
+ training_args = {
170
+ "do_train": True,
171
+ "do_eval": True,
172
+ "evaluation_strategy": "steps",
173
+ "eval_steps": logging_steps,
174
+ "logging_steps": logging_steps,
175
+ "group_by_length": True,
176
+ "length_column_name": "length",
177
+ "disable_tqdm": True,
178
+ "skip_memory_metrics": True, # memory tracker causes errors in raytune
179
+ "per_device_train_batch_size": geneformer_batch_size,
180
+ "per_device_eval_batch_size": geneformer_batch_size,
181
+ "num_train_epochs": epochs,
182
+ "load_best_model_at_end": True,
183
+ "output_dir": output_dir,
184
+ }
185
+
186
+ training_args_init = TrainingArguments(**training_args)
187
+
188
+ # create the trainer
189
+ trainer = Trainer(
190
+ model_init=model_init,
191
+ args=training_args_init,
192
+ data_collator=DataCollatorForCellClassification(),
193
+ train_dataset=classifier_trainset,
194
+ eval_dataset=classifier_validset,
195
+ compute_metrics=compute_metrics,
196
+ )
197
+
198
+ # specify raytune hyperparameter search space
199
+ ray_config = {
200
+ "num_train_epochs": tune.choice([epochs]),
201
+ "learning_rate": tune.loguniform(1e-6, 1e-3),
202
+ "weight_decay": tune.uniform(0.0, 0.3),
203
+ "lr_scheduler_type": tune.choice(["linear","cosine","polynomial"]),
204
+ "warmup_steps": tune.uniform(100, 2000),
205
+ "seed": tune.uniform(0,100),
206
+ "per_device_train_batch_size": tune.choice([geneformer_batch_size])
207
+ }
208
+
209
+ hyperopt_search = HyperOptSearch(
210
+ metric="eval_accuracy", mode="max")
211
+
212
+ # optimize hyperparameters
213
+ trainer.hyperparameter_search(
214
+ direction="maximize",
215
+ backend="ray",
216
+ resources_per_trial={"cpu":8,"gpu":1},
217
+ hp_space=lambda _: ray_config,
218
+ search_alg=hyperopt_search,
219
+ n_trials=100, # number of trials
220
+ progress_reporter=tune.CLIReporter(max_report_frequency=600,
221
+ sort_by_metric=True,
222
+ max_progress_rows=100,
223
+ mode="max",
224
+ metric="eval_accuracy",
225
+ metric_columns=["loss", "eval_loss", "eval_accuracy"])
226
+ )
examples/in_silico_perturbation.ipynb ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "e10ac0c9-40ce-41fb-b6fa-3d62b76f2e57",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from geneformer import InSilicoPerturber\n",
11
+ "from geneformer import InSilicoPerturberStats"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": null,
17
+ "id": "67b44366-f255-4415-a865-6a27a8ffcce7",
18
+ "metadata": {
19
+ "tags": []
20
+ },
21
+ "outputs": [],
22
+ "source": [
23
+ "# in silico perturbation in deletion mode to determine genes whose \n",
24
+ "# deletion in the dilated cardiomyopathy (dcm) state significantly shifts\n",
25
+ "# the embedding towards non-failing (nf) state\n",
26
+ "isp = InSilicoPerturber(perturb_type=\"delete\",\n",
27
+ " perturb_rank_shift=None,\n",
28
+ " genes_to_perturb=\"all\",\n",
29
+ " combos=0,\n",
30
+ " anchor_gene=None,\n",
31
+ " model_type=\"CellClassifier\",\n",
32
+ " num_classes=3,\n",
33
+ " emb_mode=\"cell\",\n",
34
+ " cell_emb_style=\"mean_pool\",\n",
35
+ " filter_data={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]},\n",
36
+ " cell_states_to_model={'state_key': 'disease', \n",
37
+ " 'start_state': 'dcm', \n",
38
+ " 'goal_state': 'nf', \n",
39
+ " 'alt_states': ['hcm']},\n",
40
+ " max_ncells=2000,\n",
41
+ " emb_layer=0,\n",
42
+ " forward_batch_size=400,\n",
43
+ " nproc=16)"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "execution_count": null,
49
+ "id": "0525a663-871a-4ce0-a135-cc203817ffa9",
50
+ "metadata": {},
51
+ "outputs": [],
52
+ "source": [
53
+ "# outputs intermediate files from in silico perturbation\n",
54
+ "isp.perturb_data(\"path/to/model\",\n",
55
+ " \"path/to/input_data\",\n",
56
+ " \"path/to/output_directory\",\n",
57
+ " \"output_prefix\")"
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "execution_count": null,
63
+ "id": "f8aadabb-516a-4dc0-b307-6de880e64e26",
64
+ "metadata": {},
65
+ "outputs": [],
66
+ "source": [
67
+ "ispstats = InSilicoPerturberStats(mode=\"goal_state_shift\",\n",
68
+ " genes_perturbed=\"all\",\n",
69
+ " combos=0,\n",
70
+ " anchor_gene=None,\n",
71
+ " cell_states_to_model={\"disease\":([\"dcm\"],[\"nf\"],[\"hcm\"])})"
72
+ ]
73
+ },
74
+ {
75
+ "cell_type": "code",
76
+ "execution_count": null,
77
+ "id": "ffecfae6-e737-43e3-99e9-fa37ff46610b",
78
+ "metadata": {},
79
+ "outputs": [],
80
+ "source": [
81
+ "# extracts data from intermediate files and processes stats to output in final .csv\n",
82
+ "ispstats.get_stats(\"path/to/input_data\",\n",
83
+ " None,\n",
84
+ " \"path/to/output_directory\",\n",
85
+ " \"output_prefix\")"
86
+ ]
87
+ }
88
+ ],
89
+ "metadata": {
90
+ "kernelspec": {
91
+ "display_name": "Python 3 (ipykernel)",
92
+ "language": "python",
93
+ "name": "python3"
94
+ },
95
+ "language_info": {
96
+ "codemirror_mode": {
97
+ "name": "ipython",
98
+ "version": 3
99
+ },
100
+ "file_extension": ".py",
101
+ "mimetype": "text/x-python",
102
+ "name": "python",
103
+ "nbconvert_exporter": "python",
104
+ "pygments_lexer": "ipython3",
105
+ "version": "3.10.11"
106
+ }
107
+ },
108
+ "nbformat": 4,
109
+ "nbformat_minor": 5
110
+ }
examples/pretraining_new_model/obtain_nonzero_median_digests.ipynb ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "charged-worcester",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Obtain non-zero median expression value of each gene across Genecorpus-30M"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "markdown",
13
+ "id": "28e87f2a-a33e-4fe3-81af-ad4cd62fcc1b",
14
+ "metadata": {},
15
+ "source": [
16
+ "#### Upon request, we are providing the code that we used for obtaining the non-zero median expression value of each gene across the broad range of cell types represented in Genecorpus-30M that we use as a normalization factor to prioritize genes that uniquely distinguish cell state.\n",
17
+ "\n",
18
+ "#### Please read the important information below before using this code.\n",
19
+ "\n",
20
+ "#### If using Geneformer, to ensure consistency of the normalization factor used for each gene for all future datasets, <ins>**users should use the Geneformer transcriptome tokenizer to tokenize their datasets and should not re-calculate this normalization factor for their individual dataset** </ins>. This code for re-calculating the normalization factor should only be used by users who are pretraining a new model from scratch with a new pretraining corpus other than Genecorpus-30M.\n",
21
+ "\n",
22
+ "#### It is critical that this calculation is performed on a large-scale pretraining corpus that has tens of millions of cells from a broad range of human tissues. <ins>**The richness of variable cell states in the pretraining corpus is what allows this normalization factor to accomplish the goal of prioritizing genes that uniquely distinguish cell states.** </ins> This normalization factor for each gene is calculated once from the large-scale pretraining corpus and is used for all future datasets presented to the model. \n",
23
+ "\n",
24
+ "#### Of note, as discussed in the Methods, we only included droplet-based sequencing platforms in the pretraining corpus to assure expression value unit comparability for the calculation of this normalization factor. Users wishing to pretrain a new model from scratch with a new pretraining corpus should choose either droplet-based or plate-based platforms for calculating this normalization factor, or they should exercise caution that including both platforms may cause unintended effects on the results. Once the normalization factor is calculated however, data from any platform can be used with the model because the expression value units will be consistent within each individual cell.\n",
25
+ "\n",
26
+ "#### Please see the Methods in the manuscript for a description of the procedure enacted by this code, an excerpt of which is below for convenience:\n",
27
+ "\n",
28
+ "#### \"To accomplish this, we first calculated the non-zero median value of expression of each detected gene across all cells passing quality filtering from the entire Genecorpus-30M. We aggregated the transcript count distribution for each gene in a memory-efficient manner by scanning through chunks of .loom data using loompy, normalizing the gene transcript counts in each cell by the total transcript count of that cell to account for varying sequencing depth and updating the normalized count distribution of the gene within the t-digest data structure developed for accurate online accumulation of rank-based statistics. We then normalized the genes in each single-cell transcriptome by the non-zero median value of expression of that gene across Genecorpus-30M and ordered the genes by the rank of their normalized expression in that specific cell. Of note, we opted to use the non-zero median value of expression rather than include zeros in the distribution so as not to weight the value by tissue representation within Genecorpus-30M, assuming that a representative range of transcript values would be observed within the cells in which each gene was detected. This normalization factor for each gene is calculated once from the pretraining corpus and is used for all future datasets presented to the model. The provided tokenizer code includes this normalization procedure and should be used for tokenizing new datasets presented to Geneformer to ensure consistency of the normalization factor used for each gene.\""
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "code",
33
+ "execution_count": 1,
34
+ "id": "textile-destruction",
35
+ "metadata": {},
36
+ "outputs": [],
37
+ "source": [
38
+ "import os\n",
39
+ "import numpy as np\n",
40
+ "import loompy as lp\n",
41
+ "import pandas as pd\n",
42
+ "import crick\n",
43
+ "import pickle\n",
44
+ "import math\n",
45
+ "from tqdm.notebook import tqdm"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "markdown",
50
+ "id": "4af8cfef-05f2-47e0-b8d2-71ca025059c7",
51
+ "metadata": {
52
+ "tags": []
53
+ },
54
+ "source": [
55
+ "### The following code is an example of how the nonzero median expression values are obtained for a single input file. This calculation should be run as a script to be parallelized for all dataset files."
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": 30,
61
+ "id": "physical-intro",
62
+ "metadata": {},
63
+ "outputs": [],
64
+ "source": [
65
+ "input_file = \"study1.loom\"\n",
66
+ "current_database = \"database1\"\n",
67
+ "\n",
68
+ "rootdir = f\"/path/to/{current_database}/data/\"\n",
69
+ "output_file = input_file.replace(\".loom\", \".gene_median_digest_dict.pickle\")\n",
70
+ "outdir = rootdir.replace(\"/data/\", \"/tdigest/\")\n",
71
+ "\n",
72
+ "with lp.connect(f\"{rootdir}{input_file}\") as data:\n",
73
+ " # define coordinates of protein-coding or miRNA genes\n",
74
+ " coding_miRNA_loc = np.where((data.ra.gene_type == \"protein_coding\") | (data.ra.gene_type == \"miRNA\"))[0]\n",
75
+ " coding_miRNA_genes = data.ra[\"ensembl_id\"][coding_miRNA_loc]\n",
76
+ " \n",
77
+ " # initiate tdigests\n",
78
+ " median_digests = [crick.tdigest.TDigest() for _ in range(len(coding_miRNA_loc))]\n",
79
+ " \n",
80
+ " # initiate progress meters\n",
81
+ " progress = tqdm(total=len(coding_miRNA_loc))\n",
82
+ " last_view_row = 0\n",
83
+ " progress.update(0)\n",
84
+ " \n",
85
+ " for (ix, selection, view) in data.scan(items=coding_miRNA_loc, axis=0):\n",
86
+ " # define coordinates of cells passing filter\n",
87
+ " filter_passed_loc = np.where(view.ca.filter_pass == 1)[0]\n",
88
+ " subview = view.view[:, filter_passed_loc]\n",
89
+ " # normalize by total counts per cell and multiply by 10,000 to allocate bits to precision\n",
90
+ " subview_norm_array = subview[:,:]/subview.ca.n_counts*10_000\n",
91
+ " # if integer, convert to float to prevent error with filling with nan\n",
92
+ " if np.issubdtype(subview_norm_array.dtype, np.integer):\n",
93
+ " subview_norm_array = subview_norm_array.astype(np.float32)\n",
94
+ " # mask zeroes from distribution tdigest by filling with nan\n",
95
+ " nonzero_data = np.ma.masked_equal(subview_norm_array, 0.0).filled(np.nan)\n",
96
+ " # update tdigests\n",
97
+ " [median_digests[i+last_view_row].update(nonzero_data[i,:]) for i in range(nonzero_data.shape[0])]\n",
98
+ " # update progress meters\n",
99
+ " progress.update(view.shape[0])\n",
100
+ " last_view_row = last_view_row + view.shape[0]\n",
101
+ " \n",
102
+ "median_digest_dict = dict(zip(coding_miRNA_genes, median_digests))\n",
103
+ "with open(f\"{outdir}{output_file}\", \"wb\") as fp:\n",
104
+ " pickle.dump(median_digest_dict, fp)"
105
+ ]
106
+ },
107
+ {
108
+ "cell_type": "markdown",
109
+ "id": "190a3754-aafa-4ccf-ba97-951c94ea3030",
110
+ "metadata": {
111
+ "tags": []
112
+ },
113
+ "source": [
114
+ "### After the above code is run as a script in parallel for all datasets to obtain the nonzero median tdigests for their contained genes, the following code can be run to merge the tdigests across all datasets."
115
+ ]
116
+ },
117
+ {
118
+ "cell_type": "code",
119
+ "execution_count": 2,
120
+ "id": "distributed-riding",
121
+ "metadata": {},
122
+ "outputs": [],
123
+ "source": [
124
+ "# merge new tdigests into total tdigest dict\n",
125
+ "def merge_digest(dict_key_ensembl_id, dict_value_tdigest, new_tdigest_dict):\n",
126
+ " new_gene_tdigest = new_tdigest_dict.get(dict_key_ensembl_id)\n",
127
+ " if new_gene_tdigest is not None:\n",
128
+ " dict_value_tdigest.merge(new_gene_tdigest)\n",
129
+ " return dict_value_tdigest\n",
130
+ " elif new_gene_tdigest is None:\n",
131
+ " return dict_value_tdigest"
132
+ ]
133
+ },
134
+ {
135
+ "cell_type": "code",
136
+ "execution_count": null,
137
+ "id": "distinct-library",
138
+ "metadata": {},
139
+ "outputs": [],
140
+ "source": [
141
+ "# use tdigest1.merge(tdigest2) to merge tdigest1, tdigest2, ...tdigestn\n",
142
+ "# then, extract median by tdigest1.quantile(0.5)\n",
143
+ "\n",
144
+ "databases = [\"database1\", \"database2\", \"...databaseN\"]\n",
145
+ "\n",
146
+ "# obtain gene list\n",
147
+ "gene_info = pd.read_csv(\"/path/to/gene_info_table.csv\", index_col=0)\n",
148
+ "func_gene_list = [i for i in gene_info[(gene_info[\"gene_type\"] == \"protein_coding\") | (gene_info[\"gene_type\"] == \"miRNA\")][\"ensembl_id\"]]\n",
149
+ "\n",
150
+ "# initiate tdigests\n",
151
+ "median_digests = [crick.tdigest.TDigest() for _ in range(len(func_gene_list))]\n",
152
+ "total_tdigest_dict = dict(zip(func_gene_list, median_digests))\n",
153
+ "\n",
154
+ "# merge tdigests\n",
155
+ "for current_database in databases:\n",
156
+ " rootdir = f\"/path/to/{current_database}/tdigest/\"\n",
157
+ " \n",
158
+ " for subdir, dirs, files in os.walk(rootdir):\t\n",
159
+ " for file in files:\n",
160
+ " if file.endswith(\".gene_median_digest_dict.pickle\"):\n",
161
+ " with open(f\"{rootdir}{file}\", \"rb\") as fp:\n",
162
+ " tdigest_dict = pickle.load(fp)\n",
163
+ " total_tdigest_dict = {k: merge_digest(k,v,tdigest_dict) for k, v in total_tdigest_dict.items()}\n",
164
+ "\n",
165
+ "# save dict of merged tdigests\n",
166
+ "with open(f\"/path/to/total_gene_tdigest_dict.pickle\", \"wb\") as fp:\n",
167
+ " pickle.dump(total_tdigest_dict, fp)\n",
168
+ "\n",
169
+ "# extract medians and save dict\n",
170
+ "total_median_dict = {k: v.quantile(0.5) for k, v in total_tdigest_dict.items()}\n",
171
+ "with open(f\"/path/to/total_gene_median_dict.pickle\", \"wb\") as fp:\n",
172
+ " pickle.dump(total_median_dict, fp)\n",
173
+ "\n",
174
+ "# save dict of only detected genes' medians \n",
175
+ "detected_median_dict = {k: v for k, v in total_median_dict.items() if not math.isnan(v)}\n",
176
+ "with open(f\"/path/to/detected_gene_median_dict.pickle\", \"wb\") as fp:\n",
177
+ " pickle.dump(detected_median_dict, fp)"
178
+ ]
179
+ },
180
+ {
181
+ "cell_type": "markdown",
182
+ "id": "e8e17ad6-79ac-4f34-aa0c-1eaa1bace2e5",
183
+ "metadata": {
184
+ "tags": []
185
+ },
186
+ "source": [
187
+ "### The below code displays some characteristics of the genes detected in the pretraining corpus."
188
+ ]
189
+ },
190
+ {
191
+ "cell_type": "code",
192
+ "execution_count": 38,
193
+ "id": "decent-switzerland",
194
+ "metadata": {},
195
+ "outputs": [],
196
+ "source": [
197
+ "gene_detection_counts_dict = {k: v.size() for k, v in total_tdigest_dict.items()}"
198
+ ]
199
+ },
200
+ {
201
+ "cell_type": "code",
202
+ "execution_count": 44,
203
+ "id": "polished-innocent",
204
+ "metadata": {},
205
+ "outputs": [
206
+ {
207
+ "name": "stderr",
208
+ "output_type": "stream",
209
+ "text": [
210
+ "/home1/ct68/miniconda3/lib/python3.8/site-packages/seaborn/distributions.py:2557: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).\n",
211
+ " warnings.warn(msg, FutureWarning)\n"
212
+ ]
213
+ },
214
+ {
215
+ "data": {
216
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAABR8AAAMRCAYAAABlG8GWAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAABcSAAAXEgFnn9JSAAC/KUlEQVR4nOzdd5hjZ3X48e/Z7l2vK240G0wzBgOmmmp6NT/TQmjBlCS0ACGE3gklJARCLyGYGgi9hRqwgYBpxnRMMTZgsI1x2+Lt5/fHe8d7dUfSSBpdaWb2+3kePaN7dcs7M1ea0dF5z4nMRJIkSZIkSZLGbdm0ByBJkiRJkiRpaTL4KEmSJEmSJKkVBh8lSZIkSZIktcLgoyRJkiRJkqRWGHyUJEmSJEmS1AqDj5IkSZIkSZJaYfBRkiRJkiRJUisMPkqSJEmSJElqhcFHSZIkSZIkSa0w+ChJkiRJkiSpFQYfJUmSJEmSJLXC4KMkSZIkSZKkVhh8lCRJkiRJktQKg4+SJEmSJEmSWrFi2gOQJEmTERFXAY4GDgf2A1YDG4BLgPOA0zPzT9Ma3yAi4mTgkbVVj8rMk/tsfwTwm9qqczLziDbGJqm3iDib8toz4xqZefZ0RqM9XUQcD3yltsq/DZLUIoOPkiQtYRFxLHAScF863/j32v4c4AvAe4GvZWa2OkANLCJOAe4wzmNmZozzeJImw9cDSdJi4rRrSZL6iIjlEXFpRGR1e/uA271m0mNtjOeGEfFF4HvA3zFA4LFyOPDXwKnAryLiERHh/wuSJEmSRuKbCUmS+rsJsE9t+Ss9trtZY7tT2hpQP1E8AzgduEufTRO4mDLtuld24zWBdwPfGOsgJUmSJO0xnHYtSVJ/zWltp/TY7k61+7uAr7Yymj6qDMX/AB7V5eE/AB8DPgt8F7gwM3dW+60GrgPcFrgf5XtZXtv3ei0OW6M7DXjXtAchaUHw9UCStGAZfJQkqb/ja/fPzMw/9NiuHnz8QWZe3N6QenoDswOPG4FXAK/JzMu77ZSZW4EfVbc3R8SRwHMptSKtAbZwnZmZb5n2IKRB2Myjdb4eSJIWLKddS5LUQ5VJeNvaqq5TriNiFXDrubZrU0Q8Gnh8Y/X5wO0y8+W9Ao/dZOavM/PRlO/prDEOU5IkSdIexuCjJEm93RjYr7bcK6h4K2BtbfmUdobTXURcGWg2uLkEuG1mnjHqcTPzNErNyy+MPDhJkiRJezSDj5Ik9bZY6j0+n85mNwBPycxfzffAmXkZ8JfzPY4kSZKkPZM1HyVJ6u342v2fZOYFPbarBx+/n5mXtjekThFxILPrPH41M989rnNk5q5R9ouI/SlZoYcAB1G6av8J+A1wWmZuG9cY2xQRhwI3Ao4A9gVWAZcDlwG/BX49jkCveouIFcAtgBsABwJbKE2UvtfWzz4i1lGu3+sA+1M+WPhjZg7U1CMiDgeOpVz7BwKbgQuAnwI/zMxeXeYHHV8ARwLHAIdRPoCI6jwXAedQ6gCeN+Lx96mOfR1KBvhaYCuwCTgXOBv4aWZun8/3MR8RsS9wG+DawN7ApZTr4quZeeGYzrEGuD1wdeBgys/gt8C3MvO34zjHYlS7/q5L+dnsQ0lsuRi4kPLc/E1L574ecDRwJeAAYAfld/8r4EeZ+acxnWcl5TXgBpTXgMspz+Fv+ZovScMx+ChJUhdVvcfb1Vad0mO7vYBbzrVdix4GrG6se9OEx3CFiFgOPBL4a+DmdHbNrtsYEZ8BXpyZP5vU+AZVBbseW91uOsD2FwFfAz4MvH/UgO1CEBF3BT5H5wyZl2TmC4c4xq2BU+n8X/M1mfm0Hts3A3HXyMyzI2It8EzgiZQAXrd9Twf+KTM/NsT4TgLeWVt1amYeXz12beBFwAOY/dyCPh2FI2I/4O8p2cLX6TOE8yLivcArMvOiQcddnWN/4OnAwylBn7m2/y3wReA9mXnqANvfHXgycDfmfq+wJSK+A3wMOLlfo62IOBs4vLbqGpl5dp/tXwTUr7l3ZeZJ1WPXBP4JeCCwssvuGRFfAp6VmafP8T30Ov9VgJdRroO9e2zzDcpr2Bd6jPnFmfmiUc6/EFXX3v2Be1MCsl2fk7XtzwXeAbx+vsHgqhHaPwInAFfus2lGxA+BjwLvyMxzRzjXeuA5wOPoLL1S3+ZnwAsy88PDHl+S9kROu5Yk7bEiInvdgJ2UTIcZT+yx3WZKJtyMf+hz3ONb+Dbu21i+iBIImLiIuC0lq+sdlGyRXoFHKG/mHwz8KCJeXmXRLAhVxtrpwJsZIPBYOQD4f8B7mD0FflHJzC8CL2msfl4VlJpTRFwJ+CCdgatvUoKIA4uIawDfA15A/yDHscBHI+JDVZbayKrGTT8CHkr3wGO/ff+W0qDpBfQPPAIcSgkgnhURDxziHHcDfkkJjMwZeKxcHXgM8O9zHHtNRHyQEni+F4MlKayhfEjzb8wuU9GKiHgI8GPgIXQPPELJAL0r8K2I+KsRz/FzygcpXQOPlVsDn4+If68+sFqyIuIY4DzgP4D7MUfgsXIVyvPhVxFxnxHPuyYi3kT5ffwt/QOPUH73NwJeDHx6hPPdiPIa8Cx6BB4rRwEfiog3L/XfvSSNgy+UkiQtUlV23m0aq78xjenMEfEI4H/pHnRJyhTljV0eWw48G/ivhfAGLiIOoGQw3rDHJpspUwo3TWxQ0/FS4PO15WXAeyPiqv12qn6H7wPq210I/MWQ03OvRLmertdYv5Ey9bGbBwIfj4ihgoYzIuKRlMB5c/9LgJ5jj4jlEfF64C10fmAxYydlKurWLo/tC/x3RPzdAOO7DfApugd9EthA+Vl3O88gPgz8RY/HtgJ/pjyPp5bVGxEPp1xfe9VW76L8fLu97q0ATo6IOw5xjkcC76V70HHmXDsb65/M7KZfS81aOj9oq9tOuT66vcZDuc4/GREPHuaEVTO1rwKPp3cw/DLKtd/1EEOe7waUxnKHNx66jN6v+Y+jBFglSX1M/Z98SZI0sqMomUd135n0IKqMlnfR+cb0YuBVwHHAmszcNzPXUzJJHgB8o3GYB1Ma50zby4Cr1ZYTeDdlCur+mbkuMw/KzL0p3+/RlAyskyn1LJeEatr4w4Hf1VZfiRIo65VtBuV3eLf6oYCHZ+bvhxzCG4BrVPd/TalreqXMXJ+ZaykZVU+m1F+ruzvwyiHPBSU7cKZcwS7KlOw7AKszc39KQPIISjZU0yuBJzXW/Yoy/fr6wMrMPCAz11TneQKlHuOMAF4bEXeiv7fR+Ry7DHg5pbzBuszcp7o211ACZzenBG0+Se+AbRlACQrdu7H6a5Tp41fJzDWZeaXM3JcSBDoCuA/wauDMOcY9LjegZN0Fpebnayh1QFdVP9/V1TZvoDNAGsB/VCUh+qqy3t5O53ukXcBbKdncqzPzAMrv4UaU17iZYO+TgXuM/N0tHpcDn6Fc87cB9svMVdX1sR5YR8kIfTWdwciZ38ORg5ykKmnyGcp1XHcxJTP7lpTfx76ZuQ/ld3ITyvPrS8wOEM9lL8qsgZkPED5G+X2uq86xN+V15x8oH0jUPSci5sp2lqQ9Wsyz1rUkSYtWRDyux0PLgNexe9rwl4EPddluNfDa2vIngc/2OeUnM/MPQw6zp4i4H6WuVd39MvPj4zrHAGM4HPg+nRlfXwQekZnnz7Hv8ygZdjN2ATfLzO/32edkylTIGY/KzJP7bH8EpcHNjHMy84ge266iBLP2ra3+y8z8YK/jd9n/AcDHM7NvsGcUEXEKnVNbr6iB15aIuBUl86gecPz3zHxql23vQsmWrAduXpqZc2YFdan5OONTlN/B5h77HVidsz49fhdw28z8Zp/znURnzccZG4D7ZuYpc425Os59gU80Vv8r8Nx+GcgRsTclg69eNuEPwJGZuaXL9rcAvlVbdQlwq8wcKPBXZfTeOTO7vY4REf8D3LO26s3AEwdtihMRtwf+1K9+6xhqPs74DXCvzPx5n30fQfnQoO7/ZeYn++yzjDLN/8a11RuBe2fmV/vsdzQl2HVol4dbq/k46deDiLguJUD9jkGbqlV1Mz9JKY0w4z8z8zED7Hsyna/1AB+nvOZfMsD+R1Cey6/r8fjxlCzHps2UD0x6li+pMiT/j84SGz1r2kqSgMz05s2bN2/evNVuwM0oGVszt4f32O4Oje3uOeFxPqFx/gRuN+Ex/Gfj/F+lZKMMuv8bG/t/YI7tT25sf9Ic2x/R2P7sPtter7Htt6d9LTbGd0qX3/d8bicOeN6ndNn3AY1trkIJ3Na3+RKwbMBzdBvfDylZs3PteyXg/Ma+n5ljn5N6nPM+Q/w+llHqL9b3f9UQ+6+hBLvq+z+ux7aPG/U8A47lvNqxtwH7tnD9nt34Ho6YY/sXdfn9XAZce8DzfbKx77vn2P5eo14PlIDl9i77v2jcP8faOafyejDCOA+mTMmeOc8WShZ5v32OoXyIUB/fhwd9PRlwXMf3+Dn85YD7P62x32/b+l178+bN21K4Oe1akqTZmtMfv9xju+Nr93cCX29lNL2t77JuoIyUcajq/z28tmoH8NjMHKbm3HMpAYUZD6yy2abhgMbyr6YyigUmM/+d8sa/7j+jdIWeqT36AeCg2uN/AB6a8+v6/eTskgXYZXwXUq6juntUDWuG8enMHKZBxQOBa9WWfwU8b9Cdq+/tHxure2Vjt31t1o9/YQ6Y2TYFr8zMXw647dsay83pu03Nn/0nBr0eMvMMdk/bV01mXkCppTpjNWVadj/PprNe4x8pf1varjf6xcz8wIDbvpPyN2/G1SLikBbGJElLgsFHSZJmqwcfz8zeU6WPr93/Xmb2Knrflm6NNSbZCOVBdE7H/Vxm/mKYA2SZPve52qrllO6503BJY/kmC6EJzgLxaKD+u90H+HBVl+2VwG1rj+2gZA816zEO46c54NTnynvpDGIvY3YNw7k0g1VzeVhj+S05ZLOnzPwyJetwxjHVFOmmSxrLg3ZhH1T9+IfM1VhoSnYx3O+oWVf2Or2ez1UA/c6N1cMGE9885PZ7ktMay7fqtWFVU/a+jdWvywGmWo/BwL/DzLwYaJYZaDbIkiRVenUNkyRpj1S98akHUrpmPVYddetvoE5pcVi9dMswXDfB89+hsfy5rlvN7Xt0dtk9jlLba9LOpDQzmKlfeT3grRHxtCkElgdxGqXRz6jOGHTDzNwQEQ+k1B2c6TR8DKUj9XGNzZ+bmV+bx7hgdh3Fuca3JSK+QMlGnHErSvORgQ4BnDro+aogVjNIPur1/31211sMSiONZu3YZvDmMRFxBvDWMWWDnQacUN1fRgks/2X2qck4BT+uslwHkpkXRcSl7K7huoySLd4tq/MYSjfnGVvonfHe63w/j4izgGsOs98YTez1oK7qSH09SjOx9ZRyAs0u081mLFejt1vS+buA8uHCJAz8GlA5C7hhbXm/8Q1FkpYWg4+SJHW6BZ0BvK/02O5WdHaaPqWtAfWxscu6fbusa0sze+V6fZr49HNMY/mwEcczL5m5MyLeSmdH48cCD4qIj1A6r34tMxdKV+szM/MtkzpZZv4oIh5Pqbs5oxl4/BTwL2M43ekj7lMPPt5oiH3PyczL5t7sCtehs8kSwJ0iYpSs3Ss1lmdd/5l5ekScxu7n3HJKZt4zIuK/KYHPb2WPxjwDeCO7g49QAkC/jIjPUgLBp2Tmr0c89ricPcI+G+h8TdyH7sHHZsbajzNzR5ft5vJ9phd8nNjrQUTcjZL5e29glDIZzedOXTOr95zM/P0I5xjWZZl50ZD7ND+U2qfrVpIkg4+SJDXUp1wnvYOPx9fuT6PeI5Q6WE3dpmyOXZX5dVBj9ZPGdPhp1XwEeAlwezprku1LmXb8aICI+AWl0+nXgC9n5jmTHuS0ZOa7qgBbt261ZwOPzMwcw6lG+Zk29xnmOvrzkOfq1tm4a1fdEfQa9yMoU4nrz7sjgGdUtx0R8X3KtflVSsDw4kFOmJmfj4jXAH9fW72CEpA8ASAizqvO/zXg1OzTlb4ll4ywz87G8vIe2zWDYb1Kbcyl22vykhER16JMfb/jPA/VrV7xjObflUnV3r1khH0Gvb4kaY9n8FGStEeJiJtRuln3Us+cuojSAKXbdvev3f8T8LAe2/0hMz857DgH1C0T6RiGnLI6ov1pr3Z0c8rdxGTm5RFxZ+AVlG7iq7psdp3q9iiAiPg28HbgXZm5fVJjnaLnU7pFN99oP2bQYNcAhslCnNHMaOuXXdXULYu4nzYD5F2v/8z8VUTclFKXrls9yxWUpio3B54KbK+mor82M78010kz82kR8XPgZczOxoQScL1/dSMizgbeTanHN2zwdhTjCGr3sl9jedQyC6Nct4tCRNyA0sF+HE1V+v3taD63LhnD+QbR5vUlSXs8g4+SpD3NfYAXDrjtgQxWgP7QPtudCrQVfPwppe5jvfHMXB1dx6VbUG5cukZxJ6XqQvz3EfFq4K+AE4Fj6Z3Vcovq9oyIeEhmfm8iA52CKBH2t9D9Z/F4hqyT18cogYBJXjdTuf4z83fAfSLiWMq1eS/g2j02X0kJUt47Ij5HyUrt2wQoM98WEe8HHkxpKHVbeteRPQJ4AfDUiHhiZk6qLl8bmvVzR/39tnldTE1VC/kDzA48/gD4KPAdSubxecDlwNZ6LdKIOJ7eswjmYlBQkpYAg4+SJC1Smbk9Ir5B5xS420TEqmG77o6gW6bTUZn585bPOzFVnbGXAy+PiPWUenvHAbehBGWaGWrXBr4cEbfNzB9NdLCT84/M7kQ744ER8ZTM/PcxnGeU2qXNemvjysLspnn9n5+Z3aZityIzT6fUuHxqRBxGKRNwa8p1eVNmB4fvAXwpIm6dmX2zPKvH3wG8owo63YRy3d+WUpLg4MYu+wDviYgVmXnyvL6x6WleK/uNeJxR91voHgYcXVveATxqiIDz3kOcq9lUaJgMZknSAtXWdClJkjQZzazKA4D7tX3SKrjZnGJ4rbbPOy2ZuSEzv5iZL8nMu1N+zicAn29sug+Dd1heVKpajy9rrD6rsfwvEXHLMZzu8DHs0+ZU4GbToUOqAPXEZeYfM/MjmfkPmXlLSnDwr4GfNTa9ISV4PMyxt2fmtzPz3zPzQZQs71tS6v41G7K8NiIWa6CoWavxqBGPM+p+C939G8uvHDLTtVnHsZ/mc2vJ/l2RpD2JwUdJ0h4lM1+UmdHtBvxPbdMz+mz31dp2p/Xarrod3/K39D6gmeX4hJbPOaPZcOL4CZ136jJza2Z+OjPvQWd3bIDbR8TVpzGutkTEwZRpl/VZM6dSaozWO1OvBP47Iubb+OjYMezzg3mOoZ+fAVsa6+7Q4vkGlpkXZeZ/ULp9f6rx8CPmeeysgpF/S8m4rgcg96WzY/Zi8t3G8lUj4irDHCAiVgM3HtuIFpYbN5bfPeT+txhi2+bv4vCIuOqQ55MkLTAGHyVJAiJiOXC72qpTemy3hpL5M2PUOlZjkZl/At7VWH37iPircZ2j6mzdzRcbyw+IiD2xpMurmJ05daNpDKQN1e//fcCVa6vPBx6SmZsoTZouqT12deDd0aMD04D+35BjXAPcrbH6tHmcv6+qLmizw/2D2zrfKKrmR89orL7GuDI0M/PrwEcaqxfldV/Vwmx2VX7YkIe5H73rYy52zan2A3ejr/623muIc30H2NRY9/Ah9pckLUAGHyVJKm4O1N+U9woqHkdng5epBh8rL2Z2d9Z/j4h5T1eLiH2A/+rx8EeAXbXlI4DHzPeci01mJrPfjC+lIMQLgbvUlncBD83MPwJk5m+oOn/X3Bt45jzOef2IGCaT8OF01nzcBXxmHucfxH83lh8SEddv+ZzD+k2XdeO8NpvHX8zX/fsay0+NiIFqj1Yfujx3/ENaMJrZ9fsNse9DKR9IDKQKmn+8sfrvBv1dSJIWJoOPkiQV9aYtu+icWt1ru+3A/7U2ogFl5rnAPzRW7wd8PSJGzkSKiFtRptTevcd5f06Zilv3rxFxk3mcc2qdrkfN2qyacjQDvefNf0TTFxF3A57XWP2izOzoap2ZHwde3djunyLi9vM4/eurqaxzjfFKzK5F+fkqKNqmk4Gza8vLgQ9FxH6jHrDX9T+PjOJmMHQnjZp688xWbh5/MV/3b6O8ps84DPiPKnNvLv8C3KCVUS0Mv28sDzS9vio/8doRzvdKOrtcX5nSAMn3rpK0SPkCLklSUQ8qnpGZl/TY7vja/W9l5ubWRjSEzHw75c1z3SHA1yLi2RGx16DHiohrRsQ7KIHVI+fY/PnApbXlvYH/jYgHDHq+6pyHRcQLmV2jbpKeEBGfjYh7DPkm9xXAlWrLGylTBxe1qs7a++j8f/HzzA70zXgWncH45cAHqnqRo7ghJZjX89qNiAOBz9E5LTT7jHFsqgytpzdWX58S9B8qEBURx0TE2ylZzN28OyLeHhHHDHHMdcwO/HwtM3c21t0gIn4YEY+p9hn0+P8PuE9j9ULIBB9JZv6BEvSqeyDwyYi4Wrd9IuLAiDgZeGq1qlkHdKn4cmP5ZRHR929DlQX8VUpzrqFk5o+BdzZWPwD48BDZqEdExJOHPbckqR17Yl0mSZI6RMQq4Da1Vaf02G4vFlC9xy6eAOxFZ1OJ9cDLgSdFxEeBzwLfAy6cCUJU2WXXpvwM7g/cmRI4mlNmnhURD6ZMcZ3ZZ3/Km8TTgP+gvAH9dWbuqs4XlGDRDYGbUrJojqMEub430nc+HsuAe1S3P0XEJyi/4zOAX1UdvgGIiEOA2wNPqr7Wvb2qhdi260bE4+Z5jK9k5pnNlVU23AfpDKr+Hnj4zO+xKTN3VNfC99nd3fYw4P0Rcbde+/XwLcpz7QTgRxHxT8AnM/OianyHUQJDz2N2Pbo3ZOZEMpIz8yPV2OrZoUcDZ0TExyglC76RmVdkBFaZdIdTmvUcR6lved3q4Wb26Iy1wEOAx0bEmcDHgG9Srs3f155by4BrUK7hpwHXbBznNT2Of0PKc/X1EfF5SpD5dOAn9Wu5qhd5c+CvKK8z9cD09+idMb5YvBS4J3Cz2rp7Ab+KiC9RmqFcSHmNuxElK3wmYPt7SimKp9T2rWfvta211wPgLcDj2f37PgT4bkS8DPjvzPwtXPG6cTPKVOu/BVZV25/C8A3JnkRpInXj2rr7AXeIiNcDnwZ+UH0IMJOBfn3Kc+r+wJ2AHwOvG/K8kqQWGHyUJKkEOdbWlnsFFW/N7jdT/babiszcGRGPBH4OvITOAOKVKW/mnjSzeURcXG2znv6zIZpdrZvn/XwVdHonnXUzb1XdAHZFxKXVeeY630JwEPDY6gZARGwGNlOCDb2y8b7L5Gq/1X++o3oU0C3Y8M+U633GDuDBmXlhv4Nl5rkR8TBKNuLM7/jOlLqRLxxiXE+i1FS8BiX79p0AEbGBcs2u7bHfl5jdZKVtL6CM6VnAzLTp5ZTg6AMBImIHJUN4DfOvi3hdOjusZ/Vz2U6pe7myx35vzMxPznHsvYATqxsAEbENuIxS67ZXs5o/UwLTkwy2jV1mbo+IuwNfoHwwMmMVJQjZq3HKRZRA+f0a6yeZCdna60Fm/jgiXkNneY/9KNPN/yUiNgFbKUHZZumAzwP/ypDBx8y8PCLuTcmGr3eyP4Da60n1dyUo1+bUynZIkvpb6P/4S5I0CfUp1zuBrw2w3VZK5tGCksXLKW+c+wVHg/Imbl96/z/wE+ABmXnHHo/Xz/sR4Bb07jC8jPLGtN/5dgE/mOtcLZorcLKWkgnYK/D438AdM/PysY5qwiLiRErWXN2zMvMbg+yfmV+kZJDVPa+qHzmoCylBy2YgZD29A48fA+5bdaKemOo59xxKBuNve2y2AjiQ/oHHzXQPBEP/azMoQccD6R543Ao8PzOf1OWxuY4NJfB2JXoHHs8AjqtqwC56VXbtHSkZc80p6t18i/L9n0Fn0yPo7AK/2D0DeEePx9ZR/p40g38fomQh7hjlhNVU+NtRPnzolTm9L+Xn3i3wOEy2tSSpRQYfJUmaXe/x0h7bHV+7f9qkgxzDyMwfZOadKFMk3wj8bsBdzwbeRHkzfYPM/OgQ5/x5Zh5HCRp9nM5akL1spmTGPB04PDOn2S37DcBtKTUcv0kJ2sxlM+UN9h0y88GZubHF8bUuIq5JaaRS94nM7DUduJeXULIQZywD3hsRVxn0AFXDmGOrY/25z6Y/AB6YmfefZuA3Mz9FaTz0KODrzO4Q3M0FlOntJwGHVrVbu3kYJbj5VuCnDDad93zKNX39zPynPuP+AXAU8I+U5+IlAxx7F2Uq7SOBm2bmLwfYZ9HIzA2Z+RTKNN4XUj5U+QMlu3QTJbv8XZRMyOMy8xfVrs0SABdPZsTty8xdmflY4EHM/SHRtykfXP3FfOsiZ+bmzHw0pTTAe+j/WgDl2vwW5Xoe5gMPSVKLYpHPjpAkSQOqGojcALg6ZcrcKkpzlIuBPwLfy8y53tgNc77llHpd16JkZe1PCchspLyRP5NSC3J7r2NMU1UL9HqUab9XpmR+LaeM/8+UINBPMnOQIKX6iIjmP6TXyMyza4+voGTW3pByLW2hXEPfW6iBr1qN2KtSxryeEqy+DDiHcv3/bpSpylXTjetRajoeTMk8S2AD5bn8I0qd0qEzv6qarNeqblejZJatrsZ+KfAL4Id9PqTZY0XELyk/txk3rJqnLDkRcS3K9X0oJRt8I+W6/nZmntvieZdR/q5ch5KRux9wOeXv2C+BH/VpGCdJmhKDj5IkSZqquYKP0kJXdTj/UW3VRmDfUQLAkiQtNU67liRJkqT5eX5j+csGHiVJKgw+SpIkSRIQEatH2OepwF80Vr9pLAOSJGkJMPgoSZIkScXLI+JjEXGPqu5rTxFxREScDLym8dC3gS+0NUBJkhabFdMegCRJkiQtEMuBE6vbhoj4FqWW4/mUOo57A4dRmq3cvNq+bgPw0FEaCUmStFQZfJQkSZKk2dYDd6lugzgPuH9m/rq9IUmStPg47VqSJEmSirOArUPusx14F3CzzPzm+IckSdLiFs4IkCRJ0jRFRPMf0mtk5tnTGIsUEeuBuwG3Bm4EHA4cBKwFErgE+DPwQ+BrwCcz83dTGawkSYuAwUdJkiRJkiRJrXDatSRJkiRJkqRWGHyUJEmSJEmS1AqDj5IkSZIkSZJaYfBRkiRJkiRJUisMPkqSJEmSJElqhcFHSZIkSZIkSa1YMe0BSJKk3iLiFOAOtVV3zMxTpjOa6YqIVcAxwDWBw4B1wE7gEuBi4EzgJ5m5Y1pjnISIyPpyZsYc258MPLK26lGZefL4RyZNX0QcD3ylturUzDx+KoPRghAR1wKuC1wN2AdYBWxk99+OnwC/zszsdQxJ0vwYfJSkJaRLkKGXncBllH+8fwl8G/hcZv5fa4OTRhARa4GHAX8B3A5YPccul0fEd4H/Bj6YmX9qeYiSFrkh/3ZuAC4FzgFOB04FPpOZ20c894uAF3Z56NOZecKIx2wG0e6amV8a5Vi1Y34CuG9j9ecy857zOW4bImIZcB/gwcA9gAMG2G1DRJwOfBL4SGae0+IQJWmP47RrSdozLQf2B64B3A14HvD1iPhBRNx7qiNrSUTcOCJeVLudNO0xqbeIWBkRTwd+B7wNuAtzBx4B9qIEKV8P/CEi3hURh7c3Ukl7kOXAfsDhwO2BpwIfA34fEU+PiOVjPNd9IuLWYzzeyCLiYOBeXR66W0RcZdLj6SciHgL8CvgE8FAGCzwCrKfMMng1cHZEfDUi7t7OKCVpz2PwUZJUdwzw6Yh47bQH0oIbU7JLZm4nTXMw6i0ijqRk4/4L/d84bgH+DGzt8fgK4K+AMyPi/411kJK028GU16tTI2L9GI/7ijEeaz4eQfcZc8sor7FTFxEHRcRngfdTPljtJSmzPi6lZLL2cjvgcxHxubENUpL2YE67lqSl7ZfAv3VZv4IS1LkJJaNs78bjT4mIyzLzBS2PT+oQETcDPs/soGMCXwI+U339bWZuqO13KOV6vitlqt2Va/uupv+bUUmq6/W3cybz8SjgzsChjcdvA3w8Iu6ambvGMI7bR8Q9MnPaAbCT+jz2KKYcJK0+sPoi3V/nzwA+DXwZ+DFwUWburPZbARwB3ILyv9D9KL/fuhu0MWZJ2tMYfJSkpe0PmfmWfhtExAGUNw5/03joeRHxscz8fmuj05z2pEYJEXFd4AuUkgB1XwWenpnf6bVvZp4HfBb4bEQ8A/hL4KWUN5Z7tMw8CTN9tYeoGnL1bcI0gEH+dq4CHgf8M7Cm9tCdgIcD757nGGa8PCI+P61mKBFxczoDcDPjmPkZXzsibjOtmtHVB09fBq7eeOhHwHMz81O99q2ak/2qur0/Ip5AyfJ8Nn5gJUlj5bRrSdrDZeZFmfm3lCljdUH3IvjS2EXEXpQaXc3A45spHb57Bh6bMnNHZr6Xkp30Gna/WZakscjMbZn5OkpdwaZnz+PQ59P5mnUT4EHzON58Paqx/GU6u4l322YiqsYyH2Z24PEjwC36BR67ycwtmfl2yt+O5wHbxjJQSZLBR0nSFZ4D/Kax7h5VUEhq20uB6zbWvTEznzDq9MXqjeTTgAcCm+c7QElqysyPUbKu664XEc2A2KB+BXygse6lY25mM5CIWAM8pLH6XdWt7i8iYt1kRtXhKZSp7nX/DfxFZm4Z9aCZuTUzXwbcitn/F0mSRmDwUZIEXDH96B2N1auBBdFtU0tX9Sb97xqrfwY8fRzHz8yPAv8xjmNJUhf/3WXdLedxvBcAO2rL12E62YUn0lkDcSPwUUq24Yba+vXAAyY2KiAi9qVkJ9adCzxuTPU2qcrO3GMcx5KkPZ01HyVJdd/osu7wUQ4UEYdQ3nxdg/LGZAvwg8z84gD7XplSAP5g4EBgE/An4BfA6dOqfTWXqvbUzSnjPojShflPlOYF350pcj9tVTbrbYDrAftS3kSeD/xfZv5+CkN6CrCqse6J88lcaRrlzegkr8OIuAbl2rkysBa4CPgpcFpmLripfxGxkpIVdAPKVPnLgQuAb2Xmr8Z0jqD8/K8DHFatPh84IzN/MI5zjFPV5fjWlN/hQZROun8Cfgd8c5zXc+O8y4GbAkdTrtUVlC7wH8nMP7Vxzi5jWEv53q9LCVZdDpwFfC0z/zzA/uuB46r996F0Iz4H+Epmbmpn1GP1oy7rDh71YJn5q4h4B/C3tdUvjIj3tnUd9dAMeH545vcRER9uPP4oxlfnchCPZnZjsmdk5sXjPMl8r78qI/Q4dr8uLKe8Lvye8je3laz8iDiI8nf+msBelL8pvwO+mpmXjfE8AdyI8jp9EOV/iospfw++nZm/G9e5JC1ymenNmzdv3pbIDTiZUitq5nbKkPsf1dg/Kf/Mz3WeF9UeuytwCrCry7F6jgdYCTwJ+GGX/eq3C4C3AFcb4Ps5fo5jzXU7YoBz7AU8jdJRs9v3PHP7M/C2QcbdOP4pjeMcP8f2J/X6mQOHUGoobuozztOAO03wmp15I1Yfw0+n+Bwa+3U4x/nuDXy7z3kuBV4N7Ffbp2ObAc5xcmOfk4Z83pxde2w9pUHVxX3G/FPggfP4mewFvAj4Q59znAU8AVjWY8ynjHr+EcZ7L0oNvG19xrsZ+CRwyyGPfUSv3zflTf4rgQt7nPP4MX1//a6HQ6rnweYeY9hKyTq+Uo9jX4MSsLq8x/6XA68F9p3HeOe8Fro8R4a6foBrdxn7cwfc90WN/b5erb9yl5/r0wY8ZnMsdxnh935VSgC96zUF3KHx2C7gmhN83jVfoy8AVk3q/AOM736U+phb+7wubAE+M8LrQvN6fVHtsZtQuns3f3czt+2UTN1rz/P7uxZltsz5fb6/mb8HTwBWTvt34s2bt+nenHYtSarr1iE0B9oxYkVEvJnSrfgOPY7Va99bAmcCrwduOMfmB1GyQX4REc8a9BxtiIgTKfW5Xk355L/f93wA8NeUcTenGLcuIu5EeRPwOEpmXS+3BP43Ip47kYGVjJArNdY1p/9PxCSvw4hYHRHvo7xJvHmfTfehBLd/FBE3GvY841Sd/0fAs+icitl0FPChiHhz1RBimHMcA/yE0uzqsD6bXgN4I/CViGhmP01ERBwcEV+mBA+OpwSue9kLOAE4LSLeX2UKzufct6A8n59JycqduIi4QzWGv6V8f92sAh4DfCcirtXY/4GUANIj6OwWXbeGkhn9zSqzfKHap8u6eWXMZeYfgDc0Vj+7yhKdhJPoLNF1DnBqbfmrdNZDDOCR7Q8LIuJwZr9GvzcXQJZ4RBwVEd+lTE+/I7Oz+utWUz68OC0i/rPqoD6fc/8D8B3Kh1q9XntXUBoYfT8i7jbCOVZFxOsppVEezdwZvkdRXqt/EhHXH/Z8kpYOg4/SFETEdSLiARHxxIh4dkQ8NiJOiIjrD/tGTRqzbm/uLhxw37dRAlt1OykZUj2nvEbECZTsgGv02OQSOmtfzVgDvCIi/mMaz5sq4PRRSnZK0y7KuLtNp1oDvC4imt3FW1MFHv+H2VPULqFkXnTzTxFxUovDmnGHLutOmcB5O0zyOqzeYH6E7l1yoWR7bWysuyrwpYg4cpBzjFtE3ICS3Xd446HL6B1keRyldt2g57gx/X8Hl1KyiOpuT7m2ewWvWhER16ZkCd+xxyYbKb/Hbh5CCZo2g+6DnvuGwBeZ/dqziXkGvIYYwy2Y/ZqyizK1s9vz5AjgMzNB1yrw+AFg79o2/f5eHAV8bAH/j9Ttg4HfdFk3rFdSrvsZVwL+YQzHHcRJjeV3Z+YVH0ZW95vTrB9ZTcNtW7e/G1+dwHn7qoJ536SUQehmI52/z7pHAV8c9YOJ6gPDf6XMJpixk97PyXXAJyLiekOc40DgS5TZAd3Kt23rc75rA9+IiGaDIEl7iIX6B1wLUEQsi4ijI+KREfH6iPhmRGyOiKzdjp/2OBeqiFgZEX8fET+iZNZ8mPKJ9suBt1OmY/0E+HNE/JcdhjUl3ZrLnDPAfg9kd+2nDcCLKbXgVmXmAZTAwE1ovFGJiKMob0Cb/2x/ilLkfU1m7k/JHDiK0hG5+Yb+McCze4zrF8Djq1vzTdIva4/1unWtVRYRT6RMPa2/yfoD8Hzg2Or73j8z11GyAv4K+HHjME+PiElkiRxKmWK1mvJG5D8pb9xWV2PcCziS8rNtBiJfExH7tzy+YxvLWynZUBMzgeuw6RWUzJS631OCdYdm5trMXE/JaHsk5W8GlMDDewc8xzjtBXyMUtuR6v49gHWZuW9m7g1chRIUuaSx73Mi4jpznaCqi/ZRZmfxfZGSMbguM/fLzDXA1YEnA+dV29ySkik5EdXf548zO0j6Y0oW3wGZuT4z11KyNx9PqbVWdwvgfSMGav6L3Zl2pwL3p0xL3rv6XRxICRz9cYRjD2Iv4IOU58sOyrTrW1Je9w6kvN7fjhKsrrsO8Mwq2HEyJUiymRJgO4YyLXPm78U9gWZdz1tRMq0Wogc3lndSgtPzkpkXUQJKdU8bNXA9qIi4PeXvQl23eo7vonN2xOHAndoaV82Nu6z77gTO21NEHEv5X37f2urNlOfH8ZTXsPWZuR8l8HcPStZ03e0pWYLDuhvl7xKU4OaL2P2cOpDyt+vmlNeOujXAWwc5QVXn95OU53bdqcDDgKtk5ura+W5EeY9Tb0y0L/Dhqia4pD3NtOd9e1scN0qGxkb61/QYW32hpXajfAL64wF+fvVb1/pI3rz1uzGPulWUT7HPauy/BdhrgPPUa/tcdcDzLaPUSazvvxN49Bz7XZsSEG3WMLrZHPudNOrPpnGcmzG7htN7gfUD/Hzf0thvE3DYHPudMszrbJfvc+b2J+A2c+x7x+pnWd/vyS1fsz9onO+MCT9nJn0d3orZtbi+0O/6oQSO39fj95oDfI/N5+tJc2x/fI9zbQLuN8e+N6C8+a3v928DjPHfu5yvb307yhvZr/UY6yktXjOv73K+t9GnphmlVubnu+z31DnOdUSv3zvwDxN6jvS6Hi4GbttnvxWU4Ep9nwuBr1f3zwau12f/vYHvN/b//gjjnfNa6PIcGfj6oXz41vzZfHqI/V/U2PfrXX4Ozbp6fZ9TXcYzVM1H4J39xtTY9tTGtu+dwDX5iea1OInnQp/x7Mvs/5++Bxw5wL6PogTx6/ved8jrtX7Ouf6neH6X/Y4ZYJyvbuyzGXjoAPtdi/IBWn3fj07z9+XNm7fp3Mx81KBuSvmUTkOKiOMo08iOrq3+HSUI8Q+Ufzr+jvLG6zT6TE+VWvYyZmfyfDYze00dbLoUuGsO3i35RGZPVXtmZv5nv50y85fAXeic0rwCmFSNwlfRWcPpg8AjMnNDj+0ByMwdlAyoT9dWr6XUM2vbDkrQ6P/6bZSZX6E0pKl7YGujKpp1/ebsjDtmJzLZ6/D5dM48+TlwYr/rJzO3UjIgpz2t8DGZ+bF+G2TmjymZz3V9r6GqZuPfNFb/a2b+2xznupRSL+3sftuNU1V3sDnWTwF/m5nbe+1X/X5PZHZW77MiYvUIQ/nXzHz1CPuN00Mz8+u9Hqxe855A5/81B1I68G4FTsjMn/fZfyOzXx9vPK3SA01V7bunMDsbeRuDZ0HPqfo5vKyx+gkRcbVxnaMuIvZm9nP2XX12ObmxfP+I2Hesg5qtWXLg4pbPN5en0Pn/05mUgO+v59oxM9/J7OvlOSOM4ffA3TNzroznlzE7q3iu1+hr0vlcTOBBmfn+uQaVmb+ivE7XO2yfWM04kLQHMfioUWylFDN+C9OZ/rVoVMXVP8/u6VGXUQqzH5GZj8/Mf8vMkzPzDZn51Mw8jvIP1XMp/7xKrYuI/SPiLcAzGg8ls4MI/bw0M88dYvvmm8ofAK8ZZMcq8PNPjdX3jYheteLGoqpzVq/xdjHwxMzMQfavtnsanW/G/3oCdcze2S9I0PC2xvKxLY+v2TyhVz2stkzsOoyIIyhT7er+LjO71QZtnmsmeD2tD6i+mJkfGHDbd9JZ8+tqc0yzO4nOmo1/ZMBp1FVQ72kDjmscHk/nhw+XA08Y5DWg+iCnWRf3EOAvhxzDnxiilmZLPpmZn51ro8w8h5Lp2PTmzPzRAPt/FfhtY/XNBhvivFw5Ih7X5fbEiHhORLynGtdrKZnJM3ZSgvRzfm9DegudJVBW016pgQfRWYtzC6V0Ry8forPW6F7MnoY+bns3li9p+Xw9VTUam03kHp+ZwwREX0PJnJxxy6oG7jCemZlz1ujOzF2UDvR1/ZqeATydzlqS78nM5pTxfuf8NSXJYkZQ3g9J2oMYfNSg3k35pP+mlKlht8jMxwP/O91hLVxVHad3sPuN9Qbgbpn5tuoPf1eZeX5mvjwzL+u1jTSEfm+gnh8RH6W8ger2T+DLMvOMAc+zndnZDz1FxD7AbRurX5+ZOwc9BiVDr16jcBmzAzvj9rDG8vsyc6hMvSpgVa9NdQClNlObmtmMPVWZa/XXn3VAKxk2lWbW10QaZsBUrsP70Pm/15mZ+aVBT5SZP6Vk0k/DMNfQxZROqHX9mhrcvbF88iAB2ZpPAsN88DEf92osf2SIbG8y85vAt+Y45lzeM0RGelvePsS23+myrhn86KdZy2/gBhnzcG3KNd+8vYGSOfZwSuC47kxK9v/YP5jP0sX5RY3VJ0XEdcd9LnbXb57x8SrLuKsqM/Ojcxxj3Ob9dyMift+oW9/vdkqfQ92dUo93xo+rWQQDqz5c+nBj9fFDHOIi+geIm77RWO75nKo+fGx+QPK6Ic41o1lv8vgRjiFpETP4qIFk5gsy8+2ZeXq/aUXjEMVNI+IREfEPEfH06v7Rc++9oDycUjh6xjMzs/mGQ2pbvzdQLwHux+wMAoB/z8znD3GeHw4ZhLsVnX+DktlvXvrKzEsozSjq2u6ieIfG8udGPM73GsvHjXicQVzC7ClWc/lNY3m/sYyku2b34kmW+Jj0ddj8PX98mHNVhhrfGJ065PZnNZb367ZR9UHdLRqr/2eYE1XB4s8Ps88oqgynGzdWf2SEQzWDBMO+bg0V2GhB0j2bsZdm5uJFlPrAo+6/3xD7TkJSyikcPWzQaUjvoTOov5zdTUbGopqx02wo0m/Kda9tbtXytNrm7KBploZaCP8XfL0KYA5qoNfnyjHsbjYGcGFmNsc6p8z8GZ2N2m5YTfGXtIdYMe0BSDMiYj3wTOCxzP40eWabXwIvzMzmp2cL0ZNq93/FgN3kpCn7EfDsYabT1PYbxg0by78ecorSjO9SOuHOaC2DsAo8NMd98xHrbjWn5TbrHo7Tb/tlW/fQrD+4T9etxmMjnVNu264VVjfp67BZW3LoN3Aj7jNfl2XpujuMQa+hw+h845uUBkDD+v4I+wzrKGb/79wtq28uzUy+q0TEAUP8jMc9pXdYl1ZB90E1s9J+O2ipisrGxnKbr0ejCEqJkjXA89o6SWbujIjn0RnwfmBEHJuZp4/pNCc1lv/I7A9XuvkyJUh89caxnjmWUc3WvCYm+Xej6VaN5atERLO8wiCawdph/i84e8hzDfM3vvn9bRzx+4MSNN6rur8MOJjZv0tJS5TBRy0IEXErSgZIv5pQULK43h8R9wMe1nYW5qgi4hg6MzneMcKbf6lNOylTay8Bfgl8G/jcXA1J+hi2SciBjeVmpt2gmsXcm8cdp4OZPWNgXDW32hz3JSPs05x2vLzbRhFxR2CYaX+fzMw/NNb9kc4pawcMcbz5mvR12Fx/9gjnGmWf+bpkhH0GuobozKgB2FBN4xzWXE0WxqH5+9s+ZJ3bGd2aUBxIyQgcxKSbMjUNWxameS3Md/9e19I4nZqZx9dXVFm6ewNHAnej1IudaXyyDHhuRKzKzGb95LHJzI9GxHfYXaMvgJczhpIj1fTav2qsfu8gZSgyM6s6mPWGW4+IiOcMWcZiUH+glIKaMcrfjWfRfeYHlAZfzaBbL4c2lh9S3eZrmP8LLhnmwFUgu76q32zI5vd3BEOU4ZjDgczOwpS0RBl81NRVb14/Ten6OuPMat2vKUXrrwv8Bbvrjj2Ikh3RdkHrUd2tsTzUFDJpjGa9gWrJsMGCZsBh1CYjzf3aDFy1GSBcO/cmIxsmw2hYj6xug/o55U1j3a/pzEA8KiJWTujDpUlfh83zjVLbd9INeaDda2i/xnLfrvF9TKJOclvXCwzx2jVicHac5ns9tHk9tabK1txAycw9IyLeTCmDcJfaZv8YEd/NzGHq7w3rOXRmI949Iu6QmcOWRmi6C7Pr+w4y5XrGyXQGHw+jBEWHnUkxiFnThiPiKsN8GNCvNmeVFDFo8LGt/w2G+b+gzefUYv3fR9ICY81HTVVEHEwpQDzzx2cL8BjgqMx8ema+uao1+XRKALI+dfkvIuIRkx3xwOpZjxuAHwNExLER8YaI+ElEXBYRGyPiNxHxsYj4m4jYq/vhpCUnGsvj+se5zX/AV829yciaP489SXO64GpmT4duy7Svw0UZhBmzZs3PUZ9nbT4/Z7R1vYz7WJqAqtP6/Zldv/JNEXGlLruM67xfYnbjqVeM4dCP7rLux4M2ZaHMomhqq/HMGV3WzdWxuS1tvfYslP8L/N9H0lgYfNS0vZLdU613AffLzP/sVgsoMy/PzMfRWevmpdU0kYXmJrX7vwTWRMQbKXWenghcn9IFex1l+sKJlMDqWRFx/4mOVJqO5vTC/UY8TrPO0yj1+gbVnOqYwLrMjDHcTmpx3Atdt2ydO07o3JO+DpvrR6lTNs3aZm2Y9TOJxnzAAe03hrHMpXm9jPq76LZfm69dakkVgHwU5X/YGQcy5kYwXTy7sXxcRJzQdcsBRMR+wP+b14i6OyEi2sic+2qXdbfvsm4Smv8b3HNM/xccMY1vpovm9/fBMX1/kZmnTOMbkjQdCzFooz1ERBwKPKy26j8yc5AOcU8GZqbjHQ7ca9xjG4ODavcvAD4EPIHdn/BtA37P7BothwIfjointD1Aacqa/8weMeJxrjnHccfpT43l6HL+PUpmnjSGNxrfYPbv7THtjx66nPeIEY8z6HU4jvONss9Cdj6dWX+rGO151WZn3RnN39+qiLhy1y376/b9TbuOo0aUmd+mdKKue2zVObrNc368sfpl8/hA/qF0Nv4al1V0/q8/Fpl5FrMzTh8eEZPIgG5q/m/Q2u99Spb69ydpQgw+apoeSGcq/2sG2alqVvCl2qq7jnNQ81VlbKyvrbozuwOkZwL3A/bJzKtl5v7A9YB31g8B/FtE3HkS45WmpNmt9VpV5sWwbtZY/sFow5lb1QX5nMbq49s6354iM3cw+437UVU94LZN+jpsrr9p1636G2WfBauqX/jzxupBa63Nd59h/YxSh7qu+bsfRHOf34/QTVwLywspHyzPWAE8v+VzPo/OjMsbMnqjk+b06C8Ajx/x9ok5jj0u72wsH0SpCT9p328sHz+FMbSp+f3daMS/k5L2cDac0TTdrnb/rMxsvvno59vAPav7t+y1UURcdZSBDejSarpN0zo6A/srq6+nA3fKzI5C85l5JvDoiPgZ8Kpq9TLgtRFxTLcp6NIScBrlTdPMcyUogfnmm4meImJfZn/48I0+uzSDBqN0TP0i8Nja8oOBN4xwHHX6d0p2eP0DqTdGxLGZuWUcJ4iIZZm5q7F60tfhNykZRjNOpHRcHcZSLM3xf3RmLj4MeN+gO1fZh8ePeUyzZObmiDiDzuDhA4BPDnmov2gs93vd0iKQmedExLuAv66tflhEvDQzf9XSOX8SEe+ls0P1SyJiqGY3EXEDZgfE/zkzm3UlBz3ed+icwn3jiLhxZp4xyvH6eAelwc1+tXWviojPTjiY/0U6G6/dIyL2bf6/v4h9A9hEeX8DJX7wAMrPX5IGZuajpulGtfs/GXLf82v3+wUYf9fi7Yk9znl5l3W7gIf3+0ckM/+FzgLiN6Czg6K0ZGTmZcDXGqufNOSUsccB9SZNu4B+pRuaHxaMUq+t+abuthFx9xGOo5rMPBt4Y2P1UcC/jOP4EXE/OoPGM+ed9HX4aTozla4bEQO/zkfE9YE7DTG2xaIZaLx71W12UC9gtA8TRtHs3PvAYaZeR8QtmZ2l2UY3YE3ey9ldFgjKNdl29mMz4/KadHmtm0MzM/E84JRRB5SZ32N285mxZz9WsxFe1lh9ZeBtE64H/1mg3oF+HcN/qLRgZeY2ZmezPj8iVk9jPJIWL4OPmqZ6AeoTBu2mV3XUe1Nt3/0nPO6+MnMnpWt33Rcy82cD7P7axvKCmlIujdnrGsvHUmq6zikijmT2m7pPZOZv+uz2x8bytYatD5WZX2R2ltI7I+LqwxynbsTmGkvR85j9hvVJEfHGUd9IRsSaiHg1pVHZ2h6bTew6rIKszcDk6yKi19jq51oBvJkl+L9bVQu0/iHkMuDkiDh4rn2rwPLftDS0bt5CZ7BnLSVLd87ncUSsqfavOx/4wPiGp2mpnt/vbqx+WMu1H88G3tZYPXDAs3pdeXhj9Qe7ZIkP678ayw9rqR7jaygZ7HUPAD4SEXt12X7sqizL5t+Rf4yIe3bbfhAL8P+ClwI7a8uHM8/MxwX4PUpq2ZL7B1aLyn5jOs6cb9qm4LLG8lcG3O9UOgvvHzue4UgL0seZXQPvXyPiEf12qgI+X2L3FCAoU6qbGRBNP6Jz6vVeDJ8hAvAPdGa3HAb8X0QM1WkzIq5ZBcbePsIYlpzM3EyZqtfMEH8C8JWIGLi2XkSsiIiHURoSPI3dzb66+TiTvQ7/ic7sx6OAj0fE3n3OtQp4F9Pr5joJT6bz7991ga9GxK27bVz9jp8OfJDy+x3L9Py5ZOZ5zA72nEgJQPYsZ1T9fj8K3Ljx0CuqzCItDS+n8+/McsoHK236J8q02BmHDbHvvYFmkL8ZOBxF8xgHAiN34+6l+sD/AZQmjnUnAt+OiPsOe8yIuBFwkyF3+1fgt7Xl5cBHI+LxQ557v4j4e+BbQ56/VVVprDc1Vj8sIj4aEQcMepyIWBYR94iI/2F3+SxJewiDj5qmzbX7FwO/nsetqyG7sA57e2Wf7605pt923Wr2eC+rfhYzDuq1rbTYVZkVD6HztWA58O6I+FhE3GVmWk8U142IFwM/ZHbH3xdWU736ne9y4PON1W+MiC9ExIsj4kkR8bjGbX2X45xGKapfd1Xg1Ij4fEQ8NCIOr3+qX/3DfdWIuFdEvCgiTqe8TjyN8X0Qs+hVGeJ3p/N1EErQ7dvVz/fJEXH9ZrAuIg6u3tS8GjgbeC9wjQHOOenr8JvA6xur7wr8NCL+pp7tFxH7V0HQH7C7VmQzy2dJqOrLNX8u16UE9r8dEa+MiKdExLMi4j8o5U/+hVJXeQfwkuYhWxzuM5jdaffxwHer5/9+Mysj4pCI+BtKZmfzzfYXmJ0xpUWs6sL83sbqh7ec/Xg+pW7uKJrToX+TmfMOfFXBqjPmONdYVI0o78TshnA3AD4REWdExEsj4viIuFJEdJRoiIgDIuLWEfEPEXFqNe5jhhzDxcB96QwCrwHeFBE/joi/i4gbdjn3gRFxh4h4akR8AbgA+DfK9PmF5mnMTqa4H/CbiHhtRNw5IvapP1jNPjgmIh5WvW7/kTJN/Z4Yh5D2ODac0TRdCMz8kfpQZv7tNAczZj8BjqstD5ORUd92zXiGIy1MmfmziPhLSvZSfYrUidWNiLiEkl22ku7eAfT7MKDuZZTgVv3v313pXeLgc8yuFUlmvqOa0vVvjXHdrboB7IyIS6vH96Z/9p0qmfmtKPX+PkTnG8Cg8+dLRGyh1Nram/6vl5uAX/Q556Svw2cC16NcizOuBrwVeGtEbKZMcWsGvy+kTJFspYHFAvD3lN/loxvrb17dutlFmXZ9dmN9a5mQmXl5RJxI+TCjHuC+EVX9yojYQAli95qd8R3gYTaVW5JeBjyC3XVIZ7IfT2rxnK+iBMAHLkVUfdBxr8bqcWQ91o9149ryPSLisMxslkCZt8z8ZZR6qu9h9t/zG1W3mQzUrF7Pg/J6M9f74a9QZjzMNYYfRMS9KLWhD6k9dDS7P2TIiLiM8rq1D5OrVTtvmbkjIu4PvJ/OD1L2AZ5S3Yb5uyxpD+MnDpqmenfro6c2inac0VgeaEpClSlV/8fxz+MakLRQZeangDsCveo17kf3gM8W4NmZ+dhB61NVWWePYHZphKFl5huAOzA7A2rGcspzfz29A4/bGb7h1pKXmb+gBJueDVzSZ9M1wJXo/QZnK6VO4rUy83/mOOckr8OtlKBmr660a5kdePw9cNfM7Jntv9hVP7/HUhq6XTLALn8ATsjMd7L7w8wZg+w/ssz8JeVDxl5lVdbTO/D4X8DxmXlhG2PTdFXdrd/fWN129uOlwD8PudvDmf2aNs7g4wfozEBeTvn724oqA/TulO7fZ/fZdOZ/7f3oHXhM4OvA/8vMO2Xm9wccw1cpU7Z7/b0JSrO7/ekfeBzofJOWmZdQpuo/h84mO3Vz/V2G0tRo7EFoSQubwUdNU/0f9ltFxJWmNpLx+1Rj+cYD7nddOrNuzhrLaKQFrprmdT1K3bcfz7H5nyg1164zR/mDXuf6AHAk8CTgY5SMuEvorOM46LG+SZnadX/gi3RO3e3lUkrnyCcAV87MFw573j1BZm6rfr9Xo3SU/jKdjT562Uzp1Po44NDMfEJVp2+Qc07yOtySmQ+m1Ln8bp9NL6M0VbhhZp4x7HkWmyzeBFyL8hz5AiWQsIUSTD6H8vx5DHBkLajcrFvXnLrfxljPz8w7AfehXHP9XkMup3Q7Py4zH1rVONXS9U90NuiYRO3H11EC8oNqToP+SWbO9bo3sMz8LbMbtLUy9bp2zszM91BeP+5HCYAO+lqwkRJwfAHlteV2mfnJEcbwx8y8N3BTSib0nwbYbRvw1erc183MBdtwsvoZvwK4OmW8P2GwMhe/pDTcuidw1blKlEhaesLZHpqPiDgJeGdt1R2zdK0cZN8jKH+IZj51fFVmPnOc45umiPgWcItq8VzgiMzc0WcXIuKFwItqqx6Tmf/ZzgilhSsirkJ5/hxMKVS/ifIP/JnA6Qt1qmKUxiA3o3SCPJCSWbGFMnX7d5Tx/2bQDDl1qmovHkMJHh9KmQa9k/Lm8mJKRv1PsjQhGMf5JnYdRsQ1qnNdmfIh1MWUrNpvpg1J5hQRb6ezgdTfVdnJkxzDeuA2lN/hQZRr80+Uus/fzMyJNMWRtFs1q+jalA/4r0bJSl5JCTZeQnmt/SXw8zb+Nlfnv351O6C67aL8X3A+5QPQX1QZ8YtSlUBS/1u5ht0/319RfrbO5pL2cAYfNS/zCT5W+7+b3VMwdgD3zswvDLF/ACsX4huziHgQnVPqnpmZr+qz/dUo3Xj3rVZdRglYtp69IUnSYlUF/X9DCfrNuGVmfntKQ5IkSVKN0641bc9gd82PFcCnqm5zfQsUR8RhEfF3lCyXY1se40gy80PAN2urXh4RXZvqVHWAvsjuwCPAqw08SpI0p8fQGXi8kNm1lyVJkjQlZj5qIFV3s25Ze+vprLP0B0pdo6ZnZOZHexz7OEpH2Xqx+AspXSTPAC6i1MrZD7gOJdh4E3Y3cDguM08b8FuZqGpq+Tcp0wNnfJ9Sr+p3lE5wt6LUi1td2+Z/gbuPa+qgJEkLXUSsGnYmQ/U/xP/SWS/5lZn57LEOTpIkSSMz+KiBdJlePaxHZebJfY5/FPBxSnBxWLfIzO+MOK7WRcSNKN/bEQPu8lHgrzJzU1tjkiRpoYmIE4HnAm8APtkv+z8i9qF0xX4RsKr20KXA0Zl5bnsjlSRJ0jBWzL2J1L7M/FlE3AB4NKXL6PXn2OWnwGeB9y707p+Z+YOIuCHlDdIjgV5dvX8MvAz44EJtpiFJUstuBpwM7IiI7wI/pHS4vowyQ+BAygyI21IaDjX9jYFHSZKkhcXMRy1IVYfRWwGHAPsD2yjd6H4N/Dgz/zTF4Y0sIlZQOmFek/K9baV0uvtmZv5mmmOTJGmaqszHj424+3ZKh+u3jm9EkiRJGgeDj5Wqa/KRwA2Aq1HqD26m1Bv8AfCjSdffi4hlwK2rcR1GmUp0LvA1G5FIkqSlpKrf+ClKduMwvgI8Z6HWf5YkSdrT7dHBx4hYD5wA3Be4E3BQn80vptQ8/NfM/GOf7cYxrhXAM4En0Nm9ccY2yj/nT8/Ms9sciyRJ0qRU/wPdHrgdcFPgGpT/hdZRygVdSvlg+FfA14DPZebp0xmtJEmSBrHHBh+rwOMFwJohd70IeGxmjjotqK+IOAT4NKXm0VwuozQm+UQbY5EkSZIkSZLmY08OPu5HyWasOws4FTgTuJASmLwh8AA6m4TsBB407gBkROxFmTp0y9rqc4H3UmodHgjck5IRMGMLcKfM/OY4xyJJkiRJkiTNl8HHkj34TuA/M/OHPbZdC7wW+Ova6ouB62TmhWMc078AT6+t+jDw8Mzc2tjuoZROkCurVb+rxrJlTOM4D1hbHVeSJEmSJEl7rqsBmzPz0FF23pODj3sDzwX+JTMvGnCf9wEPra16YWa+ZEzjuSrwS3ZPA/8hcLPM3N5j+2cBr6itenpmvnpMY7ls9erV64888shxHE6SJEmSJEmL1K9//Wu2bt26ITP3GWX/PTb4OIqIuDLweyCqVd/JzFuM6dgvA55TW3WPzPx8n+1XAGcDV6lW/T4zrzamsfzk+te//vV/8pOfjONwkiRJkiRJWqSOPvpofvrTn/40M48eZf9l4x7QUpaZfwB+Vls1ztTA+9XunwN8YY6x7KBMF59x1YgYpEmNJEmSJEmSNBEGH4e3sXZ/3TgOGBHXAI6qrfpSDpaS+sXG8n3GMR5JkiRJkiRpHAw+Du+I2v3zxnTMGzWWTxtwv28DO2rLx4xnOJIkSZIkSdL8GXwcQkTcFji4tuqbYzr0UY3lXw2yU9Xd+g+1Vdcf03gkSZIkSZKkeTP4OJxnNJb/e0zHvWZj+bdD7FvftnkcSZIkSZIkaWoMPg4oIh4CnFBbdQbwiTEdvtmq/KIh9r24dn9lRKwew3gkSZIkSZKkeVsx7QEsBhFxNPC22qodwF9n5q4xnWLvxvKWIfa9vMuxtg6yY0T8pMdD4+ziLUmSJEmSpD2UmY9ziIjDgM/QGSB8VmZ+d4ynWdNY3jbEvs1A417zHIskSZIkSZI0FmY+9hERBwCfBw6vrX5bZr56zKdqZjqu6rKul+Y062YmZE+ZeXS39VVGpM1rJEmSJEmSNC9mPvYQEfsAnwNuWFv9PuDxLZxuY2O5mQnZTzPTsXksSZIkSZIkaSoMPnYREXsDnwVuXlv9YeCRY6zzWHdZY3n/Ifbdr3Z/e2YOVO9RkiRJkiRJapvBx4aIWEup8Xjr2upPAg/NzJ0tnfY3jeWrD7FvfUr4WWMYiyRJkiRJkjQWBh9rImIv4FPA7WurPws8KDO3t3jqnzaWrzXIThGxBrhyn+NIkiRJkiRJU2PwsRIRq4GPA3eqrf4ScP/MHKb79Ch+0Fg+bsD9bkFn06AfjWc4kiRJkiRJ0vwZfAQiYhXwEeButdVfAe6bmYN2nR5ZZv4G+Hlt1V0iIgbY9a6N5U+Pb1SSJEmSJEnS/OzxwceIWAF8ALh3bfXXgBMy8/IJDuVjtfuH0xkInaUa96Nqq84FvtvCuCRJkiRJkqSR7NHBx4hYDrwXuF9t9TeAe2Xmpnke+4iIyNrtlDl2eTNQ71T9qohY2Wf7pwNXqS2/NjNzxOFKkiRJkiRJY7fHBh+rac3vAB5cW30acI/M3Djp8WTm74A31lYdA7yvqkXZISIeAry4tupc4A3tjlCSJEmSJEkazoq5N1mybgs8srHu6sD3Byu3eIU7ZOa5YxrT8ymdtm9WLT8IuHVEvAc4C9gfuBdwh9o+W4G/nERtSkmSJEmSJGkYe3LwcXmXdVce4Tj9pkYPJTM3R8QJwGeAY6vVVwGe1WOXDcAjM/Pr4xqDJEmSJEmSNC577LTrhSozzwNuBbwAOK/HZtsoDWpulJkf67GNJEmSJEmSNFV7bOZjZp4CDDW/esjjnz3q8TNzO/DSiHg5cGvgWsAhlEzH3wNfy8yLxjRUSZIkSZIkqRV7bPBxMcjMncDXqpskSZIkSZK0qDjtWpIkSZIkSVIrDD5KkiRJkiRJaoXTrqUWvf9bvx3r8R56y6uP9XiSJEmSJEltMvNRkiRJkiRJUisMPkqSJEmSJElqhcFHSZIkSZIkSa0w+ChJkiRJkiSpFQYfJUmSJEmSJLXC4KMkSZIkSZKkVhh8lCRJkiRJktQKg4+SJEmSJEmSWmHwUZIkSZIkSVIrDD5KkiRJkiRJaoXBR0mSJEmSJEmtMPgoSZIkSZIkqRUGHyVJkiRJkiS1wuCjJEmSJEmSpFYYfJQkSZIkSZLUCoOPkiRJkiRJklph8FGSJEmSJElSKww+SpIkSZIkSWqFwUdJkiRJkiRJrTD4KEmSJEmSJKkVBh8lSZIkSZIktcLgoyRJkiRJkqRWGHyUJEmSJEmS1AqDj5IkSZIkSZJaYfBRkiRJkiRJUisMPkqSJEmSJElqhcFHSZIkSZIkSa0w+ChJkiRJkiSpFQYfJUmSJEmSJLXC4KMkSZIkSZKkVhh8lCRJkiRJktQKg4+SJEmSJEmSWmHwUZIkSZIkSVIrDD5KkiRJkiRJaoXBR0mSJEmSJEmtMPgoSZIkSZIkqRUGHyVJkiRJkiS1wuCjJEmSJEmSpFYYfJQkSZIkSZLUCoOPkiRJkiRJklph8FGSJEmSJElSKww+SpIkSZIkSWqFwUdJkiRJkiRJrTD4KEmSJEmSJKkVBh8lSZIkSZIktcLgoyRJkiRJkqRWGHyUJEmSJEmS1AqDj5IkSZIkSZJaYfBRkiRJkiRJUisMPkqSJEmSJElqhcFHSZIkSZIkSa0w+ChJkiRJkiSpFQYfJUmSJEmSJLXC4KMkSZIkSZKkVhh8lCRJkiRJktQKg4+SJEmSJEmSWmHwUZIkSZIkSVIrDD5KkiRJkiRJaoXBR0mSJEmSJEmtMPgoSZIkSZIkqRUGHyVJkiRJkiS1wuCjJEmSJEmSpFYYfJQkSZIkSZLUCoOPkiRJkiRJklph8FGSJEmSJElSKww+SpIkSZIkSWqFwUdJkiRJkiRJrTD4KEmSJEmSJKkVBh8lSZIkSZIktcLgoyRJkiRJkqRWGHyUJEmSJEmS1AqDj5IkSZIkSZJaYfBRkiRJkiRJUisMPkqSJEmSJElqhcFHSZIkSZIkSa0w+ChJkiRJkiSpFQYfJUmSJEmSJLXC4KMkSZIkSZKkVhh8lCRJkiRJktQKg4+SJEmSJEmSWmHwUZIkSZIkSVIrDD5KkiRJkiRJaoXBR0mSJEmSJEmtMPgoSZIkSZIkqRUGHyVJkiRJkiS1wuCjJEmSJEmSpFYYfJQkSZIkSZLUCoOPkiRJkiRJklph8FGSJEmSJElSKww+SpIkSZIkSWqFwUdJkiRJkiRJrTD4KEmSJEmSJKkVBh8lSZIkSZIktcLgoyRJkiRJkqRWGHyUJEmSJEmS1AqDj5IkSZIkSZJaYfBRkiRJkiRJUisMPkqSJEmSJElqhcFHSZIkSZIkSa0w+ChJkiRJkiSpFQYfJUmSJEmSJLXC4KMkSZIkSZKkVhh8lCRJkiRJktQKg4+SJEmSJEmSWmHwUZIkSZIkSVIrDD5KkiRJkiRJaoXBR0mSJEmSJEmtMPgoSZIkSZIkqRV7fPAxIpZFxNER8ciIeH1EfDMiNkdE1m7HtzyG4xvnG+Z2szbHJkmSJEmSJI1qxbQHME0R8RHg7sC6aY9FkiRJkiRJWmr26OAjcFMWZuDxHGDHgNtuaXMgkiRJkiRJ0qj29OBj3Vbgh8D3gL2Bh09xLMdn5tlTPL8kSZIkSZI0b3t68PHdwO8oAccfZeZ2gIg4iekGHyVJkiRJkqRFb48OPmbmC6Y9BkmSJEmSJGmp2uO7XUuSJEmSJElqh8FHSZIkSZIkSa0w+ChJkiRJkiSpFXt0zccF7OURcX3gcGAdcAlwHvBN4PPAJzJz5/SGJ0mSJEmSJM3N4OPC9JDG8kHV7YbA3wBnRcTTMvMTEx+ZJEmSJEmSNCCDjwvXxcBllMzHA+icIn9N4OMR8fLMfO6oJ4iIn/R46MhRjylJkiRJkiTNsObjwvFn4PXAPYADM/OAzDwiMw+iBB/vD/xfY5/nRMRTJjxOSZIkSZIkaSBmPi4M3wOumplbuj2YmZcCH4uIjwPPBV5ae/ifI+Kjmfm7YU+amUd3W19lRF5/2ONJkiRJkiRJdWY+LgCZuaFX4LGxXWbmPwFvqa1eDTyjtcFJkiRJkiRJIzL4uDg9D7i8tnzCtAYiSZIkSZIk9WLwcRHKzD8Dp9ZWHR4Rh01rPJIkSZIkSVI3Bh8XrzMbywdPZRSSJEmSJElSDwYfF6/LG8trpzIKSZIkSZIkqQeDj4vXIY3lC6cyCkmSJEmSJKkHg4+L1+1q97cD505rIJIkSZIkSVI3Bh8XoYi4J3Ct2qr/y8zN0xqPJEmSJEmS1I3BxxZExBERkbXbKX223WvIYx8GvLWx+uThRylJkiRJkiS1y+Dj9D04Ik6NiPtGxKp+G0bEXYBvAVerrf4B8J42ByhJkiRJkiSNYsW0BzBNEXF/4FVdHlrfWH5fRDS7SwM8IzM/Ooah3L66XRIR/wf8EPgjsIHSxfoawF2BGzX2Ow84MTN3jWEMkiRJkiRJ0ljt0cFHYB/gyAG2u3Kf/cdpP+De1W0upwEPz8yzxzwGSZIkSZIkaSycdj193wXeCfwMyDm2TeAbwMOB22bmr1semyRJkiRJkjSyPTrzMTNPpoVmLVU2Ygy47Y+BRwNExH7ATYCrA1cC9gK2ApcAZwPfzsxLxz1eSZIkSZIkqQ17dPBxocnMS4CvTHsckiRJkiRJ0jg47VqSJEmSJElSKww+SpIkSZIkSWqFwUdJkiRJkiRJrTD4KEmSJEmSJKkVBh8lSZIkSZIktcLgoyRJkiRJkqRWGHyUJEmSJEmS1AqDj5IkSZIkSZJaYfBRkiRJkiRJUisMPkot+945F/GpH/yBizdvm/ZQJEmSJEmSJmrFtAcgLWV/uORyPnL6uQBcsnkbjzjuiOkOSJIkSZIkaYLMfJRadO4ll19x/6wLN7Erc4qjkSRJkiRJmiyDj1KLLtuy/Yr7W3fs4pLN2/tsLUmSJEmStLQYfJRatGHLjo7lP156eY8tJUmSJEmSlh6Dj1KLNlzemen4x0u3TGkkkiRJkiRJk2fwUWrRZbMyHw0+SpIkSZKkPYfBR6lFG7Z0Zj6e57RrSZIkSZK0BzH4KLVk566cVfPx4s3buXzbzimNSJIkSZIkabIMPkot+fOmrWSX9edd5tRrSZIkSZK0ZzD4KLXkgsu2dl1vx2tJkiRJkrSnMPgoteT8HhmO59l0RpIkSZIk7SEMPkotOb9n5qPBR0mSJEmStGcw+Ci1pJ75eNDeqzvW79zVrRqkJEmSJEnS0mLwUWrJBRt2Zz4eefDeLItyf8eu5MKN3bMiJUmSJEmSlhKDj1JLLqhlPh6wdiVXqmU/OvVakiRJkiTtCQw+Si05f8PuAOP6vVZy2L5rrlg+z47XkiRJkiRpD2DwUWpJveHMPmtWcti+e12xbOajJEmSJEnaExh8lFqwY+cu/ryxHnxc0ZH5aPBRkiRJkiTtCVZMewDSUvTnTduoN7Rev2Ylq1bsjvVv3LqDDVu2s37NyimMTpIkSZIkaTLMfJRacH6t2czqFctYtWIZ69esZO/Vu+P955n9KEmSJEmSljiDj1ILmvUeZzj1WpIkSZIk7UkMPkotqGc+rt9rd7ZjZ/DRjteSJEmSJGlpM/goteCCDd0zHw+147UkSZIkSdqDGHyUWnBBLfNxnzXdMx8v3LiV7Tt3TXRckiRJkiRJk2TwUWpBx7TrWubjlfZezYplAcCuhAtqtSElSZIkSZKWGoOPUgvqDWfW1zIfly8LDtnHuo+SJEmSJGnPYPBRasEFG+rTrld2PHZovenMZdZ9lCRJkiRJS5fBR2nMtu/cxZ83bbtieZ+9OoOPHR2vLzH4KEmSJEmSli6Dj9KYXbhxK5m7l+vTrgEOq3W8Pu+yy8n6xpIkSZIkSUuIwUdpzOr1HvdauZyVyzufZofWaj5u2b6Ly7bsmNjYJEmSJEmSJsngozRmnZ2uV8x6fK9Vy1lVC0hu3mbwUZIkSZIkLU0GH6Uxu+Cy3s1mZuy1avkV9y/ftrP1MUmSJEmSJE2DwUdpzC7YsHvadbfMRyjTsWdsNvgoSZIkSZKWKIOP0pjVp103O13PqGc+btlu8FGSJEmSJC1NBh+lMas3nDHzUZIkSZIk7ckMPkpj1tlwZoCaj2Y+SpIkSZKkJcrgozRm9ZqP+/TIfFy70oYzkiRJkiRp6TP4KI3Rth27uGjTtiuWB+p2beajJEmSJElaogw+SmP0p41bO5Z71nw0+ChJkiRJkvYABh+lMarXe9x/7UpWLO/+FNvLadeSJEmSJGkPYPBRGqMLasHHg9ev6bmdmY+SJEmSJGlPYPBRGqPzL9s97frgfVb33K6e+bh5245WxyRJkiRJkjQtBh+lMbpgw+7Mx0P26Z35uHbV7lqQW7fvYldmq+OSJEmSJEmahokEHyNi30mcR5q2eubjIQNmPiawxanXkiRJkiRpCZpU5uMfIuLdEXH7CZ1Pmop6w5l+mY+rVy4jass2nZEkSZIkSUvRpIKPewEPA74SEWdGxNMj4qAJnVuamAvqNR/7NJxZFsGalTadkSRJkiRJS9ukaz4GcG3gn4HfR8SHIuIeEx6D1JrzazUf+zWcgUbHazMfJUmSJEnSEjSp4OMrgT821q0E7g98JiLOjogXRMTVJjQeaey27tjJJZu3X7Hcb9o1dNZ9NPNRkiRJkiQtRRMJPmbmc4CrAycCnwRmIi1R3a4OvBA4KyL+JyJOjIjl3Y4lLVT1KdcAB+09eObjZjMfJUmSJEnSEjSxadeZuSszP5mZJ1KCjc8FftXYbDlwd+AjlGnZr4iIa09qjNJ8XFCbcn3gulWsWtH/6VXPfLTbtSRJkiRJWoomXfMRgMw8LzNfkZnXAe4IvB+YSRubyYY8BHgG8POI+EpEPDQi+qeSSVN0fr3ZzBxTrsHMR0mSJEmStPRNJfhYl5mnZubDgcOAJwNnNDYJ4PbAe4A/RMS/R8Qxkx2lNLfzL6s1m1k/d5x8rTUfJUmSJEnSEjf14OOMzLw0M9+QmccCNwPeClxWPTyTDbk/8CTg+xHxrYh4TESsm86IpU4XbNid+XjIHJ2uwW7XkiRJkiRp6Vswwce6zDw9Mx9PyYY8CTgfyOo2E4i8GfA24NyIeJ2dsjVtf6oFHw9eP8C0azMfJUmSJEnSErcgg48AEXEo8FTgecDBtYdyZpPq6z7AE4FfRMQ/RcTKiQ1Sqtm4ZccV9/fZa8Wc25v5KEmSJEmSlrq5IyQTFBHLgHsDjwXuSel+fcXD1dc/Au8DbgjctbZ+NfBs4OYRcc/M3DWRQUuVTdt2Bx/Xrhog+GjmoyRJkiRJWuIWRPAxIq4FPBp4JHDozOraJruAL1DqQH4qM3dW+10NeBzweGC/ap+7AE8A3jCJsUsz6h2r161e3mfLorPb9Y4+W0qSJEmSJC1OU5t2HRGrI+LhEfEV4EzgmZQajzM1HaFkOb4MODIz75mZH58JPAJk5u8y87nAtYEv1Q7/iIl8E1LNpq27A4jrhsx83L4z2bHLZF1JkiRJkrS0TDzzMSJuQplW/RBg35nVtU12AV+kkeXYT2b+OSJOAs6hTNU+apxjlgZRn3a9bvXcT63m1OzLt+1k/ZoFW4ZVkiRJkiRpaBMJPkbEvsDDKEHHG82sZnf3aihZju8E3p6Z5wx7jsz8Q0ScA1wTWDfvQUtD2rx1d5x87aq5p12vXB4sj2Bnlh5KJfhovyRJkiRJkrR0TCrz8Y+UhjCwO+hI9fULwNuATw6S5TiHDfPcXxrZxtq0670HyHyMCPZatfyK/Ww6I0mSJEmSlppJBR/X0JnleB7zyHLs44+UxjPSRO3YuYutO3bXbFw7QPARSt3HK4KP2ww+SpIkSZKkpWWSNR/HneU4+wSZ9xr3MaVBbG5kLa4bYNo1dHa8NvNRkiRJkiQtNZMKPr6c8Wc5SgtGvdM1zG4m00u94/VmMx8lSZIkSdISM5HgY2Y+bxLnkaZlU63ZzKrly1i1YrCu1WY+SpIkSZKkpWxS3a5vX929PDO/M4/jHAvsDZCZXx3H2KRx2Lxtd+bj2tWDTbmGRvDRzEdJkiRJkrTETGra9SmUmo+/Aq47j+O8AzimOtYk61VKfdU7Xa8bcMo1dE67NvNRkiRJkiQtNZMM4AW7u13P9zjSgrK5Nu163RCZj2vNfJQkSZIkSUvYYIXpxiMneC5pojbVp12b+ShJkiRJkgRMNvg4DjORmh19t5ImbNOImY/WfJQkSZIkSUvZYgs+HlZ93TjVUUgN9YYzo9Z83GzmoyRJkiRJWmIWTfAxIu4IHEiZvn3OlIcjdejMfBwi+FjLfNyybSeZVieQJEmSJElLx1gbzkTEMcCN+2yyPiL+aohDLgP2BW4APLi2/rThRye1p7Pm4xDTrmuZjzsz2bZzF6tXDL6/JEmSJEnSQjbubtf3A17Q47EADgbeOeKxZ7pcJ/CfIx5DasWmrbuDj3uPmPkIpe6jwUdJkiRJkrRUtDHtOubeZOjj1QOPL8zM7475HNK8bK41ixmm2/WKZctYtXz309CO15IkSZIkaSkZd+bjjF4ByGEDkzsozWXOpky1fmdmfmce45JasbGW+ThMt2so2Y/bLt8F2PFakiRJkiQtLWMNPmbmi4EXN9dHxC5K1uKvM/M64zyntBB0dLseYto1lLqPl16+HTDzUZIkSZIkLS2T7HY97unY0oJR73Y9TMMZ6Kz7aOajJEmSJElaStqadt00kw150YTOJ01UveHMuiFqPkJnx2szHyVJkiRJ0lIykeBjNR1bWrLqDWeGnnZdy3zcbOajJEmSJElaQiY57VpasjZtG73hzFozHyVJkiRJ0hJl8FEag/q067XDTru25qMkSZIkSVqiDD5K87Rtxy6278wrlveex7RrMx8lSZIkSdJSMraajxFRj5pkZq7o8dg4dBxfmqbNtSnXAGuHnHbd0XDGzEdJkiRJkrSEjDOAF0BWX4d5TFrUNm5tBB9XDhl8NPNRkiRJkiQtUeOedt0vuGjgUUtSvUP1mpXLWLF8uKdVPfOxmUUpSZIkSZK0mI0z8/FRIz4mLWr1ZjPrhmw2A53Bx63bd7Erk2VhrF6SJEmSJC1+Yws+Zua7RnlMWuw2bd2d+ThsvUfo7I6dlABkfSq2JEmSJEnSYmW3a2meNm2bX+bj6pXLOmoSOPVakiRJkiQtFQYfpXmqBwvXrR4++LgsgjUrbTojSZIkSZKWHoOP0jxtrE+7HnG6dEfH620GHyVJkiRJ0tIwzoYzYxURa4FrAyuBczLzT1MektTV5lrDmb1HyHyEzqYzZj5KkiRJkqSlYiKZjxGxd0Rcs7pdZY5trxwRHwQuBk4HvgWcFxFfi4hbTGK80jA2batnPo4YfKxlPm4281GSJEmSJC0Rk5p2/a/AL6vbP/baKCIOBU4DHkjJeIza7TbA1yPixLYHKw1j09Z6zccRp13XMh+3mPkoSZIkSZKWiEkFH+8LVzT0fWOf7d4EXLW6n43HkjJN/L0RcfXxDk8a3XwbzoCZj5IkSZIkaWlqPfgYEUcAh1KChz/LzF/22O5o4ER2Bx0vBP4euBfwHGBj9dhewIvaHLM0jE21hjPrRm04Y81HSZIkSZK0BE2i4cz1a/e/2We7R1RfA7gcOC4zz6rWfS4iTgO+XC0/KCKekJlbxjtUaXj1adej1nxca7drSZIkSZK0BE1i2nV9ivTP+mx3z+prAh+sBR7LysxTgFOqxbXAsWManzQvm7bZ7VqSJEmSJKmbSQQf96ndv7jbBhFxJeAGtVX/3eNYp9TuHzW/YUnjUa/RuHbUhjNmPkqSJEmSpCVoEtOuV9buN5vIzLgNuxvSbAdO7bHd72v395/nuBa0iFgG3Bo4EjgMuBQ4F/haZnYN4mo6Nta7XY847drMR0mSJEmStBRNIvi4oXb/wB7b3KH6msD3MvPyHtvVg5er5jswuCLIdxRws9rtRpTGNjPuWE37bl1ErACeCTwBuHKXTbZFxKeAp2fm2ZMYk/rbXG84M4Zu12Y+SpIkSZKkpWISwcdza/d71Wk8oXb/632OdUDt/saRR1SJiI8AdwfWzfdY4xARhwCfpgRAe1kFPAC4a0T8VWZ+YiKDU0/1mo9rx9DtetvOXezYtYsVyyZRFUGSJEmSJKk9kwg+nl59DeCEiDgwM/8882BE3IUytXjG//Y51nVq9/84hrHdlIUTeNwL+ASdgcdzgfcCv6Zkjd4TuH312D7AByLiTpnZr4u4WpSZHd2uR818bHbJvnzbTtavMfgoSZIkSZIWt9ajG5n5G+D7lCnT64BPRcQNImJ1RNwReCe7p1NfSP/g4y1r93855qFuBb4DvIUS8Ju0l9D5/X0YODIzn5WZb8/MV2bmHYCHUepiAqwBPhgRayY8VlW27tjFrloxgHUjNpxZuTxYHnHFslOvJUmSJEnSUjCp1KqXs7uhzC2BHwCbgS8BV6keS+C1mdk16hIRhwPHVItbgB+PYVzvBv6GkgG5PjNvkZmPp38AdOwi4qrAk2qrfgg8NDO3NrfNzPcDL6ituhrwxHZHqF7qWY8wesOZiOis+2jTGUmSJEmStARMJPiYmR+hZBTOBCCjdpvJG/s28Oo+h3nYzOGAb2fmjj7bDjquF1RZhadn5va592jN4ylZjDOeMcd4/pXOWppPbWNQmtumrZ1BwnrtxmHZ8VqSJEmSJC01Eysql5lPoATZzmk8tAV4M3CXzNzWbd+IWMXuzMAA/qetcU7J/Wr3zwG+0G/jKvD6ztqqq0ZEvyY1akm92cy6VctZtiz6bN2fHa8lSZIkSdJSM4mGM1fIzLcCb42IawCHUqZe/6xX0LFmf+DZteXPtTTEiat+FkfVVn0pM7PX9jVfBJ5XW74P8N1xjk1z21zvdD1is5kZZj5KkiRJkqSlZqLBxxlVE5rfDLH9+cC72hvRVN2osXzagPt9G9jB7t/hMX22VUs21qZdr1s1+pRr6Mx83GzmoyRJkiRJWgImNu1aPR3VWP7VIDtl5hbgD7VV1x/biDSwzbWGM+vmm/lowxlJkiRJkrTEGHycvms2ln87xL71bZvH0QRs2lbPfBzjtGszHyVJkiRJ0hJg8HH69mksXzTEvhfX7q+MiNVjGI+GsGlrvebj/KZdr7XhjCRJkiRJWmKmUvMxItZSah0eBewHrKN0sR5YZr5k/CObir0by1uG2PfyLsfaOujOEfGTHg8dOcQY9mgd3a5tOCNJkiRJktRhosHHiLgBpUPzfYH5ZuktleDjmsbyXJ2/65qBxr3mORYNaXNLDWfMfJQkSZIkSUvBxIKPEfF44LXVOWeyHJMhMx5r+y0VzUzHVV3W9dIM4DYzIfvKzKO7ra8yIm1gM4CN9WnXY6z5uNnMR0mSJEmStARMJPgYEScCb6wW64HDpNQ43DiJcSxQze99DYMHH5uZjnvyz3EqNtemXe89xm7XW7btJDOJGCU2L0mSJEmStDC0HnyMEj35t2pxJtPxv4C3At/OzGFqHC5FlzWW9wcuGXDf/Wr3t2fmwPUeNR71btfzbThTz3zcmcn2ncmqFQYfJUmSJEnS4jWJzMebA0ewO+PxMZn5zgmcd7H4TWP56l3W9XJ47f5Z4xmOhlHvdr1uvtOuGzUjN2/bwaoVq+Z1TEmSJEmSpGlaNoFz3Lh2/38NPM7y08bytQbZKSLWAFfucxxNQEfDmXlOu16xbBmrlu9+StrxWpIkSZIkLXaTCD4eULv/2Qmcb7H5QWP5uAH3uwWdmas/Gs9wNIxN2+qZj/Obdg12vJYkSZIkSUvLJIKPf67dv3gC51tUMvM3wM9rq+4Sg3UZuWtj+dPjG5UGVZ92vXaemY/QWffRzEdJkiRJkrTYTSL4eHbt/kETON9i9LHa/cOBu/XbOCJWAI+qrToX+G4L49Ic6g1n9p5nwxkw81GSJEmSJC0tkwg+nsLu7Mc7TuB8UxcRR0RE1m6nzLHLm4F6p+pXRcTKPts/HbhKbfm1mZm9NlZ7NtczH+fZcAbMfJQkSZIkSUtL68HHzNwOvBEI4G4RceO2z7nYZObvKD+jGccA74uI1c1tI+IhwItrq84F3tDuCNXNrl3Zkfk4327X0Jn5uNnMR0mSJEmStMjNP1oymJdSahQeB3wkIu6Ymb+d0Ll7ioj7A6/q8tD6xvL7IuLyLts9IzM/OqbhPB+4PXCzavlBwK0j4j3AWcD+wL2AO9T22Qr8ZWZuGdMYNIRmZuK6MUy7XmvmoyRJkiRJWkImEnzMzJ0RcU/gv4B7Aj+IiFcA78rM8ycxhh72AY4cYLsr99l/LDJzc0ScAHwGOLZafRXgWT122QA8MjO/Pq4xaDj1TtcA68bRcMaaj5IkSZIkaQmZSPAxIr5c3V0G7AL2BV4BvCIizgHOA4bJ3svMvPN4Rzl9mXleRNyKEnB8AnBol822UQKU/1B1ytaUbNq6Ozi4LGD1ivlXMegIPpr5KEmSJEmSFrlJTbs+Hqg3RElKDUiAIygdngcVjWONLDNPBk4ex7Eaxz2b3d/fsPtuB14aES8Hbg1cCziEkun4e+BrmXnRmIaqedhUazazbvUKIkb6lXfoaDhj5qMkSZIkSVrkJhV8hP7BuPlHbZaYzNwJfK26aQHaPOZmM2C3a0mSJEmStLRMKvj4rgmdR5qYeubj2jE0mwFrPkqSJEmSpKVlUg1nHjWJ80iTVG8400bm45btO9mVybIxTOeWJEmSJEmahvl3yJD2UJtrDWfWjSnzcW0tiJnA1u27xnJcSZIkSZKkaTD4KI1o49bxZz6uXrmsowDq5lp2pSRJkiRJ0mJj8FEaUT0wuHb1eIKPyyJYY9MZSZIkSZK0RBh8lEa0qdYQZu8xTbuGRtMZg4+SJEmSJGkRm1S361ki4l7AXYFbAlcF9gfWAr/KzOs2tl0J3KRa3JmZ35vkWKVuNte7XY9p2jV0Np2x47UkSZIkSVrMJh58jIiHAS8FDq+v7nEfgMzcHhHvBq5dHeMmmfnDVgcqzWFjveHMKjMfJUmSJEmSmiY27ToiVkTEfwHvpgQeo3aD0ty3nzfWtn14K4OUhlCv+bhuTDUfwcxHSZIkSZK0dEyy5uP7gAezO+C4Hfgs8GLgCdW6fgHID9Uev2d7w5QGU6/5OK6GM9DIfDT4KEmSJEmSFrGJTLuOiL8EHkQJHgbwEeDJmfnH2jZv6neMzDwvIr4PHAtcPyIOzMw/tzhsqa9NtZqPY512Xct83Oy0a0mSJEmStIhNKvPxxbX7b8nMB9UDj0M4vXb/BvMckzQvHcHHMWY+rjXzUZIkSZIkLRGtBx8j4vqURjEJ/A74+3kc7he1+0fOZ1zSfG3eVm8401LNRzMfJUmSJEnSIjaJzMdja/f/OzO3zuNYl9Tu7z+P40jzVs98XLu6pW7XZj5KkiRJkqRFbBLBx0Nq98+c57HqkZhV8zyWNC+bat2u926r27WZj5IkSZIkaRGbRPCx3sF6vhGaA2r3L57nsaSR7dyVbNm+64rlteNsOGPmoyRJkiRJWiImEXy8oHZ/vnUab1y7f/48jyWNrJ71CO3VfNy2cxc7du3qs7UkSZIkSdLCNYng409r9+8z6kEiYjVw99qq00YekTRPm7d2ZiSOt9t157HMfpQkSZIkSYtV68HHzDydkqUYwHUj4pEjHuqJwJUo07h/mpl/HNMQpaHVMx9XLg9WrRjfU2nl8mB5xBXL1n2UJEmSJEmL1SQyHwHeUX0N4E0Rcddhdq62f3lt1evGNTBpFB2drsc45RogIlhj3UdJkiRJkrQETCr4+EpK7ccE9gL+JyLeFBHX6bdTRBwQEa8APk3pbp3AL4D/bHm8Ul+batOux9npesZaO15LkiRJkqQlYPxRky4yc2NEnAh8iRJ8XA78LfC3EXEW8JPa5gdExJuBo4FbVdvOzEG9DDgxM43GaKo2b6tnPo6v0/UMO15LkiRJkqSlYFKZj2TmacAJdHa/DkoH7BMoWY0A+wN/A9yGzuDoH4F7ZuaZ7Y9W6m9jfdp1C5mPe5n5KEmSJEmSloCJBR8BMvMrwDHAycD22kPR2DRq63YC7wGOrQKY0tRt3lafdt1u5uNmMx8lSZIkSdIiNZFp13WZ+Sfg0RHxbOCBwO2AGwEHAvsBm4ELgZ8DXwE+nJnnTHqcUj9tNpyBxrRrMx8lSZIkSdIiNfHg44zMPB94Y3WTFpV6w5l1bdR8rE273mLmoyRJkiRJWqQmOu1aWirqDWfWtdHt2mnXkiRJkiRpCTD4KI1gU8vBRxvOSJIkSZKkpWCi064jIoBjq9uVgAOAfYBLgYsotR6/m5lnTHJc0rDq067XtjHtul7z0cxHSZIkSZK0SE0k+BgRtwCeAdyZEmyca/tLgC8C/5KZ32t3dNLw6g1n9jbzUZIkSZIkqatWp11HxCER8Wngm8D9gH2BqG5dd6lu+wMPAr4dER+PiIPaHKc0rHodxla6Xa/szHzMzLGfQ5IkSZIkqW2tBR8j4ijgNOCe7A421iMo0eVGY7sATgC+GRHXaWus0rA2bq3XfGx32vXOTLbvNPgoSZIkSZIWn1amXUfE1YCvAgdSAolJCSRuomRBfgM4G7gY2AisB/YDrgkcB9wKWMfuIOQ1gVMj4maZeW4bY5aG0dHtuo3Mx0Ydyc3bdrBqxaqxn0eSJEmSJKlNbdV8/E92Bx4D+BPwz8DbM3PDXDtHxD7A3wL/SGlMk8AhwH9QMimlqepoONNC5uOKZctYtXwZ23buAkrdx/3GfhZJkiRJkqR2jX3adUTcndJYZiZr8TvAsZn5b4MEHgEy87LM/BdKV+zvsntK9t0i4s7jHrM0rE0tZz5Co+O1TWckSZIkSdIi1EbNx6dUXwP4HXD3UadKZ+bvgXtUx5kJZj51vgOU5mtzLfNxXQvdrmF20xlJkiRJkqTFZqzBx4g4ALhLtZjAYzPzkvkcMzMvAh7L7qY0d4uI/eZzTGk+tu3YdcV0aGin4Qw0Mh8NPkqSJEmSpEVo3JmPd6TUkUzgR5n5pXEcNDO/CPyoWlwB3Gkcx5VGUW82A7C2rWnXK512LUmSJEmSFrdxBx9vVbv/7jEfu368W/XcSmrZpkYW4rpVZj5KkiRJkiR1M+7g4/Vq97815mOfVrt//TEfWxrY5q27Mx9Xr1jGiuVtlE6FtbXMx81mPkqSJEmSpEVo3FGTI2r3vzfmY59eu3/4mI8tDWxjLfjYVrMZMPNRkiRJkiQtfuMOPh5cfb08M7eM88CZeTmwmdJ05uA5Npdas3lbvdN1O1OuAdZY81GSJEmSJC1y4w4+7k1pNnPJmI87Y+a461s6vjSnTfXMx5aazQCsNfNRkiRJkiQtcuMOPq6uvm4e83FnXF59XdXS8aU5bap1u17bUrMZsNu1JEmSJEla/MYdfGyn88ZsMaHzSLNs2lqfdm3NR0mSJEmSpF4mFSyUlozN2yYz7bqe+bhl+052ZbZ2LkmSJEmSpDYYfJSGtLGW+bi2xYYza2uBzQS2bt/V2rkkSZIkSZLa0Fba1vqI+Ks2jtvCMaWhbK41nNm7xWnXq1cuIyiBR7DuoyRJkiRJWnzaipwcDLyzpWNLU7WpVn9xbYvTrpdFsGbl8iuCjvXp3pIkSZIkSYtBe5GTdprCWPROU7dpa73mY3vTrqE0nZkJPpr5KEmSJEmSFps2go9tdqK2y7WmrqPhTIvTrqGz6YwdryVJkiRJ0mIz7sjJo8Z8PGnB2VRrOLOuxYYzUDIfZ5j5KEmSJEmSFpuxBh8z813jPJ60EG2qZT62WfMRzHyUJEmSJEmL27JpD0BabDZNqNs1NDIfDT5KkiRJkqRFxuCjNKTNHd2uW552vdJp15IkSZIkafEy+CgNaePWyTWcqQc3N5v5KEmSJEmSFhmDj9IQMrMjCDjJbtf1LtuSJEmSJEmLgcFHaQhbd+xi5668Ynldy9Ou16/ZHdzcsMXgoyRJkiRJWlwMPkpDqDebAVjbcubj+jUrr7i/YcsOMrPP1pIkSZIkSQuLwUdpCM26i2tXtpv5uM9eu4OP23bu6qg3KUmSJEmStNAZfJSGsKlWd3HtquUsWxatnm/tquXUT3HBhq2tnk+SJEmSJGmcDD5KQ6hPu167qt0p1wDLIjqmXp9/2ZbWzylJkiRJkjQuBh+lIWzaunva9d6r251yPaPedOaCy8x8lCRJkiRJi4fBR2kIm7dNNvMROpvOXLDBzEdJkiRJkrR4GHyUhrCxlvm4bkKZj/vUMh/PN/NRkiRJkiQtIgYfpSFMJ/OxNu3ahjOSJEmSJGkRMfgoDaGz5uNkgo/72HBGkiRJkiQtUgYfpSF0drueRsMZg4+SJEmSJGnxMPgoDWFTbdr1ugllPnY2nNlKZk7kvJIkSZIkSfNl8FEawuZpNJzZa3fwcfO2nWysZV9KkiRJkiQtZAYfpSFsnELDmbWrlrMsdi/bdEaSJEmSJC0WBh+lIWyuZR2um1DNx2URHVOvbTojSZIkSZIWC4OP0hA2batPu55M5iM0m86Y+ShJkiRJkhYHg4/SEOrdricbfKw3nTHzUZIkSZIkLQ4GH6UhbK5lPq6d0LRrgH1qmY/nm/koSZIkSZIWCYOP0hDqmY97T2natTUfJUmSJEnSYmHwURpCPfg4qW7XAPt0TLs281GSJEmSJC0OBh+lAe3alWzeXm84M7lp1x01H818lCRJkiRJi4TBR2lAW3bsJHP38tS6XW/YStYHIkmSJEmStEAZfJQGtLE25Rpg3SSnXe+1O/Nx87ads8YiSZIkSZK0EBl8lAa0eevuKdfLAtasnNzTZ+2q5SyL3ct2vJYkSZIkSYuBwUdpQJu27c42XLdqBRHRZ+vxWhbRWfdxg3UfJUmSJEnSwmfwURrQplrm49oJNpuZ0VH30cxHSZIkSZK0CBh8lAbUzHycNDMfJUmSJEnSYmPwURpQvebjJDtdz9inlvlozUdJkiRJkrQYGHyUBrSp1mF67arpTrs+/zIzHyVJkiRJ0sJn8FEaUMe066lkPtanXZv5KEmSJEmSFj6Dj9KANm+b7rTrjpqPZj5KkiRJkqRFwOCjNKCNW+sNZ6bc7XrDVjJz4mOQJEmSJEkahsFHaUCbO2o+TmHa9V67Mx83b9vZEQyVJEmSJElaiAw+SgPaVJt2vffqyWc+rl21nBXL4oplO15LkiRJkqSFzuCjNKCObtdTqPm4LIKD1q++YvmCDdZ9lCRJkiRJC5vBR2lA9czHadR8BDh4nzVX3L/AzEdJkiRJkrTAGXyUBlSv+TiNbtcAB9cyH8+347UkSZIkSVrgDD5KA9o45YYzAIfsU592beajJEmSJEla2Aw+SgPaXJ92PYWGMwAHr9897drMR0mSJEmStNAZfJQGtHnb9Kddm/koSZIkSZIWE4OP0oDq067XTWnadWfDGTMfJUmSJEnSwmbwURrAzl3Jlu27rlheO61u1x0NZ7aSmVMZhyRJkiRJ0iAMPkoDqE+5Bth7atOud2c+Xr59Z0c2piRJkiRJ0kJj8FEawKatOzuW106p4cwBa1exYllcsXz+ZdZ9lCRJkiRJC5fBR2kAm2qZjyuWBauWT+eps2xZcND6etMZ6z5KkiRJkqSFy+CjNIDNtczHdatXEBF9tm5XZ9MZMx8lSZIkSdLCZfBRGkBnp+vpTLme0dl0xsxHSZIkSZK0cBl8lAZQbzizdkrNZmYcsk992rWZj5IkSZIkaeEy+CgNYNO2zmnX03TI+t3Trs18lCRJkiRJC5nBR2kAmxbStGszHyVJkiRJ0iJh8FEaQD34uHbVdDMfOxvOmPkoSZIkSZIWLoOP0gA216Zd7716ITWc2UpmTnE0kiRJkiRJvU03hWuBioijgWOAKwM7gXOB72bmb6Y6ME1NR+bjtGs+1jIfL9++k0s2b2f/daumOCJJkiRJkqTuDD7WRMQDgedTAo/dHv8G8NzMPKWFcx8PfGXE3W+emd8d32jUtGnbwqn5eOC6Vey710ouvXw7AGeev4FbXfPAqY5JkiRJkiSpG6ddAxGxPCLeCXyIHoHHyq2B/42Il05mZFooNm9dON2uI4LrHrr+iuUzz9swxdFIkiRJkiT1ZuZj8RrgpNryZuB9wBnAKuCWwAOAlZSA7fMi4qLMfE2LYzoH2DHnVoVdR1q2saPb9fSfNtc7dD3f/s1FAPzc4KMkSZIkSVqgph9FmbKIuDfwd7VVPwXukZm/a2x3I+B/KHUgAf41Ir6UmT9qaWjHZ+bZLR1bQ6o3nFk75YYzANc5ZHfm4y/ON/goSZIkSZIWpj162nVELANeXlu1GTihGXgEyMwfAA8CdlWrmvtqCavXfNx7ytOuoWQ+zvjFeRvseC1JkiRJkhakPTr4CNyZzhqPr8vMs3ptnJnfoNSFnHGfiLhWW4PTwtHR7XoBTLu+Ti34uGHrDs695PIpjkaSJEmSJKm7PT34eL/G8n8MsM/bG8snjmcoWsg21RvOTLnbNcA+a1Zylf32umLZpjOSJEmSJGkh2tODj/eu3f91Zv56gH2+RmeDl/uMd0haiDbXpl1Pu9v1jHrHa5vOSJIkSZKkhWiPDT5GxH7A1WurThtkv8zcBnyvtuqYXttq6ejIfFwADWegM/ho5qMkSf+/vfuOj+sq8z/+fWbUJcuSe4m705wGCYnTkw0JLQtLC7AhhLIsgVAWtoT9hYXfssvSFn6EXTphwyYQ6kJCykIIgfQKpOEUd8d27LhILurl+f1xp9y51sgz1ozmSvN5v17z8j1nzr33kXQsXT06BQAAAHFUtclHSUdHymuKODc8QrLdzOaUIJ6oT5vZo2bWYWb9ZvaCmT1uZt80s9ebWTwyYFWgf3BY/UPDmXIc1nyUIpvOsOM1AAAAAACIoWpOPi6NlDcVcW60bfRapfCXkk6Q1CapVtJMScdJeo+k/5H0rJn9RRnui4ie/qGcchynXa/dsV8DoQQpAAAAAABAHFRz8rE1Ut5dxLkdkfKUEVuNXYekjZJ2SopmlpZKusHM/q1M90bK/tB6j5LUFIMNZyRp6YwW1SRMkjQw5Fq3o6vCEQEAAAAAAOSq5uRjS6TcO2KrkfUc5FqHapek/5T0CknT3X2auy9295mSpkl6vaR7I+dcaWZ/cyg3M7M/jfSStGwsH8Rk092XTT7W1SRUm4zHf5u6moSWzmzOlJ/etreC0QAAAAAAABwoHlmUymiIlPuLOLcvUm4cYyxSsInNYe7+IXf/lbvnjMR09z3u/nNJZ0n6eOTcz5nZghLEgBF0haZdt8RkynXakXOyA3jZdAYAAAAAAMRNNScfoyMd64o4tz5Sjo6ELJq773P3g46+9MCnJH0jEs8Vh3DPY0Z6KXdDnarXFRr5GJcp12lHseM1AAAAAACIsWpOPu6PlKMjIUcTHekYvdZ4+CflJj1fXYEYqkI4+dgck52u046cnU0+Pk3yEQAAAAAAxEw1Jx+jC+S1F3FuW6Q87lkfd98l6c5Q1SIzmzvecVSD7tC06+b6eI18DO94vaWzR/t6ByoYDQAAAAAAQK5qTj6uj5QXFnHuokh53RhjOVTPRMqzKhLFJNcV2u26OWZrPh7W3pizDuWz2ysxCBcAAAAAAGBk1Zx8XBUpLy/i3PBu0B3uvq0E8RyK6FqTTRWJYpKL85qPZqYjZmc3W2fdRwAAAAAAECdVm3x0905Jm0JVpxVynpnVSTopVPVECcMq1uxIeWdFopjkuvrC067jNfJRyp16/cy26GoCAAAAAAAAlVO1yceUW0PHy8xsaQHnnKXczWluLm1IRTkrdDwgaUulApnMuvvju+GMxKYzAAAAAAAgvqo9+fjzSPmvCzgn2uaG0oRSHDN7pXKnit/r7t2ViGWy2x8a+dgUsw1nJOnIOa2Z42e275O7VzAaAAAAAACArGpPPt4u6clQ+YNmtiRfYzM7TdJFoapb3H11nraLzcxDr9+Nct3GYoJO7Wr9zUj1d4u5BgoX95GPR4WmXXd2D+iFfX0VjAYAAAAAACCrqpOP7j4s6cpQVbOkm8xsQbStmR0v6SfKfs6GJX2sRKG82czuNLPXpNaUzMvMzpf0oKRwjI9Juq5EsSAi7ms+tjfXadaU+kyZTWcAAAAAAEBcxC+TMs7c/SYz+5qky1NVx0h6ysy+L+lRSbWSTpX0xtRx2kfd/bEShnJ26tVpZvdKelzS85L2KdjFeomkCySdEDlvm6TXphKpKIPwbtfNMdvtOu3IOVMyIx6f2bZPZx8xs8IRAQAAAAAAkHxM+5CkKZLelio3S3pPnrYu6bPu/oUyxdIm6cLU62AekHSJu28oUyxQ7rTrphiOfJSCqdd3rw42O2fTGQAAAAAAEBdVPe06zd2H3P1SSW9W7hqQUQ9IOt/drxylzaF4RNI1kp5SkNwcjUu6T9Ilks5097UljgURXf3ZadctMdxwRpKOCO14/cz2vRWMBAAAAAAAICuew7gqxN1/LOnHZnaspOMlzZM0JGmrpIfdfV0R19ogyQps+6Skd0mSmbVJerGkhZJmSGqU1CepU9IGSQ+5+55C48DYhaddN8VwwxlJOiq04/Xq7fs1NOxKJgrqfgAAAAAAAGUTz0xKhaWSgaONgCznvTsl/bYS98bIctd8jOd/mcNntyhh0rBLfYPD2rCrS8tmtlQ6LAAAAAAAUOWYdg2Mwt3V3R/e7Tqe064bapNaPL05U35yC4NjAQAAAABA5ZF8BEbRNzisweHsMpzNMd1wRpJetLAtc/zbp1+oXCAAAAAAAAApJB+BUYRHPUpSU108Rz5K0gVHz84c3/H0CxoYGq5gNAAAAAAAACQfgVGF13uU4rvhjCSdfcRM1dUE/6X39g7q4fW7KxwRAAAAAACodiQfgVF09WeTj421yVjvIN1cX6Mzlk3PlG9btb2C0QAAAAAAAJB8BEbV1Rf/zWbCLlgxJ3P861Xb5e6jtAYAAAAAACgvko/AKPb2DmSOW2K82Uza+UfPyhxv6ezRU8/vq2A0AAAAAACg2pF8BEbR2d2fOW5rqqtgJIWZ1dqgFy1oy5R/zdRrAAAAAABQQSQfgVF0dGVHPrY31VYwksJdsCK76/Wvn9pWwUgAAAAAAEC1I/kIjCI88rG9Of4jHyXpZaHk45Nb9mprZ08FowEAAAAAANWM5CMwio7u8MjHiZF8XD6rRUtmNGfKtz/F1GsAAAAAAFAZJB+BUXSERz5OkGnXZpY79Zp1HwEAAAAAQIWQfARG0Rka+TgRNpxJCycfH1i3K2fXbgAAAAAAgPFC8hEYRe7Ix4mTfDxxYbump9aoHBhy/e6ZHRWOCAAAAAAAVCOSj8AoOrom3rRrSUomTOcdNStTZuo1AAAAAACohJpKBwDEWUfMpl1f/+CmgtvW1yQzx7f9aZuuvX+DahK5f2+4eOXCksUGAAAAAAAQxchHII/egSH1DAxlyu3NE2fkoxTsel2bNElS3+Cw1u/sqnBEAAAAAACg2pB8BPIIbzYjTaw1HyWpriah5TNbMuUnNu+pYDQAAAAAAKAakXwE8ghvNtNQm1BDbXKU1vF0zPypmeM/PtepvT3seg0AAAAAAMYPyUcgj4m603XY8fOnqrUhWNp1aNh192p2vQYAAAAAAOOH5COQR2fMNps5FDXJhM46fGam/NCG3drfN1jBiAAAAAAAQDUh+QjkER75OG2CbTYTdvLiaWquD0Y/Dgy57l2zs8IRAQAAAACAakHyEchjMox8lIKNZ85aPiNTvn/dLnX3M/oRAAAAAACUH8lHII+OrvCajxN35KMkrVwyTY2pDXP6B4d139pdFY4IAAAAAABUA5KPQB4doZGPE3XDmbT62qTOWD49U75v7U71DgxVMCIAAAAAAFANSD4CeXSG1nycyNOu005bOkP1NcF/+d6BYT2wjtGPAAAAAACgvEg+Anns7p48064lqbEuqdOWZUc/3rNmJ2s/AgAAAACAsiL5COTROYmmXaedsWyGapMmSeruH9L1D26qcEQAAAAAAGAyI/kI5NGRM+164o98lKTm+hqtXJId/fj1363N2VgHAAAAAACglEg+AiMYGnbt6Zl8Ix8l6azDs6Mfd3X1619vXlXhiAAAAAAAwGRF8hEYwd6eAblny5Mp+TiloVYXHD07U/7ZH7fojqe3VzAiAAAAAAAwWZF8BEYQnnKdMGlKQ00Foym905fP0IL2xkz5yp89qb29A6OcAQAAAAAAUDySj8AIOkKbzbQ11SmRsApGU3oJM73+xMNUlwy+BWzb26vP3Pp0haMCAAAAAACTDclHYASdoZGP7ZNks5mo2a0N+uB5yzPlHzy0Sfet2VnBiAAAAAAAwGRD8hEYQXjk42Ra7zHqvecu04q5rZnyR3/2uLr7BysYEQAAAAAAmExIPgIjCI98bJvEycfaZEKff+PxSqamlT+3u0f//qtnKhwVAAAAAACYLEg+AiPoqIJp12nHzp+q956zNFP+7n0bdC/TrwEAAAAAQAmQfARGkDPtunnyjnxM++B5h2v5rBZJkrv0Nz/8o7bv7a1wVAAAAAAAYKIj+QiMoKMrPO16co98lKSG2qS+/JYXqa4m+Jawc3+/Pnj9HzU4NFzhyAAAAAAAwERG8hEYQe6068k/8lGSjpk3VZ98zTGZ8kMbdusLtz1bwYgAAAAAAMBER/IRGEFnzm7Xk3/kY9pbTl6g1714fqb8jTvX6jdPba9gRAAAAAAAYCIj+QiMoKNKdruOMjP92+uO1eGp9R8l6W9//Jie291dwagAAAAAAMBERfIRiHD33A1nqij5KElNdTX6+iUnqrE2KUna0zOgD1z/B/UNDlU4MgAAAAAAMNGQfAQiegaG1D+Y3WilmqZdpy2fNUWfef1xmfJjm/fokzetqmBEAAAAAABgIqqpdABA3IRHPUqTe9r19Q9uGvX9UxZP00MbdmfadvUNauWS6XnbX7xyYUnjAwAAAAAAExsjH4GIjq7seo8t9TWqq6ne/yYXHj9XC9obM+WbHtuq9Tu7KhgRAAAAAACYSKo3qwLkEd7puq0Kp1yH1SYTeuvKRZrSEAySHnbp+gc3qjO0IQ8AAAAAAEA+JB+BiPBO19W22cxIWhtrdcnKRUomTJLU1T+k7z2wMWddTAAAAAAAgJGQfAQiwqP6qn3kY9qCaU167YvmZ8pb9/TqZ3/cLHevYFQAAAAAACDuSD4CEbu7stOuGfmYddKidp2+LLvZzOOb9+jOZ3dUMCIAAAAAABB3JB+BiNxp14x8DHvlsXO1dGZzpnzbqu16fHNn5QICAAAAAACxRvIRiMidds3Ix7BkwnTxyQs1rTn7efnp7zdrAztgAwAAAACAEZB8BCI6usPTrhn5GNVUX6N3nLZYjbVJSdLgsOu6BzZq576+CkcGAAAAAADihuQjEBEe+djezMjHkcyYUq9LTs3ugN0zMKTv3r9Bu/aTgAQAAAAAAFkkH4GI8MhHpl3nt2RGs9540mGZ8u6ufr372kfUOzBUwagAAAAAAECckHwEIthwpnAnHNaml6+YnSn/cVOnPvzDRzU07BWMCgAAAAAAxAXJRyBkcGhY+3oHM+V2Rj4e1NlHzNTJi6dlyr/80zZd8dPHNUwCEgAAAACAqkfyEQjp7BnIKbPm48GZmV5zwjwdMbslU/c/f9isf7rxSbmTgAQAAAAAoJqRfARCwpvN1CZNzXXJCkYzcSQTpotPWaSVS7IjIK9/cJP+5eZVJCABAAAAAKhiJB+BkOhmM2ZWwWgmlrqahL7zjpN14sK2TN01927QZ3/5NAlIAAAAAACqFMlHIKSji81mxqKlvkbffdcpOv6wqZm6b965TlfdvrqCUQEAAAAAgEoh+QiEhHe6bmOzmUPS2lCra991io6aMyVT9+XfrNYXfvUMIyABAAAAAKgyJB+BkPC0a0Y+Hrq2pjp9790rtXxWdhOar/x2jT52w5MaYhdsAAAAAACqBslHICQ88rGdkY9jMqOlXte/e6WOnJ0dAXn9g5v0wR/8QX2DQxWMDAAAAAAAjBeSj0BIZ1fuhjMYm1mtDfrRZafqpEXtmbpbn9imd333Ye3vG6xgZAAAAAAAYDyQfARCckc+Mu26FNqa6nTdX52ic4+cmam7d80uXfztB7Rrf18FIwMAAAAAAOVWU+kAgDjpzFnzkZGPxbr+wU1533vpUbPV0dWvxzbvkSQ9vnmPLvjSXbr0tEWaNaVhxHMuXrmwLHECAAAAAIDxwchHICR3t2tGPpZSMmG66CULdNqy6Zm63V39+sada7Xmhf0VjAwAAAAAAJQLyUcgJGe362ZGPpZawkx/ftxcvfyYOZm63oFhffe+9Xpw/a4KRgYAAAAAAMqB5COQ4u7qZLfrsjMznXPETF18ykLVJk2SNOzSjY9u1S2Pb9Wwe4UjBAAAAAAApULyEUjZ3zeoweFs4osNZ8rr2PlT9Z6zlqm1Ibv07L1rd+m6+zeqb2CogpEBAAAAAIBSIfkIpIQ3m5GkqY0kH8ttfnuj3nfucs2bmt1w5pnt+/TNu9blrL8JAAAAAAAmJpKPQEo42dXaUKOaJP89xsPUxlq95+xlWjG3NVO3bW+vvva7tfrDpo4KRgYAAAAAAMaK7AqQsrsrtN4jm82Mq7qahC5euVBnHz4zU9fVN6i3fOsB/eKxrRWMDAAAAAAAjAXJRyAlPO26jc1mxl3CTK84do7ecOJ8JS3YiKZ/cFgf+sEfddXtz8rZiAYAAAAAgAmH5COQ0pGz0zXrPVbKSYum6Z1nLlZjbTJTd9Xtq3XZdb/Xvt6BUc4EAAAAAABxQ/IRSOkIjXxsZ+RjRS2d0aL3nbtMS2c0Z+puW7Vdf/GVe7V6+74KRgYAAAAAAIpB8hFIeWFvb+Z4Gms+VtyMlnr9/PIzdM4R2XUg1+3s0l989V7d8vjzFYwMAAAAAAAUiuQjkLL6hf2Z46Uzm0dpifEytalW//WOk/Wh85Zn6rr7h/T+6/+gT9/6lAaHhisYHQAAAAAAOBiSj4Akd8+Zznv4rCkVjAZhyYTpb192pK6+9CWa0lCTqf/WXet00Tfv1/qdXRWMDgAAAAAAjIbkIyBpx74+7e0dzJQPn9VSwWgwkvNXzNZNHzhTR87OJob/uKlTr/ry3brugY3shg0AAAAAQAyRfASUO+V6Rkud2lnzMZYWz2jWz99/ut5w4mGZup6BIX38hif1jmse1vbQup0AAAAAAKDySD4CElOuJ5Cmuhp98U0n6BuXnKj2ptpM/Z3P7tDLr7pLN/xxC6MgAQAAAACICZKPgHJHPh4+mynXE8Erjp2rX33kbJ131KxMXWf3gD78o0d10Tfu1xOb91QwOgAAAAAAIEk1B28CTH45yUfWe4yN6x/cdNA2Lz1qlqY21OqWJ55Xf2r360c2dug1X7lHJy5q18tWzNaUhmCE5MUrF5Y1XgAAAAAAkIvkI6pedKfr5Uy7nlDMTCcvmaZls1p06xPPa9XzeyVJLun3Gzv05JY9OveImTp12fTKBgoAAAAAQBUi+Yiqt6urXx3dA5ky064npmnNdbrk1EVa88J+3fz4Vr2wr0+S1Dc4rF+t2q671+xUd/+QLj1tUWYkJAAAAAAAKC/WfETVW709O+V6WnOdZrTUVzAajNXyWS364HmH69UnzFNjbTJT390/pH//1TM683O/1VW3P6s9oYQzAAAAAAAoD5KPqHprXghPuWbU42SQTJhOWzpdf3fBETrr8BmqS2a/1e3pGdBVt6/WmZ+7Q5/536e0fW9vBSMFAAAAAGByI/mIqsdmM5NXU32NXnnsXP3Dy4/UuUfOVEt9dqWJfX2D+uad63Tm5+7Q3//kMT0bWvcTAAAAAACUBslHVL3wtGuSj5NTc32NXrZiju796Hn68PmHq7Uhm4QcGHL99Peb9bIv3aV3ffdhPbBul9y9gtECAAAAADB5sOEMql7OyMfZ7HQ9mU1tqtWHzz9C7z5rqX740CZ95571en5Pdtr1HU+/oDuefkEnHDZV7zl7mV5x7BwlE1bBiAEAAAAAmNhIPqKqdXT1a+f+vkyZkY+T2/UPbsocN9XV6PJzl+vxzZ26e/VObQut/fjY5j16//V/0LTmOp25fIZOXNiuupoDB4pfvHLhuMQNAAAAAMBERfIRVS086nFqY61mTmGn62qSTJhevLBdL1rQptUv7Nfdq3do7Y6uzPu7u/r1i8e26vantuvM5TO0csl0NdYlR7kiAAAAAAAII/mIqrY6tNP14bNaZMYU22pkZjpi9hQdMXuKtnT26O7VO/TE5j1Kr/zY3T+k21Zt153P7tCpS6frjOUzcjavAQAAAAAAI+O3Z1S1nM1mZjPlGtL8tka95eSFevmKft2zdqce2bBbA0NBGrJvcFh3PrtD967ZqZMXT9M5R87U/LbGCkcMAAAAAEB8sds1qtqa0LTr5bPYbAZZ7c11evXx8/QPLz9Kf3bkTDXUZr9dDg677l+3S+d8/re64qePad2O/aNcCQAAAACA6sXIR1S16LRrIKqlvkYXrJijsw6fqQfX79Y9a3aqq29QUpCE/PEjm/WT32/Wq46bq8vPXaZj5k2tcMQAAAAAAMQHyUdUrT09A9q+N7TTNdOuMYqG2qTOOWKmTl82XY9s7NDdz+5QZ8+AJMlduuXx53XL48/rjOXT9bZTF+v8o2epJsngcgAAAABAdSP5iKoVnnI9pb5Gc1obKhgNJoraZEKnLZ2uUxZPU2NdUl/73RqtC+2Qfe+aXbp3zS7Nm9qgt566SG8+eYFmtLCLOgAAAACgOjEsB1Vr9fbslOvls9npGsVJJkxvPOkw/foj5+jrbz1Rx85vzXl/655e/fuvntHpn7lDH7j+D/rNU9s1MDRcoWgBAAAAAKgMRj6iaq0OjXxkvUccqmTC9Mrj5uoVx87RIxs7dO39G/W/TzyvweFgh+z+oWHd/Pjzuvnx59XeVKsLj5+r175ovk5a1E7CGwAAAAAw6ZF8RNXKTT6y0zWKd/2Dmw6oO23pdB0zr1UPb9ith9bv1r7ewcx7Hd0D+t4Dm/S9BzapralWK+a2asXcVi2a3qxkwnTxyoXjGT4AAAAAAGVH8hFVa01k2jVQKq0NtXrpUbN17hGz9My2fXp0c6eefn5vZjSkJHV2D+i+tbt039pdaqxN6ui5UzStuU5nHj5DLfV8awYAAAAATA78houqtK93QFv39GbKTLtGOSQTphXzWrViXqt6B4b0p6179OhznVq3o0seatczMKQ/bOrUe7/3e9UkTCctatc5R87U2YfP1Iq5rUokmJ4NAAAAAJiYSD6iKq0N7U7cXJfU/LbGCkaDatBQm9RJi6bppEXTtLd3QE89v1dPPb9Xa3d0aSg0InJw2PXg+t16cP1uff6Xz2hGS73OWD5dpy4NXounN7FWJAAAAABgwiD5iKqUs9P1LHa6xvhqbajVyiXTtXLJdPUODOnZ7fu06vm92rirW3t6BnLa7tzfpxsf3aobH90qSZrT2qBTl07TKUum6+TF7Vo2s4WRkQAAAACA2CL5iKq0JrTZzHI2m0EFNdQmdfxhbTr+sDa9+eQFemxzp+56dofufHaHHnuuU8Oe237b3l7d8OhW3ZBKRrY11eqkhe16yeJpesnidh03f6oaapMV+EgAAAAAADgQyUdUpWdDIx8PZ7MZxMSPHn5OkjRrSoMuOmmBLjxurtbu6NL6nfu1bkeXXtjXd8A5nd0D+s3TL+g3T78gKVhn8rC2Ri2a3qRLT1uskxa1q725blw/DgAAAAAA0kg+jsDMjpF0vKR5koYkbZH0iLuvH+c4EpJOl7RM0lxJe1Kx3O3uHeMZy2Syp2dAj2zMfvrYbAZx1VRXo+PmT9Vx86dKkvb3DWr9zi6t27FfG3d1a/veXkUGRmpo2LVxd7c27u7WXat3SgqWFnjJomB05MmL27VwGutGAgAAAADGB8nHEDN7o6SPK0g8jvT+fZI+5u6/K3McNZI+KulyBQnQqH4zu0nS37v7hnLGMhn91z3rta93UJI0pb5GJy+ZVuGIgMK01OcmI3sHhrRpd7c27urShl3d2tzRrYGhaDoyWGZgzQv79cPUyMoZLfU6eXG7Tl06XSuXTtMRs6awbiQAAAAAoCxIPkoys6SkqyW94yBNT5f0GzP7tLt/vEyxzJZ0s6SXjNKsTtIbJF1gZpe6+43liGUy2tM9oP+6JzuA9V1nLlFrQ20FIwIOXUNtUkfMnqIjZgfrlg4Nu7Z29gQjH3d1afvePu3cf+BU7Z37+/S/T27T/z65TZLU3lSrkxdP08ql0/XihW1aMbeVdSMBAAAAACVB8jHwJeUmHrslfV/SowoSfSsVJPtqJSUk/ZOZ7Xb3L5UyCDNrlHSjchOPWyR9T9JaSdMlvVLS2an3WiX90MzOc/f7SxnLZPWde9ZpX19q1GNDjd515pIKRwSUTjJhWjCtSQumNenM5TP0l6cs0MZd3XpkY4ce2bBbj2zsyNlsKa2je0C3rdqu21ZtlyTVJk0r5rbqRQva9KKFbTpi9hQtndGixjoSkgAAAACA4lR98tHMLpT0wVDVKkmvcPfnIu1OkHSrstOgv2Bmt7v7EyUM518UJDrTfirpEncPD136rJldLOm7CpKhDZJ+ZGZHuHtvCWOZdDq7+/Vf927IlN995lJNbWTUIyavHzyU/TaW3lG7u29QG3d3a/3OLq3f2aWtnT0HrBs5MOR6bPMePbZ5j/77/o2Z+vltjVo2q0XLZjZrQXuT5kxt0OzWes2a0qBZrfWqryE5CQAAAADIVdXJx9SGLp8OVXVLenU08ShJ7v6YmV0k6W4Fox/T5766RLEcJukDoarHJV3s7gMjxHK9mS2U9JlU1QJJ75f0xVLEMlldffd67U+NemxtqNE7z1xc2YCACmiqr9HRc1t19NxWScG6kRt2dWn9ji5t2t2tLZ09Ghw+cN1ISdrS2aMtnT2669kdI74/tbFWbU21amusVWtjbaY8tbFWbY11mtpYq6mpcmtDraY01KilvkYtDTWqTSbK9jEDAAAAACqnqpOPkl6q3M1l/sPd1+Vr7O73mdlPJL05VfXnZrbc3deUIJb3KRjFmHbFSInHkC8oSFbOT5U/LJKPee3u6tc192bXenzP2UtZ6xFQsG7kUXNaddScIBk5NOzatrdXz+0ONrDZ2tmrnfv78iYkw/b0DGhPz4A2HrTlgeprEjnJyJb6GrXU16qlPqnGuho11ibVUJtQY21SjXVJNdQmU3VJNdYlcsuhNk11SRKbAAAAAFBB1Z58fF2kfHUB53xb2eSjJL1WQSKwlLFslHTbaI3dfdDMrpH0T6mqw8zsJe7+SAlimXS+ffc6dfUPSZLammr19tMXVzYgIKaSCdP8tkbNb2tUsMysNOyuzu4B7djXqx37+rRjf5/29Axob8+g9vYOqDv1f2ss+gaH1be/Xzv394/5WlEt9TVqa6rVtOY6tTXVaVpTbfBvc53aQ8dtTbVqb6pTe1Md61sCAAAAQIlUe/LxwtDxWndfW8A5d0vqVXaU4p9rjMlHM1si6ehQ1e3ufvBhRtKvlU0+pmMh+Rixa3+f/vu+DZnyX5+1VFMY9QgULGGmac1Bgu7IOQe+Pzg0rL29g+rqG1TPwJB6+ofUMzCk7v4h9abK3Zn6QfX0DwXJxsHhcYl/f9+g9vcNanNHT8Hn1Nck1N4USkg2B0nKtsagPKUhO0IzGLFZq+b6pKak/q1htCUAAAAASKri5KOZtUlaGKp6oJDz3L3fzH4v6YxU1fGjtS/QCZFyQbFIekjSoLJfx1LEMul86+51mZFZ7Yx6BEquJpnIJCeLMeyu/lQSsncglZAcGFJv6t++wWH1Dg5pcMjVPzSsgcFhDQwNa2DIU/8Gx/3p48FhDQy7BgaHD9hEp1h9g8PatrdX2/Ye2j5ejbXJ0PTx7HTyKfU1aq6PJi6Duik5U86D48bapMxsjB8NAAAAAFRO1SYflTvSUJKKWbdxrbLJx3Yzm+Pu28Y7FnfvNbOtyiZRV4whhklp5/4+XXtfdgW6y85Zppb6au72QHwkzNSQWqexlDvPu7uGhl0DQ66+wWAEZnf/kLr6B4PjvsFMuSdc3z+k/hKNxuwZCEZ/7tjXN6brmEkNNak1LGsSaqhLZsupNTDrU+tc1tUkVJdMqCZhqq1JqDaZUF3SVJsMjmtrEqpNWOY4/F5N0lSXbpdMqK4meC+ZMNUkgn/Tr5rIMclRAAAAAKOp5izM0kh5UxHnRtsulTSW5ONYY0knH6PXqXr/fd8G9QwEox6nN9fp0tMWVTgiAOVmZqpJmmqSUmNdUm1NhZ87ODSs7oF0wnJQ3X2paeOhBGX3wFB2ZGZ6xObgkAaGxjre8kDu2URmXCVMqkkklEho1ERlupxIJSvTSct06tIseAV1lq1TtkG4bfpcS9UnzGQ28r+JTNmUDMWZjq8maTmJ1pqc+BOh9w+sz9aFzs1Xn/r40x9rIhV7ULYRPx5JSiRC7RR8PEodj3id1OfNFfQhueRyuafrPDM6OL3ISxBbcH7ClBNrwkzJTNwkmwEAAFCcak4+tkbKu4s4tyNSnhKTWGrNrN7dxzbUZhJ537nBSMdv3rVOl52zVE111dzlARxMTTKh1mRCrYewLuzQcDCNvHcwO4U8naAM6odHTFr2DQxnjnsHgn8L2Fw8NoZd6h8aloYkaXzW8UTlRJORmeRuKlmZTvQmEqFjMyUSoeNQvYUSnenkZ/5EaKocOk6k2tsIx+lzw23zGW2lbR9lIYd0Qjd7jVSS17PnZpO+oevlSQRH63ISx56NJXzfzNcm8nUK6iynPNJ7kX9SbQ7848BIbdIJ85G+7tGvRbiPZL+GwR8E0sfpr1n485v5nB1Qp4LaKafdoV8n53KpymJiiCb8M21D10/H5zl12bY550ZiyNc/sudkzz9Y27xx5dwrG0e0v+Zcq5i4Ivca6WsY/qNUug+m6zXCH67Sf7TJHIf7f+QPWdnjkesVvVbe61vkGtk/FkWvW6himmc/99HvLz7C1+DAfhFtE/6elb5ubmy53y+in6fwe/m+N1n4ZEW/BnnOidw/G0/+9tn7HRjz6PfIvn/A98O83y8Pcq0ivqj5mmb/BxTavrzXH7ltnmsUHUuR1y/yOvlOGKm2mM/XzCn1umDF7Hx3rSrVnIlpiZSLWdgrumtB9FqVjqWg5KOZ/SnPW0etXbtWxxxzTBFhxNuwS1f9QPpycT/nx2xPz8D43hDApBBNQmR+gfARfjlIt0m/n7lG6JfFTL3nlLPHB/7CO+aFMwEAAIAq1liX0GHtRUzDirG1a9dK0oJDPb+ak48NkXJ/EedGk3uNkygWSRru6+vrWrVq1XMluFY1W5b6t5Bd1FFd6BvIh76BfOgbyIe+gXzoG8iHvoF86BslNCBp1fOVjqJkFkjqPtSTqzn5GB1dWMw2rfWRcnT0YSliKXT04yHH4u6TZ2hjDKVHlvJ5RhR9A/nQN5APfQP50DeQD30D+dA3kA99A+WSqHQAFbQ/Uo6OPhxNdHRh9FoTORYAAAAAAACgJKo5+bg3Um4v4ty2SHnf2EIpWSwDbDYDAAAAAACAuKjm5OP6SHlhEecuipTXxSSWscYBAAAAAAAAlEw1Jx9XRcrLizh3Wei4w923VSIWM2uQNG+U6wAAAAAAAAAVU7XJR3fvlLQpVHVaIeeZWZ2kk0JVT5QgnMci5YJikXSKcjcNKkUsAAAAAAAAQElU827XknSrpPemjpeZ2VJ3P9jU5bOUuyHMzWMNwt3Xm9nTko5KVZ1vZubufpBTL4iUxxwLSocdwpAPfQP50DeQD30D+dA3kA99A/nQN5APfQPlUrUjH1N+Hin/dQHnRNvcUJpQcmJZJOllozU2sxpJ7wxVbZH0SIliAQAAAAAAAMas2pOPt0t6MlT+oJktydfYzE6TdFGo6hZ3X52n7WIz89DrdweJ5euSwjtVf97Makdp//eS5ofKVxUwUhIAAAAAAAAYN1WdfHT3YUlXhqqaJd1kZguibc3seEk/UfZzNizpYyWM5TlJXw1VHS/p+2ZWP0Isfynpk6GqLZK+UqpYAAAAAAAAgFIwBstJZvZVSZeHqrokfV/So5JqJZ0q6Y2p47R/cPcvjHLNxZLWh6rudPdzDxJHk6Q7Jb0kVL1F0nWS1klql/QqSeeE3u+TdL673zPatQEAAAAAAIDxRvJRkpklJV0j6W0FNHdJn3X3K0drdCjJx9R5cyTdIunEAmLZJ+nt7h5duxIAAAAAAACouKqedp3m7kPufqmkNyt3DcioBxSMMhw18TjGWLYpGGn5CUnb8jTrV7BBzQkkHgEAAAAAABBXjHwcgZkdq2DNxXmShiRtlfSwu68b5ziSkk6XtFzSbAUjHTdLutvdd49nLAAAAAAAAECxSD4CIWZ2jHITz1skPeLu60c9sfRxJBQknpdJmitpTyqWu929YzxjQaDSfcPM6iQdLWmFpDmSmiTtlbQ9Fce4/nEEWZXuG4ivuPUNM2tV8LNlnqRZkvZLeiEV16Pu3lWJuKpRXPqGmS1TsNTPXElTJPVI2iXpcUlPuPvgeMaD+OBZFFE8iwIYC5KPgCQze6Okjyv4RWAk90n6mLv/rsxx1Ej6qIINkOaN0KRf0k2S/t7dN5QzFgQq2TfMbL6Cza5eJelMBQ95+ayR9DVJX3P3vlLHggPF5ftGPmb2fklfiVR/0t3/uQLhVJW49Q0zO0vBz5YLJNXlaTakYHmZj7n7neMRVzWKQ99Izax5n6T3SzpqlKY7Jf23pE8z46Z8Ukm+oxVsOJl+nSCpMdTsz8bx+wXPojERh77Bs2g8xaFvFIJnUYSRfERVSz2AXy3pHQU0H1bwAP7xMsUyW9LNyt3tPJ+9ki519xvLEQsq3zfM7GWSfinJijz1T5Le5O6rShULclW6bxTCzA6TtErBSKYwHvjKKG59w8yaFDz0v0OFfy/5B3f/QrliqlZx6RtmNkvBxoaFPGukvSDpDe5+T6njqXZm9j+SXi6p+SBNxyWJwLNofMShb/AsGk9x6BuF4FkUUTWVDgCosC8p9xeBbknfl/SogtEhKyW9QVKtgg2a/snMdrv7l0oZhJk1SrpRuQ97WyR9T9JaSdMlvVLS2an3WiX90MzOc/f7SxkLMirdN5qU+7A3LOkxSXdL2iipQ1K7gg2q/kLZ0UzHSLrDzM509zUligW5Kt03CvF1Hfiwh/KLTd8ws2YFSaZzQtU9kn6jYITjdklJBVPnXiTpPAU/W1AeFe8bqSmTv1buqMs+Sb9Q0Cd2S2qRdJyCkU7TUm1mSfpfM1tJMqHkTtLBEwjjgmfR2IlD3+BZNJ7i0DcKwbMocrk7L15V+ZJ0oSQPvf4kacEI7U5Q8PCVbjck6bgSx/LvkVh+Iql+hHYXK5jukm63SVJDpT+Xk+0Vh74h6bWpa65TMP1p7ihtFyqYqheO+a5Kfx4n4ysOfaOAGN8Suu+qSLz/XOnP4WR9xa1vSLo1Es+1kmaN0r5W0uskvaLSn8vJ9opL35B0RSSORyUtydN2iqQfRdr/utKfy8n2krQh9PntlfSQgl/Yr4t87s8dh1h4Fo3RKw59QzyLxvIVh75RQIw8i/I64JUQUIVS62R8OlTVLenV7v5ctK27PybpIgV/7ZOCEQmfjrYbQyyHSfpAqOpxSRf7CGuluPv1kj4RqlqgYM0mlEiM+sYLki6TdKS7f87dn8/X0N03KZh+8Uyo+iwzOzvPKTgEMeobo8U4XdKXU8VeSR8q9z0Rv75hZn+lYIRS2ufd/VJ3fyHfOe4+4O4/d/dfljKWahezvvH20HFPKo71IzV0932S3qrgmSTtpWY20hqAOHTXSnqPgpFMU9z9FHd/n4IRyuOGZ9FYikPf4Fk0nuLQN/LiWRT5kHxEtXqpcqcd/YePskObu9+n4C/AaX9uZstLFMv7JDWEyle4+8Ao7b+gYGRE2odLFAcCsegb7n6fu3/rIH0h3H6fpE9Gqv98rHEgRyz6xkF8ScEUSUn6lILF31F+sekbZjZFwc+JtAck/Z9SXBuHJBZ9w8waFOxQm3bzSAnQSCyDkr4dvozyb5SDQ+Dun3D3b7v7Hwr9eV8mPIvGTBz6Bs+i8RSHvnEQPItiRCQfUa1eFylfXcA5346UX1uaUHJi2SjpttEap34ZuCZUdZiZFbNwPEYXp75RrNsj5WUViWLyinXfSC0M/7ZUcZWkz5frXjhAnPrGJZLaQuUr3H04T1uUX1z6xvRIudBfBldHytNGbIWJjmdRlArPolWMZ1GMhuQjqtWFoeO17r62gHPuVjB0PG3Mf8kzsyWSjg5V3e4eLJRxEL+OlPmrYunEom8cov2R8kRYjHoiiW3fSG0u8s1U0SVdFtO/hk9Wceob7wkdP+Pud5foujg0cekbnQq+N6QV+vOhJVLOO3UfExPPoigxnkWrFM+iOBiSj6g6ZtamYFHktAcKOc/d+yX9PlRViqlHJ0TKBcWiYGHhwRLHUvVi1jcOxZJIeVtFopiEJkDf+JSkxanjq939njLdBxFx6htmNkPBztVpt471mjh0ceob7t6lYJfatPMKPPWloeP0xgaYXHgWRSnxLFq9eBbFqEg+ohodHSkXsw5FeMRCu5nNqUQs7t4raWuoakW+tihKnPrGoXh9pHx/BWKYrGLbN8zsFGUX896uYEdKjJ849Y1TIuX7pWDxdzP7iJndY2bPm1lf6t/7zOxTZnb4GO+LkcWpb0jSf4aOjzWzUTcJMbOTJb0rVPUtd99bgjgQLzyLopR4Fq1CPIuiECQfUY2WRsqbijg32jZ6rUrFMtY4EIhT3yiKmbVIujxU1S/pxvGMYZKLZd8ws1pJ31H25/lH3L2jVNdHQeLUN14cKT9tZm+Q9LSk/yfpDElzJNWl/j1N0sckPWVmXzOz+jHeH7ni1DekYI2+8M+F/0x93Y8KNzKzOWZ2haTfSkr3iYckXVmCGBA/PIuiJHgWrU48i6JQJB9RjVoj5d1FnBv9RjolJrHU8ktjScSpbxTri5LmhsrfcHemupROXPvGP0o6NnV8m7v/oITXRmHi1DdmRsrnKtg5eUaq7JJ2SHpe0lCoXVLBbre/MbPGMcaArDj1DaXW8XuTpKsUTJc1BV/3p8xsj5mtN7N0//icgrXaBiR9XdJLU1O3MfnwLIpS4Vm0OvEsioKQfEQ1ii6e3jtiq5H1HORaEzkWTNCvh5ldqtxNJjZJ+vh43b9KxK5vmNnRCkatpe/xvlJcF0WLU99oi5S/qCDB1CfpnyXNd/dZ7j5Pwe7Hlys30XCGgkQTSiNOfUNSsJ6ku39EwS+Kd4bealWwVteMUN0mSa9198vdPbqJBCaP2PVTTDw8i1YnnkVRDJKPqEYNkXJ/Eef2RcpjHSESp1gwAb8eZnaOpG+HqgYkvYV1uUouVn3DzEzB1z09yuRf3H3dWK+LQxKnvhH9xb9WwfeEV7n7J939+fQb7r7H3b8u6UxJu0LnvD211h/GLk59Q5JkZgkz+4ikuySdc5DmCyXdYma/NjOm1E5eseunmFh4Fq1OPIuiWCQfUY2if9GtK+Lc6HSS6F98J3IsmGBfDzM7SdIvlI3TJb3T3Vncu/Ti1jcuVzBKTZKeUDDCDZURp74x0oilL7r7HflOcPenJP1tpPrDY4wDgTj1DZlZg6SbFaz/OStVfbuk1yqYKlknqV1BUvLbyk7NP1/SI2Z24lhjQCzFqp9iYuFZtKrxLIqikHxENYpOHYr+xXc00b/ojnUaUpxiwQT6epjZcZJ+pdy1mi539++X875VLDZ9w8wWSPpMquiSLnP3gbFcE2MSm74haV+k7JL+o4DzrlewO2Xa+WOMA4E49Q1J+rKkV4bKV7r7Be5+o7tvc/cBd+9097vc/T2SXqZsYqpd0s9SG0pgcolbP8UEwbNo9eJZFIeC5COqUXQKQHsR57ZFytFf9IpVqlgG3D069QXFi1PfyCu1M+ntCtZsS/uwu3+jXPdErPrG15XdfOIbjC6ouDj1jWgsT4enWufj7oOS7glVzTKzw8YYC2LUN1Lrcv11qOoX7v6ZfO0lKTVi9mOhqkWSLhtLHIglnkVRNJ5Fqx7PoigayUdUo/WR8sIizl0UKY91XYtSxcL6GqURp74xIjM7XNIdyk6Zk6R/dPcvl+N+yIhF3zCz10i6MFXcJun/HOq1UDKx6BspayPlTUWcuzFSju6cjeLFqW+8RcHmQ2lfKfC8byp3DcDXjzEOxA/PoigKz6LVjWdRHKqaSgcAVMCqSHl5EecuCx13uPu2MsRy50gNw1LrNs0b5To4NHHqGwdILfh/h4K1udI+4e6fK/W9cIC49I3wpg9Nkn4frPedV/Tn/IfM7JJQ+VPu/t0xxIP49A1J+lOkXMyutdG2xUy9xMji1DeOj5QfKeQkd+8ys6dD5x8zxjgQPzyLomA8i0I8i+IQkXxE1XH3TjPbpOxfdk8r5Dwzq5N0UqjqiRKE81ikfJqk7xRw3inK/f9biliqXsz6RvQeiyT9VlJ4KuSn3P1fS30vHCimfaNVuessFaJduVPq2koWTZWKWd94UsEmIclUeVoR50bb7hqxFQoWs77RHCkXszZfV+iY3YwnH55FURCeRTECnkVRMKZdo1rdGjpelvor3sGcpdyRIDePNQh3Xy/p6VDV+XaQPx2lXBApjzkWZMSib4Sl1l67Q7lToT7n7h8v5X1wULHrG4iNWPQNd9+j3BFLx5tZoc96Lw4dD0jaPNZ4ICkmfUNSR6Q8p4hzwyOcSEpPMjyLohA8iwIYK5KPqFY/j5T/esRWo7e5oTSh5MSySMHuknmZWY2kd4aqtqjA6VMoSJz6hsxsroKHvfAvrP/P3f+xVPdAwSreN9z9Kne3Ql+SlkQu8clIm6vGEg8yKt43Qn4aOp6qg/xMkSQzWyLp5FDVA+7eXaJ4ql1c+saaSDmaOBpRam23xaGqZ0sQC+KHZ1HkxbMowngWxaEi+YhqdbuC6WlpH0z98jUiMztN0kWhqlvcfXWetovNzEOv3x0klq9LCu8O+Hkzqx2l/d9Lmh8qX+XufpB7oHCx6RtmNjMVz+Gh6v9w97872AeBsohN30DsxKlvXCdpe6j82dQ03tF8UbnPhP99kPYoXFz6xi8j5SvNbMqILXNF13H7VQHnoMJ4FkU+PIsiH55FUW4kH1GV3H1Y0pWhqmZJN5nZgmhbMzte0k+U/f8yLOljJYzlOUlfDVUdL+n7ZlY/Qix/KemToaotKnzHShQgLn3DzNol/VrSilD119z9b0pxfRQvLn0D8ROnvuHu+yX931DVCZJ+lvqeEo2l3sy+Kul1oepnJV1bqniqXVz6hrvfLenhUNUySbemplIewMyazOxq5faNvZK+XYp4EC88i2IkPIsCKCU2nEHVcvebzOxrki5PVR0j6Skz+76kRyXVSjpV0htTx2kfdffo4txj9XFJZ0t6Sap8kaTTzew6SesULMr7KknnhM7pk/QWdy9mN1MUICZ94wMKkgZhrzCz6NS50Wx293NLFA8Um76BGIpZ3/iWgp8Xf5kqXyhpjZn9WNLjkgYVjGJ5k4Iplmn7Jb3B3QdKHE9Vi1HfuEzSXZJaUuUzFfSLX0h6UMF6js0KEk9vkDQ9cv7fuPvOEsZT9czs9ZI+P8Jb0VGp3zeznhHaXeHuPytRODyLxkhM+gbPojEUk74BFI3kI6rdhxR8o35bqtws6T152rqkz7r7F0odhLt3m9mrJd0i6cRU9XxJ+dZS2Sfp7e5+T6ljQUal+0ZyhLpCNioI43t8eVS6byC+YtE33N3N7B0KRtC9OVU9TdJ7Rzlti6TXufuTo7TBoat433D3P5rZhZJ+qOwmMvUKkkwX5T1R6pX0EXf/binjgaRgl9hlBbSbN8r5JcGzaOzEoW/wLBpPcegbQNGYdo2q5u5D7n6pgl/ORvuF6wFJ57v7laO0GWss2xSMfPiEpG15mvUrWBT8BHePLmKPEopT30C80DeQT5z6hrv3u/tbFIxufHSUpnsUjKA4wd0fHqUdxiAufcPd75J0rKR/U/5njbRuSddIerG7f6Mc8SBeeBYFAJSLsTYwkGVmxyqYbjRP0pCkrZIedvd14xxHUtLpkpZLmq3gr8ubJd3t7rvHMxYE4tI3ED/0DeQTp75hZkdIenEqljoFU2xXSXrI3QfHO55qF4e+YWYm6WhJL5I0U8HIzB5JuxX0jUfdvS/vBTCp8SwKACglko8AAAAAAAAAyoJp1wAAAAAAAADKguQjAAAAAAAAgLIg+QgAAAAAAACgLEg+AgAAAAAAACgLko8AAAAAAAAAyoLkIwAAAAAAAICyIPkIAAAAAAAAoCxIPgIAAAAAAAAoC5KPAAAAAAAAAMqC5CMAAAAAAACAsiD5CAAAAAAAAKAsSD4CAAAAAAAAKAuSjwAAAAAAAADKguQjAAAAAAAAgLIg+QgAAAAAAACMwMwSZnaMmb3dzP7TzO43s24z89Dr3ErHmWZmGyKxHcrrd6WMqaaUFwMAAAAAAAAmAzP7H0kvl9Rc6VjGWWcpL0byEQAAAAAAADjQSZp4iccNkgaLPGeepMZQ+Qcli0YkHwEAAAAAAICD6ZP0uKTfS2qRdEllwxmZu59bTHszq5e0Rdnk4y5JN5QyJpKPAAAAAAAAwIGulfScgoTjE+4+IElm9g7FNPl4CF4raXqofJ2795XyBiQfAQAAAAAAgAh3/8R43cvMTNKJklZImiXJJG2X9Ad3/1MZb/3uSPk7pb4ByUcAAAAAAACgAsxsiqSPKkgCzs7TZrWk/+vuJV2L0cwWS3ppqOpBd3+ylPeQpESpLwgAAAAAAABgdGZ2qqTVkj6mPInHlMMlXW9mPzaz2hKG8C4FIyzTri7htTMY+QgAAAAAAACMIzP7M0k3S2oKVT+TqlurYMfqIyW9SdKC1PsXSXJJby7B/ROS3hGq6pL0o7FedyQkHwEAAAAAAIBxYmazJP1A2cRjr6T3S7rG3T3S9uOSviTpslTVm8zsZne/boxhvEzZpKYk/cjd943xmiNi2jUAAAAAAAAwfj6r7DTrYUmvc/f/iiYeJcnde9z9vZL+J1T9r6mRi2MR3WimLFOuJZKPAAAAAAAAwLgwszmS3hqqutrdf1nAqR+SNJA6XiTpVWOIYaak14SqVrn7/Yd6vYMh+QgAAAAAAACMjzdKqguVv1TISe6+VdLtoaoLxhDDpZLCG9d8ZwzXOiiSjwAAAAAAAMD4OCt0vM7dny7i3IdCxyvHEMO7Qsf9kq4dw7UOiuQjAAAAAAAAMD5OCB3/qchzt4eODzuUm5vZaZJWhKpudPedh3KtQrHbNQAAAAAAADA+poeOX21mB2wyU6D2Qzxv3DaaSWPkIwAAAAAAADA+2kp0naZiTzCzFklvClVtVO46kmXByEcAAAAAAABgfHRLak0dd0jaPY73foukllD5GncfLvdNST4CAAAAAAAA42OnssnHn7j7ZeN4778KHQ9LumY8bsq0awAAAAAAAGB8hHe3Pma8bmpmx0g6NVR1m7tvGo97k3wEAAAAAAAAxsdvQ8enmtmMcbrvX0XK3xmn+5J8BAAAAAAAAMbJTyUNpo6Tkv6h3Dc0szpJbwtV7ZB0Y7nvm0byEQAAAAAAABgH7r5B0g9CVX9rZi8r5hoWqCvilL+QFB5hea27DxRzz7Eg+QgAAAAAAACMnyskPZ86rpF0k5n9nZk1jHaSmc01sw8qWDfyxCLuV7Ep15Jk7j6e9wMAAAAAAABiz8xeL+nzI7w1RdKsUHmrpJ4R2l3h7j/Lc+3TJP1S2Z2vpWAn7F9JelTSbgXTstskHaEg2fhiSZZqe5q7P1DAx7BQ0nplByDe5+5nHOy8UqoZz5sBAAAAAAAAE0SrpGUFtJs3yvkjcvf7zexUSTcoSC5KwdTot6ZeBzNUQBtJeqdyZz5fXeB5JcO0awAAAAAAAGCcuftTko6V9F5Jqwo4ZZWkL0p6sbs/fLDGZmYKko9p+yT9+BBCHROmXQMAAAAAAAAVZmbzJZ0qabakdkn9kjokrZX0pLvvqGB4h4zkIwAAAAAAAICyYNo1AAAAAAAAgLIg+QgAAAAAAACgLEg+AgAAAAAAACgLko8AAAAAAAAAyoLkIwAAAAAAAICyIPkIAAAAAAAAoCxIPgIAAAAAAAAoC5KPAAAAAAAAAMqC5CMAAAAAAACAsiD5CAAAAAAAAKAsSD4CAAAAAAAAKAuSjwAAAAAAAADKguQjAAAAAAAAgLIg+QgAAAAAAACgLEg+AgAAAAAAACgLko8AAAAAAAAAyoLkIwAAAAAAAICyIPkIAAAAAAAAoCxIPgIAAAAAAAAoC5KPAAAAAAAAAMri/wMZEmFf1HC9MAAAAABJRU5ErkJggg==\n",
217
+ "text/plain": [
218
+ "<Figure size 1500x750 with 1 Axes>"
219
+ ]
220
+ },
221
+ "metadata": {
222
+ "needs_background": "light"
223
+ },
224
+ "output_type": "display_data"
225
+ }
226
+ ],
227
+ "source": [
228
+ "gene_detection_counts = [i for i in gene_detection_counts_dict.values()]\n",
229
+ "import seaborn as sns\n",
230
+ "import matplotlib.pyplot as plt\n",
231
+ "plt.figure(figsize=(10,5), dpi=150)\n",
232
+ "plt.rcParams.update({'font.size': 18})\n",
233
+ "count_plot = sns.distplot(gene_detection_counts).set_title(f\"# Cells Expressing Each\\nProtein-Coding or miRNA Gene\")"
234
+ ]
235
+ },
236
+ {
237
+ "cell_type": "code",
238
+ "execution_count": 47,
239
+ "id": "missing-bradley",
240
+ "metadata": {},
241
+ "outputs": [
242
+ {
243
+ "data": {
244
+ "text/plain": [
245
+ "27454"
246
+ ]
247
+ },
248
+ "execution_count": 47,
249
+ "metadata": {},
250
+ "output_type": "execute_result"
251
+ }
252
+ ],
253
+ "source": [
254
+ "len(gene_detection_counts)"
255
+ ]
256
+ },
257
+ {
258
+ "cell_type": "code",
259
+ "execution_count": 55,
260
+ "id": "perfect-signal",
261
+ "metadata": {},
262
+ "outputs": [
263
+ {
264
+ "data": {
265
+ "text/plain": [
266
+ "25424"
267
+ ]
268
+ },
269
+ "execution_count": 55,
270
+ "metadata": {},
271
+ "output_type": "execute_result"
272
+ }
273
+ ],
274
+ "source": [
275
+ "len([i for i in gene_detection_counts if i > 0])"
276
+ ]
277
+ },
278
+ {
279
+ "cell_type": "code",
280
+ "execution_count": 56,
281
+ "id": "faced-theory",
282
+ "metadata": {},
283
+ "outputs": [
284
+ {
285
+ "data": {
286
+ "text/plain": [
287
+ "22735"
288
+ ]
289
+ },
290
+ "execution_count": 56,
291
+ "metadata": {},
292
+ "output_type": "execute_result"
293
+ }
294
+ ],
295
+ "source": [
296
+ "len([i for i in gene_detection_counts if i > 100])"
297
+ ]
298
+ },
299
+ {
300
+ "cell_type": "code",
301
+ "execution_count": 57,
302
+ "id": "tough-workplace",
303
+ "metadata": {},
304
+ "outputs": [
305
+ {
306
+ "data": {
307
+ "text/plain": [
308
+ "21167"
309
+ ]
310
+ },
311
+ "execution_count": 57,
312
+ "metadata": {},
313
+ "output_type": "execute_result"
314
+ }
315
+ ],
316
+ "source": [
317
+ "len([i for i in gene_detection_counts if i > 1000])"
318
+ ]
319
+ },
320
+ {
321
+ "cell_type": "code",
322
+ "execution_count": 49,
323
+ "id": "cooperative-camcorder",
324
+ "metadata": {},
325
+ "outputs": [
326
+ {
327
+ "data": {
328
+ "text/plain": [
329
+ "173152.0299000284"
330
+ ]
331
+ },
332
+ "execution_count": 49,
333
+ "metadata": {},
334
+ "output_type": "execute_result"
335
+ }
336
+ ],
337
+ "source": [
338
+ "gene_detection_event_digest = crick.tdigest.TDigest()\n",
339
+ "gene_detection_event_digest.update(gene_detection_counts)\n",
340
+ "gene_detection_event_digest.quantile(0.5)"
341
+ ]
342
+ }
343
+ ],
344
+ "metadata": {
345
+ "kernelspec": {
346
+ "display_name": "Python 3 (ipykernel)",
347
+ "language": "python",
348
+ "name": "python3"
349
+ },
350
+ "language_info": {
351
+ "codemirror_mode": {
352
+ "name": "ipython",
353
+ "version": 3
354
+ },
355
+ "file_extension": ".py",
356
+ "mimetype": "text/x-python",
357
+ "name": "python",
358
+ "nbconvert_exporter": "python",
359
+ "pygments_lexer": "ipython3",
360
+ "version": "3.10.11"
361
+ }
362
+ },
363
+ "nbformat": 4,
364
+ "nbformat_minor": 5
365
+ }
examples/pretraining_new_model/pretrain_geneformer_w_deepspeed.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # run with:
5
+ # deepspeed --num_gpus=12 --num_nodes=3 pretrain_geneformer_w_deepspeed.py --deepspeed ds_config.json
6
+
7
+ import datetime
8
+
9
+ # imports
10
+ import os
11
+
12
+ os.environ["NCCL_DEBUG"] = "INFO"
13
+ os.environ["OMPI_MCA_opal_cuda_support"] = "true"
14
+ os.environ["CONDA_OVERRIDE_GLIBC"] = "2.56"
15
+
16
+ import pickle
17
+ import random
18
+ import subprocess
19
+
20
+ import numpy as np
21
+ import pytz
22
+ import torch
23
+ from datasets import load_from_disk
24
+ from transformers import BertConfig, BertForMaskedLM, TrainingArguments
25
+
26
+ from geneformer import GeneformerPretrainer
27
+
28
+ seed_num = 0
29
+ random.seed(seed_num)
30
+ np.random.seed(seed_num)
31
+ seed_val = 42
32
+ torch.manual_seed(seed_val)
33
+ torch.cuda.manual_seed_all(seed_val)
34
+
35
+ # set local time/directories
36
+ timezone = pytz.timezone("US/Eastern")
37
+ rootdir = "/parent_ouput_directory"
38
+
39
+ # set model parameters
40
+ # model type
41
+ model_type = "bert"
42
+ # max input size
43
+ max_input_size = 2**11 # 2048
44
+ # number of layers
45
+ num_layers = 6
46
+ # number of attention heads
47
+ num_attn_heads = 4
48
+ # number of embedding dimensions
49
+ num_embed_dim = 256
50
+ # intermediate size
51
+ intermed_size = num_embed_dim * 2
52
+ # activation function
53
+ activ_fn = "relu"
54
+ # initializer range, layer norm, dropout
55
+ initializer_range = 0.02
56
+ layer_norm_eps = 1e-12
57
+ attention_probs_dropout_prob = 0.02
58
+ hidden_dropout_prob = 0.02
59
+
60
+
61
+ # set training parameters
62
+ # total number of examples in Genecorpus-30M after QC filtering:
63
+ num_examples = 27_406_208
64
+ # number gpus
65
+ num_gpus = 12
66
+ # batch size for training and eval
67
+ geneformer_batch_size = 12
68
+ # max learning rate
69
+ max_lr = 1e-3
70
+ # learning schedule
71
+ lr_schedule_fn = "linear"
72
+ # warmup steps
73
+ warmup_steps = 10_000
74
+ # number of epochs
75
+ epochs = 3
76
+ # optimizer
77
+ optimizer = "adamw"
78
+ # weight_decay
79
+ weight_decay = 0.001
80
+
81
+
82
+ # output directories
83
+ current_date = datetime.datetime.now(tz=timezone)
84
+ datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}_{current_date.strftime('%X').replace(':','')}"
85
+ run_name = f"{datestamp}_geneformer_30M_L{num_layers}_emb{num_embed_dim}_SL{max_input_size}_E{epochs}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_O{optimizer}_DS{num_gpus}"
86
+ training_output_dir = f"{rootdir}/models/{run_name}/"
87
+ logging_dir = f"{rootdir}/runs/{run_name}/"
88
+ model_output_dir = os.path.join(training_output_dir, "models/")
89
+
90
+
91
+ # ensure not overwriting previously saved model
92
+ model_output_file = os.path.join(model_output_dir, "pytorch_model.bin")
93
+ if os.path.isfile(model_output_file) is True:
94
+ raise Exception("Model already saved to this directory.")
95
+
96
+
97
+ # make training and model output directories
98
+ subprocess.call(f"mkdir {training_output_dir}", shell=True)
99
+ subprocess.call(f"mkdir {model_output_dir}", shell=True)
100
+
101
+
102
+ # load gene_ensembl_id:token dictionary (e.g. https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/blob/main/token_dictionary.pkl)
103
+ with open("token_dictionary.pkl", "rb") as fp:
104
+ token_dictionary = pickle.load(fp)
105
+
106
+ # model configuration
107
+ config = {
108
+ "hidden_size": num_embed_dim,
109
+ "num_hidden_layers": num_layers,
110
+ "initializer_range": initializer_range,
111
+ "layer_norm_eps": layer_norm_eps,
112
+ "attention_probs_dropout_prob": attention_probs_dropout_prob,
113
+ "hidden_dropout_prob": hidden_dropout_prob,
114
+ "intermediate_size": intermed_size,
115
+ "hidden_act": activ_fn,
116
+ "max_position_embeddings": max_input_size,
117
+ "model_type": model_type,
118
+ "num_attention_heads": num_attn_heads,
119
+ "pad_token_id": token_dictionary.get("<pad>"),
120
+ "vocab_size": len(token_dictionary), # genes+2 for <mask> and <pad> tokens
121
+ }
122
+
123
+ config = BertConfig(**config)
124
+ model = BertForMaskedLM(config)
125
+ model = model.train()
126
+
127
+ # define the training arguments
128
+ training_args = {
129
+ "learning_rate": max_lr,
130
+ "do_train": True,
131
+ "do_eval": False,
132
+ "group_by_length": True,
133
+ "length_column_name": "length",
134
+ "disable_tqdm": False,
135
+ "lr_scheduler_type": lr_schedule_fn,
136
+ "warmup_steps": warmup_steps,
137
+ "weight_decay": weight_decay,
138
+ "per_device_train_batch_size": geneformer_batch_size,
139
+ "num_train_epochs": epochs,
140
+ "save_strategy": "steps",
141
+ "save_steps": np.floor(num_examples / geneformer_batch_size / 8), # 8 saves per epoch
142
+ "logging_steps": 1000,
143
+ "output_dir": training_output_dir,
144
+ "logging_dir": logging_dir,
145
+ }
146
+ training_args = TrainingArguments(**training_args)
147
+
148
+ print("Starting training.")
149
+
150
+ # define the trainer
151
+ trainer = GeneformerPretrainer(
152
+ model=model,
153
+ args=training_args,
154
+ # pretraining corpus (e.g. https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/genecorpus_30M_2048.dataset)
155
+ train_dataset=load_from_disk("genecorpus_30M_2048.dataset"),
156
+ # file of lengths of each example cell (e.g. https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/blob/main/genecorpus_30M_2048_lengths.pkl)
157
+ example_lengths_file="genecorpus_30M_2048_lengths.pkl",
158
+ token_dictionary=token_dictionary,
159
+ )
160
+
161
+ # train
162
+ trainer.train()
163
+
164
+ # save model
165
+ trainer.save_model(model_output_dir)
examples/tokenizing_scRNAseq_data.ipynb ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "a91bca46-c056-4784-8c6c-b0f5d3f33496",
6
+ "metadata": {
7
+ "tags": []
8
+ },
9
+ "source": [
10
+ "## Tokenizing .loom single cell RNA-seq data to rank value encoding .dataset format"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "markdown",
15
+ "id": "350e6252-b783-494b-9767-f087eb868a15",
16
+ "metadata": {},
17
+ "source": [
18
+ "#### Input data is a directory with .loom files containing raw counts from single cell RNAseq data, including all genes detected in the transcriptome without feature selection. \n",
19
+ "\n",
20
+ "#### Genes should be labeled with Ensembl IDs (row attribute \"ensembl_id\"), which provide a unique identifer for conversion to tokens. Other forms of gene annotations (e.g. gene names) can be converted to Ensembl IDs via Ensembl Biomart. Cells should be labeled with the total read count in the cell (column attribute \"n_counts\") to be used for normalization.\n",
21
+ "\n",
22
+ "#### No cell metadata is required, but custom cell attributes may be passed onto the tokenized dataset by providing a dictionary of custom attributes to be added, which is formatted as loom_col_attr_name : desired_dataset_col_attr_name. For example, if the original .loom dataset has column attributes \"cell_type\" and \"organ_major\" and one would like to retain these attributes as labels in the tokenized dataset with the new names \"cell_type\" and \"organ\", respectively, the following custom attribute dictionary should be provided: {\"cell_type\": \"cell_type\", \"organ_major\": \"organ\"}. \n",
23
+ "\n",
24
+ "#### Additionally, if the original .loom file contains a cell column attribute called \"filter_pass\", this column will be used as a binary indicator of whether to include these cells in the tokenized data. All cells with \"1\" in this attribute will be tokenized, whereas the others will be excluded. One may use this column to indicate QC filtering or other criteria for selection for inclusion in the final tokenized dataset.\n",
25
+ "\n",
26
+ "#### If one's data is in other formats besides .loom, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom format prior to running the transcriptome tokenizer."
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "execution_count": null,
32
+ "id": "080fdd9c-0c48-4d5d-a254-52b6c53cdf78",
33
+ "metadata": {},
34
+ "outputs": [],
35
+ "source": [
36
+ "from geneformer import TranscriptomeTokenizer"
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "code",
41
+ "execution_count": null,
42
+ "id": "37205758-aa52-4443-a383-0638519ee8a9",
43
+ "metadata": {},
44
+ "outputs": [],
45
+ "source": [
46
+ "tk = TranscriptomeTokenizer({\"cell_type\": \"cell_type\", \"organ_major\": \"organ_major\"}, nproc=4)\n",
47
+ "tk.tokenize_data(\"loom_data_directory\", \"output_directory\", \"output_prefix\")"
48
+ ]
49
+ }
50
+ ],
51
+ "metadata": {
52
+ "kernelspec": {
53
+ "display_name": "Python 3 (ipykernel)",
54
+ "language": "python",
55
+ "name": "python3"
56
+ },
57
+ "language_info": {
58
+ "codemirror_mode": {
59
+ "name": "ipython",
60
+ "version": 3
61
+ },
62
+ "file_extension": ".py",
63
+ "mimetype": "text/x-python",
64
+ "name": "python",
65
+ "nbconvert_exporter": "python",
66
+ "pygments_lexer": "ipython3",
67
+ "version": "3.10.11"
68
+ }
69
+ },
70
+ "nbformat": 4,
71
+ "nbformat_minor": 5
72
+ }
fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224/config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/",
3
+ "architectures": [
4
+ "BertForSequenceClassification"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.02,
7
+ "gradient_checkpointing": false,
8
+ "hidden_act": "relu",
9
+ "hidden_dropout_prob": 0.02,
10
+ "hidden_size": 256,
11
+ "id2label": {
12
+ "0": "LABEL_0",
13
+ "1": "LABEL_1",
14
+ "2": "LABEL_2"
15
+ },
16
+ "initializer_range": 0.02,
17
+ "intermediate_size": 512,
18
+ "label2id": {
19
+ "LABEL_0": 0,
20
+ "LABEL_1": 1,
21
+ "LABEL_2": 2
22
+ },
23
+ "layer_norm_eps": 1e-12,
24
+ "max_position_embeddings": 2048,
25
+ "model_type": "bert",
26
+ "num_attention_heads": 4,
27
+ "num_hidden_layers": 6,
28
+ "pad_token_id": 0,
29
+ "position_embedding_type": "absolute",
30
+ "problem_type": "single_label_classification",
31
+ "transformers_version": "4.6.0",
32
+ "type_vocab_size": 2,
33
+ "use_cache": true,
34
+ "vocab_size": 25426
35
+ }
fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224/trainer_state.json ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_metric": 0.39658036828041077,
3
+ "best_model_checkpoint": "/n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/models/220224_geneformer_27M_SequenceClassifier_tuning_hCMdCM_L2048_B12_LR1e-05_LScosine_WU500_E1_Oadamw_F2/run-8429a330/checkpoint-7020",
4
+ "epoch": 0.9,
5
+ "global_step": 7020,
6
+ "is_hyper_param_search": true,
7
+ "is_local_process_zero": true,
8
+ "is_world_process_zero": true,
9
+ "log_history": [
10
+ {
11
+ "epoch": 0.1,
12
+ "learning_rate": 0.00034606438343856935,
13
+ "loss": 0.911,
14
+ "step": 780
15
+ },
16
+ {
17
+ "epoch": 0.1,
18
+ "eval_accuracy": 0.4531576503366612,
19
+ "eval_loss": 1.4550466537475586,
20
+ "eval_runtime": 66.5164,
21
+ "eval_samples_per_second": 259.004,
22
+ "step": 780
23
+ },
24
+ {
25
+ "epoch": 0.2,
26
+ "learning_rate": 0.0006921287668771387,
27
+ "loss": 0.6273,
28
+ "step": 1560
29
+ },
30
+ {
31
+ "epoch": 0.2,
32
+ "eval_accuracy": 0.5953680055723242,
33
+ "eval_loss": 0.846651554107666,
34
+ "eval_runtime": 66.1267,
35
+ "eval_samples_per_second": 260.53,
36
+ "step": 1560
37
+ },
38
+ {
39
+ "epoch": 0.3,
40
+ "learning_rate": 0.0007330550166223805,
41
+ "loss": 0.5592,
42
+ "step": 2340
43
+ },
44
+ {
45
+ "epoch": 0.3,
46
+ "eval_accuracy": 0.5935105641978176,
47
+ "eval_loss": 1.0599186420440674,
48
+ "eval_runtime": 66.2608,
49
+ "eval_samples_per_second": 260.003,
50
+ "step": 2340
51
+ },
52
+ {
53
+ "epoch": 0.4,
54
+ "learning_rate": 0.0006283471571048975,
55
+ "loss": 0.3714,
56
+ "step": 3120
57
+ },
58
+ {
59
+ "epoch": 0.4,
60
+ "eval_accuracy": 0.686324587880195,
61
+ "eval_loss": 1.184874415397644,
62
+ "eval_runtime": 66.1411,
63
+ "eval_samples_per_second": 260.473,
64
+ "step": 3120
65
+ },
66
+ {
67
+ "epoch": 0.5,
68
+ "learning_rate": 0.0005236392975874146,
69
+ "loss": 0.2976,
70
+ "step": 3900
71
+ },
72
+ {
73
+ "epoch": 0.5,
74
+ "eval_accuracy": 0.7681100534014396,
75
+ "eval_loss": 0.6318939328193665,
76
+ "eval_runtime": 66.3309,
77
+ "eval_samples_per_second": 259.728,
78
+ "step": 3900
79
+ },
80
+ {
81
+ "epoch": 0.6,
82
+ "learning_rate": 0.0004189314380699318,
83
+ "loss": 0.2564,
84
+ "step": 4680
85
+ },
86
+ {
87
+ "epoch": 0.6,
88
+ "eval_accuracy": 0.7807058277223126,
89
+ "eval_loss": 0.7283642888069153,
90
+ "eval_runtime": 66.3416,
91
+ "eval_samples_per_second": 259.686,
92
+ "step": 4680
93
+ },
94
+ {
95
+ "epoch": 0.7,
96
+ "learning_rate": 0.0003142235785524487,
97
+ "loss": 0.2336,
98
+ "step": 5460
99
+ },
100
+ {
101
+ "epoch": 0.7,
102
+ "eval_accuracy": 0.8563965637334572,
103
+ "eval_loss": 0.5184123516082764,
104
+ "eval_runtime": 66.3416,
105
+ "eval_samples_per_second": 259.686,
106
+ "step": 5460
107
+ },
108
+ {
109
+ "epoch": 0.8,
110
+ "learning_rate": 0.0002095157190349659,
111
+ "loss": 0.1731,
112
+ "step": 6240
113
+ },
114
+ {
115
+ "epoch": 0.8,
116
+ "eval_accuracy": 0.8288832133735778,
117
+ "eval_loss": 0.5823884010314941,
118
+ "eval_runtime": 66.1535,
119
+ "eval_samples_per_second": 260.425,
120
+ "step": 6240
121
+ },
122
+ {
123
+ "epoch": 0.9,
124
+ "learning_rate": 0.00010480785951748295,
125
+ "loss": 0.1451,
126
+ "step": 7020
127
+ },
128
+ {
129
+ "epoch": 0.9,
130
+ "eval_accuracy": 0.886812166241003,
131
+ "eval_loss": 0.39658036828041077,
132
+ "eval_runtime": 66.3555,
133
+ "eval_samples_per_second": 259.632,
134
+ "step": 7020
135
+ }
136
+ ],
137
+ "max_steps": 7800,
138
+ "num_train_epochs": 1,
139
+ "total_flos": 0,
140
+ "trial_name": null,
141
+ "trial_params": {
142
+ "learning_rate": 0.0008039341830649843,
143
+ "lr_scheduler_type": "polynomial",
144
+ "num_train_epochs": 1,
145
+ "per_device_train_batch_size": 12,
146
+ "seed": 73.15243080311434,
147
+ "warmup_steps": 1812.6785581609881,
148
+ "weight_decay": 0.2588277764570262
149
+ }
150
+ }
geneformer-12L-30M/config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertForMaskedLM"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.02,
6
+ "gradient_checkpointing": false,
7
+ "hidden_act": "relu",
8
+ "hidden_dropout_prob": 0.02,
9
+ "hidden_size": 512,
10
+ "initializer_range": 0.02,
11
+ "intermediate_size": 1024,
12
+ "layer_norm_eps": 1e-12,
13
+ "max_position_embeddings": 2048,
14
+ "model_type": "bert",
15
+ "num_attention_heads": 8,
16
+ "num_hidden_layers": 12,
17
+ "pad_token_id": 0,
18
+ "position_embedding_type": "absolute",
19
+ "transformers_version": "4.6.0",
20
+ "type_vocab_size": 2,
21
+ "use_cache": true,
22
+ "vocab_size": 25426
23
+ }
geneformer/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import tokenizer
2
+ from . import pretrainer
3
+ from . import collator_for_classification
4
+ from . import in_silico_perturber
5
+ from . import in_silico_perturber_stats
6
+ from .tokenizer import TranscriptomeTokenizer
7
+ from .pretrainer import GeneformerPretrainer
8
+ from .collator_for_classification import DataCollatorForGeneClassification
9
+ from .collator_for_classification import DataCollatorForCellClassification
10
+ from .emb_extractor import EmbExtractor
11
+ from .in_silico_perturber import InSilicoPerturber
12
+ from .in_silico_perturber_stats import InSilicoPerturberStats
geneformer/collator_for_classification.py ADDED
@@ -0,0 +1,602 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Geneformer collator for gene and cell classification.
3
+
4
+ Huggingface data collator modified to accommodate single-cell transcriptomics data for gene and cell classification.
5
+ """
6
+ import numpy as np
7
+ import torch
8
+ import warnings
9
+ from enum import Enum
10
+ from typing import Dict, List, Optional, Union
11
+
12
+ from transformers import (
13
+ DataCollatorForTokenClassification,
14
+ SpecialTokensMixin,
15
+ BatchEncoding,
16
+ )
17
+ from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
18
+ from transformers.utils.generic import _is_tensorflow, _is_torch
19
+
20
+ from .pretrainer import token_dictionary
21
+
22
+ EncodedInput = List[int]
23
+ logger = logging.get_logger(__name__)
24
+ VERY_LARGE_INTEGER = int(
25
+ 1e30
26
+ ) # This is used to set the max input length for a model with infinite size input
27
+ LARGE_INTEGER = int(
28
+ 1e20
29
+ ) # This is used when we need something big but slightly smaller than VERY_LARGE_INTEGER
30
+
31
+ # precollator functions
32
+
33
+ class ExplicitEnum(Enum):
34
+ """
35
+ Enum with more explicit error message for missing values.
36
+ """
37
+
38
+ @classmethod
39
+ def _missing_(cls, value):
40
+ raise ValueError(
41
+ "%r is not a valid %s, please select one of %s"
42
+ % (value, cls.__name__, str(list(cls._value2member_map_.keys())))
43
+ )
44
+
45
+ class TruncationStrategy(ExplicitEnum):
46
+ """
47
+ Possible values for the ``truncation`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
48
+ tab-completion in an IDE.
49
+ """
50
+
51
+ ONLY_FIRST = "only_first"
52
+ ONLY_SECOND = "only_second"
53
+ LONGEST_FIRST = "longest_first"
54
+ DO_NOT_TRUNCATE = "do_not_truncate"
55
+
56
+
57
+
58
+ class PaddingStrategy(ExplicitEnum):
59
+ """
60
+ Possible values for the ``padding`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for tab-completion
61
+ in an IDE.
62
+ """
63
+
64
+ LONGEST = "longest"
65
+ MAX_LENGTH = "max_length"
66
+ DO_NOT_PAD = "do_not_pad"
67
+
68
+
69
+
70
+ class TensorType(ExplicitEnum):
71
+ """
72
+ Possible values for the ``return_tensors`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
73
+ tab-completion in an IDE.
74
+ """
75
+
76
+ PYTORCH = "pt"
77
+ TENSORFLOW = "tf"
78
+ NUMPY = "np"
79
+ JAX = "jax"
80
+
81
+
82
+ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
83
+ mask_token = "<mask>"
84
+ mask_token_id = token_dictionary.get("<mask>")
85
+ pad_token = "<pad>"
86
+ pad_token_id = token_dictionary.get("<pad>")
87
+ padding_side = "right"
88
+ all_special_ids = [
89
+ token_dictionary.get("<mask>"),
90
+ token_dictionary.get("<pad>")
91
+ ]
92
+ model_input_names = ["input_ids"]
93
+
94
+ def _get_padding_truncation_strategies(
95
+ self, padding=True, truncation=False, max_length=None, pad_to_multiple_of=None, verbose=True, **kwargs
96
+ ):
97
+ """
98
+ Find the correct padding/truncation strategy with backward compatibility for old arguments (truncation_strategy
99
+ and pad_to_max_length) and behaviors.
100
+ """
101
+ old_truncation_strategy = kwargs.pop("truncation_strategy", "do_not_truncate")
102
+ old_pad_to_max_length = kwargs.pop("pad_to_max_length", False)
103
+
104
+ # Backward compatibility for previous behavior, maybe we should deprecate it:
105
+ # If you only set max_length, it activates truncation for max_length
106
+ if max_length is not None and padding is False and truncation is False:
107
+ if verbose:
108
+ if not self.deprecation_warnings.get("Truncation-not-explicitly-activated", False):
109
+ logger.warning(
110
+ "Truncation was not explicitly activated but `max_length` is provided a specific value, "
111
+ "please use `truncation=True` to explicitly truncate examples to max length. "
112
+ "Defaulting to 'longest_first' truncation strategy. "
113
+ "If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy "
114
+ "more precisely by providing a specific strategy to `truncation`."
115
+ )
116
+ self.deprecation_warnings["Truncation-not-explicitly-activated"] = True
117
+ truncation = "longest_first"
118
+
119
+ # Get padding strategy
120
+ if padding is False and old_pad_to_max_length:
121
+ if verbose:
122
+ warnings.warn(
123
+ "The `pad_to_max_length` argument is deprecated and will be removed in a future version, "
124
+ "use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or "
125
+ "use `padding='max_length'` to pad to a max length. In this case, you can give a specific "
126
+ "length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the "
127
+ "maximal input size of the model (e.g. 512 for Bert).",
128
+ FutureWarning,
129
+ )
130
+ if max_length is None:
131
+ padding_strategy = PaddingStrategy.LONGEST
132
+ else:
133
+ padding_strategy = PaddingStrategy.MAX_LENGTH
134
+ elif padding is not False:
135
+ if padding is True:
136
+ padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch
137
+ elif not isinstance(padding, PaddingStrategy):
138
+ padding_strategy = PaddingStrategy(padding)
139
+ elif isinstance(padding, PaddingStrategy):
140
+ padding_strategy = padding
141
+ else:
142
+ padding_strategy = PaddingStrategy.DO_NOT_PAD
143
+
144
+ # Get truncation strategy
145
+ if truncation is False and old_truncation_strategy != "do_not_truncate":
146
+ if verbose:
147
+ warnings.warn(
148
+ "The `truncation_strategy` argument is deprecated and will be removed in a future version, "
149
+ "use `truncation=True` to truncate examples to a max length. You can give a specific "
150
+ "length with `max_length` (e.g. `max_length=45`) or leave max_length to None to truncate to the "
151
+ "maximal input size of the model (e.g. 512 for Bert). "
152
+ " If you have pairs of inputs, you can give a specific truncation strategy selected among "
153
+ "`truncation='only_first'` (will only truncate the first sentence in the pairs) "
154
+ "`truncation='only_second'` (will only truncate the second sentence in the pairs) "
155
+ "or `truncation='longest_first'` (will iteratively remove tokens from the longest sentence in the pairs).",
156
+ FutureWarning,
157
+ )
158
+ truncation_strategy = TruncationStrategy(old_truncation_strategy)
159
+ elif truncation is not False:
160
+ if truncation is True:
161
+ truncation_strategy = (
162
+ TruncationStrategy.LONGEST_FIRST
163
+ ) # Default to truncate the longest sequences in pairs of inputs
164
+ elif not isinstance(truncation, TruncationStrategy):
165
+ truncation_strategy = TruncationStrategy(truncation)
166
+ elif isinstance(truncation, TruncationStrategy):
167
+ truncation_strategy = truncation
168
+ else:
169
+ truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
170
+
171
+ # Set max length if needed
172
+ if max_length is None:
173
+ if padding_strategy == PaddingStrategy.MAX_LENGTH:
174
+ if self.model_max_length > LARGE_INTEGER:
175
+ if verbose:
176
+ if not self.deprecation_warnings.get("Asking-to-pad-to-max_length", False):
177
+ logger.warning(
178
+ "Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. "
179
+ "Default to no padding."
180
+ )
181
+ self.deprecation_warnings["Asking-to-pad-to-max_length"] = True
182
+ padding_strategy = PaddingStrategy.DO_NOT_PAD
183
+ else:
184
+ max_length = self.model_max_length
185
+
186
+ if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE:
187
+ if self.model_max_length > LARGE_INTEGER:
188
+ if verbose:
189
+ if not self.deprecation_warnings.get("Asking-to-truncate-to-max_length", False):
190
+ logger.warning(
191
+ "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. "
192
+ "Default to no truncation."
193
+ )
194
+ self.deprecation_warnings["Asking-to-truncate-to-max_length"] = True
195
+ truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
196
+ else:
197
+ max_length = self.model_max_length
198
+
199
+ # Test if we have a padding token
200
+ if padding_strategy != PaddingStrategy.DO_NOT_PAD and (not self.pad_token or self.pad_token_id < 0):
201
+ raise ValueError(
202
+ "Asking to pad but the tokenizer does not have a padding token. "
203
+ "Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` "
204
+ "or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`."
205
+ )
206
+
207
+ # Check that we will truncate to a multiple of pad_to_multiple_of if both are provided
208
+ if (
209
+ truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE
210
+ and padding_strategy != PaddingStrategy.DO_NOT_PAD
211
+ and pad_to_multiple_of is not None
212
+ and max_length is not None
213
+ and (max_length % pad_to_multiple_of != 0)
214
+ ):
215
+ raise ValueError(
216
+ f"Truncation and padding are both activated but "
217
+ f"truncation length ({max_length}) is not a multiple of pad_to_multiple_of ({pad_to_multiple_of})."
218
+ )
219
+
220
+ return padding_strategy, truncation_strategy, max_length, kwargs
221
+
222
+ def pad(
223
+ self,
224
+ encoded_inputs: Union[
225
+ BatchEncoding,
226
+ List[BatchEncoding],
227
+ Dict[str, EncodedInput],
228
+ Dict[str, List[EncodedInput]],
229
+ List[Dict[str, EncodedInput]],
230
+ ],
231
+ class_type, # options: "gene" or "cell"
232
+ padding: Union[bool, str, PaddingStrategy] = True,
233
+ max_length: Optional[int] = None,
234
+ pad_to_multiple_of: Optional[int] = None,
235
+ return_attention_mask: Optional[bool] = True,
236
+ return_tensors: Optional[Union[str, TensorType]] = None,
237
+ verbose: bool = True,
238
+ ) -> BatchEncoding:
239
+ """
240
+ Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length
241
+ in the batch.
242
+
243
+ Padding side (left/right) padding token ids are defined at the tokenizer level (with ``self.padding_side``,
244
+ ``self.pad_token_id`` and ``self.pad_token_type_id``)
245
+
246
+ .. note::
247
+
248
+ If the ``encoded_inputs`` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the
249
+ result will use the same type unless you provide a different tensor type with ``return_tensors``. In the
250
+ case of PyTorch tensors, you will lose the specific device of your tensors however.
251
+
252
+ Args:
253
+ encoded_inputs (:class:`~transformers.BatchEncoding`, list of :class:`~transformers.BatchEncoding`, :obj:`Dict[str, List[int]]`, :obj:`Dict[str, List[List[int]]` or :obj:`List[Dict[str, List[int]]]`):
254
+ Tokenized inputs. Can represent one input (:class:`~transformers.BatchEncoding` or :obj:`Dict[str,
255
+ List[int]]`) or a batch of tokenized inputs (list of :class:`~transformers.BatchEncoding`, `Dict[str,
256
+ List[List[int]]]` or `List[Dict[str, List[int]]]`) so you can use this method during preprocessing as
257
+ well as in a PyTorch Dataloader collate function.
258
+
259
+ Instead of :obj:`List[int]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors),
260
+ see the note above for the return type.
261
+ padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
262
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
263
+ index) among:
264
+
265
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
266
+ single sequence if provided).
267
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
268
+ maximum acceptable input length for the model if that argument is not provided.
269
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
270
+ different lengths).
271
+ max_length (:obj:`int`, `optional`):
272
+ Maximum length of the returned list and optionally padding length (see above).
273
+ pad_to_multiple_of (:obj:`int`, `optional`):
274
+ If set will pad the sequence to a multiple of the provided value.
275
+
276
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
277
+ >= 7.5 (Volta).
278
+ return_attention_mask (:obj:`bool`, `optional`):
279
+ Whether to return the attention mask. If left to the default, will return the attention mask according
280
+ to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute.
281
+
282
+ `What are attention masks? <../glossary.html#attention-mask>`__
283
+ return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`):
284
+ If set, will return tensors instead of list of python integers. Acceptable values are:
285
+
286
+ * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
287
+ * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
288
+ * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
289
+ verbose (:obj:`bool`, `optional`, defaults to :obj:`True`):
290
+ Whether or not to print more information and warnings.
291
+ """
292
+ # If we have a list of dicts, let's convert it in a dict of lists
293
+ # We do this to allow using this method as a collate_fn function in PyTorch Dataloader
294
+ if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], (dict, BatchEncoding)):
295
+ encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()}
296
+
297
+ # The model's main input name, usually `input_ids`, has be passed for padding
298
+ if self.model_input_names[0] not in encoded_inputs:
299
+ raise ValueError(
300
+ "You should supply an encoding or a list of encodings to this method"
301
+ f"that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}"
302
+ )
303
+
304
+ required_input = encoded_inputs[self.model_input_names[0]]
305
+
306
+ if not required_input:
307
+ if return_attention_mask:
308
+ encoded_inputs["attention_mask"] = []
309
+ return encoded_inputs
310
+
311
+ # If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects
312
+ # and rebuild them afterwards if no return_tensors is specified
313
+ # Note that we lose the specific device the tensor may be on for PyTorch
314
+
315
+ first_element = required_input[0]
316
+ if isinstance(first_element, (list, tuple)):
317
+ # first_element might be an empty list/tuple in some edge cases so we grab the first non empty element.
318
+ index = 0
319
+ while len(required_input[index]) == 0:
320
+ index += 1
321
+ if index < len(required_input):
322
+ first_element = required_input[index][0]
323
+ # At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
324
+ if not isinstance(first_element, (int, list, tuple)):
325
+ if is_tf_available() and _is_tensorflow(first_element):
326
+ return_tensors = "tf" if return_tensors is None else return_tensors
327
+ elif is_torch_available() and _is_torch(first_element):
328
+ return_tensors = "pt" if return_tensors is None else return_tensors
329
+ elif isinstance(first_element, np.ndarray):
330
+ return_tensors = "np" if return_tensors is None else return_tensors
331
+ else:
332
+ raise ValueError(
333
+ f"type of {first_element} unknown: {type(first_element)}. "
334
+ f"Should be one of a python, numpy, pytorch or tensorflow object."
335
+ )
336
+
337
+ for key, value in encoded_inputs.items():
338
+ encoded_inputs[key] = to_py_obj(value)
339
+
340
+ # Convert padding_strategy in PaddingStrategy
341
+ padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies(
342
+ padding=padding, max_length=max_length, verbose=verbose
343
+ )
344
+
345
+ required_input = encoded_inputs[self.model_input_names[0]]
346
+ if required_input and not isinstance(required_input[0], (list, tuple)):
347
+ encoded_inputs = self._pad(
348
+ encoded_inputs,
349
+ class_type=class_type,
350
+ max_length=max_length,
351
+ padding_strategy=padding_strategy,
352
+ pad_to_multiple_of=pad_to_multiple_of,
353
+ return_attention_mask=return_attention_mask,
354
+ )
355
+ return BatchEncoding(encoded_inputs, tensor_type=return_tensors)
356
+
357
+ batch_size = len(required_input)
358
+ assert all(
359
+ len(v) == batch_size for v in encoded_inputs.values()
360
+ ), "Some items in the output dictionary have a different batch size than others."
361
+
362
+ if padding_strategy == PaddingStrategy.LONGEST:
363
+ max_length = max(len(inputs) for inputs in required_input)
364
+ padding_strategy = PaddingStrategy.MAX_LENGTH
365
+
366
+ batch_outputs = {}
367
+ for i in range(batch_size):
368
+ inputs = dict((k, v[i]) for k, v in encoded_inputs.items())
369
+ outputs = self._pad(
370
+ inputs,
371
+ class_type=class_type,
372
+ max_length=max_length,
373
+ padding_strategy=padding_strategy,
374
+ pad_to_multiple_of=pad_to_multiple_of,
375
+ return_attention_mask=return_attention_mask,
376
+ )
377
+
378
+ for key, value in outputs.items():
379
+ if key not in batch_outputs:
380
+ batch_outputs[key] = []
381
+ batch_outputs[key].append(value)
382
+ if class_type == "cell":
383
+ del batch_outputs["label"]
384
+ return BatchEncoding(batch_outputs, tensor_type=return_tensors)
385
+
386
+ def _pad(
387
+ self,
388
+ encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
389
+ class_type, # options: "gene" or "cell"
390
+ max_length: Optional[int] = None,
391
+ padding_strategy: PaddingStrategy = PaddingStrategy.LONGEST,
392
+ pad_to_multiple_of: Optional[int] = None,
393
+ return_attention_mask: Optional[bool] = True,
394
+ ) -> dict:
395
+ """
396
+ Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
397
+
398
+ Args:
399
+ encoded_inputs: Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
400
+ max_length: maximum length of the returned list and optionally padding length (see below).
401
+ Will truncate by taking into account the special tokens.
402
+ padding_strategy: PaddingStrategy to use for padding.
403
+
404
+ - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
405
+ - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
406
+ - PaddingStrategy.DO_NOT_PAD: Do not pad
407
+ The tokenizer padding sides are defined in self.padding_side:
408
+
409
+ - 'left': pads on the left of the sequences
410
+ - 'right': pads on the right of the sequences
411
+ pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
412
+ This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
413
+ >= 7.5 (Volta).
414
+ return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics)
415
+ """
416
+ # Load from model defaults
417
+ if return_attention_mask is None:
418
+ return_attention_mask = "attention_mask" in self.model_input_names
419
+
420
+ required_input = encoded_inputs[self.model_input_names[0]]
421
+
422
+ if padding_strategy == PaddingStrategy.LONGEST:
423
+ max_length = len(required_input)
424
+
425
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
426
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
427
+
428
+ needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
429
+
430
+ if needs_to_be_padded:
431
+ difference = max_length - len(required_input)
432
+ if self.padding_side == "right":
433
+ if return_attention_mask:
434
+ encoded_inputs["attention_mask"] = [1] * len(required_input) + [0] * difference
435
+ if "token_type_ids" in encoded_inputs:
436
+ encoded_inputs["token_type_ids"] = (
437
+ encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
438
+ )
439
+ if "special_tokens_mask" in encoded_inputs:
440
+ encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
441
+ encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
442
+ if class_type == "gene":
443
+ encoded_inputs["labels"] = encoded_inputs["labels"] + [-100] * difference
444
+ elif self.padding_side == "left":
445
+ if return_attention_mask:
446
+ encoded_inputs["attention_mask"] = [0] * difference + [1] * len(required_input)
447
+ if "token_type_ids" in encoded_inputs:
448
+ encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
449
+ "token_type_ids"
450
+ ]
451
+ if "special_tokens_mask" in encoded_inputs:
452
+ encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
453
+ encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
454
+ if class_type == "gene":
455
+ encoded_inputs["labels"] = [-100] * difference + encoded_inputs["labels"]
456
+ else:
457
+ raise ValueError("Invalid padding strategy:" + str(self.padding_side))
458
+ elif return_attention_mask and "attention_mask" not in encoded_inputs:
459
+ encoded_inputs["attention_mask"] = [1] * len(required_input)
460
+
461
+ return encoded_inputs
462
+
463
+ def get_special_tokens_mask(
464
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
465
+ ) -> List[int]:
466
+ """
467
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
468
+ special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
469
+ Args:
470
+ token_ids_0 (:obj:`List[int]`):
471
+ List of ids of the first sequence.
472
+ token_ids_1 (:obj:`List[int]`, `optional`):
473
+ List of ids of the second sequence.
474
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
475
+ Whether or not the token list is already formatted with special tokens for the model.
476
+ Returns:
477
+ A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
478
+ """
479
+ assert already_has_special_tokens and token_ids_1 is None, (
480
+ "You cannot use ``already_has_special_tokens=False`` with this tokenizer. "
481
+ "Please use a slow (full python) tokenizer to activate this argument."
482
+ "Or set `return_special_tokens_mask=True` when calling the encoding method "
483
+ "to get the special tokens mask in any tokenizer. "
484
+ )
485
+
486
+ all_special_ids = self.all_special_ids # cache the property
487
+
488
+ special_tokens_mask = [1 if token in all_special_ids else 0 for token in token_ids_0]
489
+
490
+ return special_tokens_mask
491
+
492
+ def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
493
+ """
494
+ Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the
495
+ vocabulary.
496
+ Args:
497
+ tokens (:obj:`str` or :obj:`List[str]`): One or several token(s) to convert to token id(s).
498
+ Returns:
499
+ :obj:`int` or :obj:`List[int]`: The token id or list of token ids.
500
+ """
501
+ if tokens is None:
502
+ return None
503
+
504
+ if isinstance(tokens, str):
505
+ return self._convert_token_to_id_with_added_voc(tokens)
506
+
507
+ ids = []
508
+ for token in tokens:
509
+ ids.append(self._convert_token_to_id_with_added_voc(token))
510
+ return ids
511
+
512
+ def _convert_token_to_id_with_added_voc(self, token):
513
+ if token is None:
514
+ return None
515
+
516
+ return token_dictionary.get(token)
517
+
518
+ def __len__(self):
519
+ return len(token_dictionary)
520
+
521
+
522
+ # collator functions
523
+
524
+ class DataCollatorForGeneClassification(DataCollatorForTokenClassification):
525
+ """
526
+ Data collator that will dynamically pad the inputs received, as well as the labels.
527
+ Args:
528
+ tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
529
+ The tokenizer used for encoding the data.
530
+ padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
531
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
532
+ among:
533
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
534
+ sequence if provided).
535
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
536
+ maximum acceptable input length for the model if that argument is not provided.
537
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
538
+ different lengths).
539
+ max_length (:obj:`int`, `optional`):
540
+ Maximum length of the returned list and optionally padding length (see above).
541
+ pad_to_multiple_of (:obj:`int`, `optional`):
542
+ If set will pad the sequence to a multiple of the provided value.
543
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
544
+ 7.5 (Volta).
545
+ label_pad_token_id (:obj:`int`, `optional`, defaults to -100):
546
+ The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
547
+ """
548
+
549
+ tokenizer = PrecollatorForGeneAndCellClassification()
550
+ class_type = "gene"
551
+ padding: Union[bool, str, PaddingStrategy] = True
552
+ max_length: Optional[int] = None
553
+ pad_to_multiple_of: Optional[int] = None
554
+ label_pad_token_id: int = -100
555
+
556
+ def __init__(self, *args, **kwargs) -> None:
557
+ super().__init__(
558
+ tokenizer=self.tokenizer,
559
+ padding=self.padding,
560
+ max_length=self.max_length,
561
+ pad_to_multiple_of=self.pad_to_multiple_of,
562
+ label_pad_token_id=self.label_pad_token_id,
563
+ *args, **kwargs)
564
+
565
+ def _prepare_batch(self, features):
566
+ label_name = "label" if "label" in features[0].keys() else "labels"
567
+ labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
568
+ batch = self.tokenizer.pad(
569
+ features,
570
+ class_type=self.class_type,
571
+ padding=self.padding,
572
+ max_length=self.max_length,
573
+ pad_to_multiple_of=self.pad_to_multiple_of,
574
+ return_tensors="pt",
575
+ )
576
+ return batch
577
+
578
+ def __call__(self, features):
579
+ batch = self._prepare_batch(features)
580
+
581
+ batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
582
+ return batch
583
+
584
+
585
+ class DataCollatorForCellClassification(DataCollatorForGeneClassification):
586
+
587
+ class_type = "cell"
588
+
589
+ def _prepare_batch(self, features):
590
+
591
+ batch = super()._prepare_batch(features)
592
+
593
+ # Special handling for labels.
594
+ # Ensure that tensor is created with the correct type
595
+ # (it should be automatically the case, but let's make sure of it.)
596
+ first = features[0]
597
+ if "label" in first and first["label"] is not None:
598
+ label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"]
599
+ dtype = torch.long if isinstance(label, int) else torch.float
600
+ batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
601
+
602
+ return batch
geneformer/emb_extractor.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Geneformer embedding extractor.
3
+
4
+ Usage:
5
+ from geneformer import EmbExtractor
6
+ embex = EmbExtractor(model_type="CellClassifier",
7
+ num_classes=3,
8
+ emb_mode="cell",
9
+ cell_emb_style="mean_pool",
10
+ filter_data={"cell_type":["cardiomyocyte"]},
11
+ max_ncells=1000,
12
+ max_ncells_to_plot=1000,
13
+ emb_layer=-1,
14
+ emb_label=["disease","cell_type"],
15
+ labels_to_plot=["disease","cell_type"],
16
+ forward_batch_size=100,
17
+ nproc=16,
18
+ summary_stat=None)
19
+ embs = embex.extract_embs("path/to/model",
20
+ "path/to/input_data",
21
+ "path/to/output_directory",
22
+ "output_prefix")
23
+ embex.plot_embs(embs=embs,
24
+ plot_style="heatmap",
25
+ output_directory="path/to/output_directory",
26
+ output_prefix="output_prefix")
27
+
28
+ """
29
+
30
+ # imports
31
+ import logging
32
+ import anndata
33
+ import matplotlib.pyplot as plt
34
+ import numpy as np
35
+ import pandas as pd
36
+ import pickle
37
+ from tdigest import TDigest
38
+ import scanpy as sc
39
+ import seaborn as sns
40
+ import torch
41
+ from collections import Counter
42
+ from pathlib import Path
43
+ from tqdm.notebook import trange
44
+ from transformers import BertForMaskedLM, BertForTokenClassification, BertForSequenceClassification
45
+
46
+ from .tokenizer import TOKEN_DICTIONARY_FILE
47
+
48
+ from .in_silico_perturber import downsample_and_sort, \
49
+ gen_attention_mask, \
50
+ get_model_input_size, \
51
+ load_and_filter, \
52
+ load_model, \
53
+ mean_nonpadding_embs, \
54
+ pad_tensor_list, \
55
+ quant_layers
56
+
57
+ logger = logging.getLogger(__name__)
58
+
59
+ # extract embeddings
60
+ def get_embs(model,
61
+ filtered_input_data,
62
+ emb_mode,
63
+ layer_to_quant,
64
+ pad_token_id,
65
+ forward_batch_size,
66
+ summary_stat):
67
+
68
+ model_input_size = get_model_input_size(model)
69
+ total_batch_length = len(filtered_input_data)
70
+
71
+ if summary_stat is None:
72
+ embs_list = []
73
+ elif summary_stat is not None:
74
+ # test embedding extraction for example cell and extract # emb dims
75
+ example = filtered_input_data.select([i for i in range(1)])
76
+ example.set_format(type="torch")
77
+ emb_dims = test_emb(model, example["input_ids"], layer_to_quant)
78
+ # initiate tdigests for # of emb dims
79
+ embs_tdigests = [TDigest() for _ in range(emb_dims)]
80
+
81
+ for i in trange(0, total_batch_length, forward_batch_size):
82
+ max_range = min(i+forward_batch_size, total_batch_length)
83
+
84
+ minibatch = filtered_input_data.select([i for i in range(i, max_range)])
85
+ max_len = max(minibatch["length"])
86
+ original_lens = torch.tensor(minibatch["length"]).to("cuda")
87
+ minibatch.set_format(type="torch")
88
+
89
+ input_data_minibatch = minibatch["input_ids"]
90
+ input_data_minibatch = pad_tensor_list(input_data_minibatch,
91
+ max_len,
92
+ pad_token_id,
93
+ model_input_size)
94
+
95
+ with torch.no_grad():
96
+ outputs = model(
97
+ input_ids = input_data_minibatch.to("cuda"),
98
+ attention_mask = gen_attention_mask(minibatch)
99
+ )
100
+
101
+ embs_i = outputs.hidden_states[layer_to_quant]
102
+
103
+ if emb_mode == "cell":
104
+ mean_embs = mean_nonpadding_embs(embs_i, original_lens)
105
+ if summary_stat is None:
106
+ embs_list += [mean_embs]
107
+ elif summary_stat is not None:
108
+ # update tdigests with current batch for each emb dim
109
+ # note: tdigest batch update known to be slow so updating serially
110
+ [embs_tdigests[j].update(mean_embs[i,j].item()) for i in range(mean_embs.size(0)) for j in range(emb_dims)]
111
+
112
+ del outputs
113
+ del minibatch
114
+ del input_data_minibatch
115
+ del embs_i
116
+ del mean_embs
117
+ torch.cuda.empty_cache()
118
+
119
+ if summary_stat is None:
120
+ embs_stack = torch.cat(embs_list)
121
+ # calculate summary stat embs from approximated tdigests
122
+ elif summary_stat is not None:
123
+ if summary_stat == "mean":
124
+ summary_emb_list = [embs_tdigests[i].trimmed_mean(0,100) for i in range(emb_dims)]
125
+ elif summary_stat == "median":
126
+ summary_emb_list = [embs_tdigests[i].percentile(50) for i in range(emb_dims)]
127
+ embs_stack = torch.tensor(summary_emb_list)
128
+
129
+ return embs_stack
130
+
131
+ def test_emb(model, example, layer_to_quant):
132
+ with torch.no_grad():
133
+ outputs = model(
134
+ input_ids = example.to("cuda")
135
+ )
136
+
137
+ embs_test = outputs.hidden_states[layer_to_quant]
138
+ return embs_test.size()[2]
139
+
140
+ def label_embs(embs, downsampled_data, emb_labels):
141
+ embs_df = pd.DataFrame(embs.cpu())
142
+ if emb_labels is not None:
143
+ for label in emb_labels:
144
+ emb_label = downsampled_data[label]
145
+ embs_df[label] = emb_label
146
+ return embs_df
147
+
148
+ def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict):
149
+ only_embs_df = embs_df.iloc[:,:emb_dims]
150
+ only_embs_df.index = pd.RangeIndex(0, only_embs_df.shape[0], name=None).astype(str)
151
+ only_embs_df.columns = pd.RangeIndex(0, only_embs_df.shape[1], name=None).astype(str)
152
+ vars_dict = {"embs": only_embs_df.columns}
153
+ obs_dict = {"cell_id": list(only_embs_df.index),
154
+ f"{label}": list(embs_df[label])}
155
+ adata = anndata.AnnData(X=only_embs_df, obs=obs_dict, var=vars_dict)
156
+ sc.tl.pca(adata, svd_solver='arpack')
157
+ sc.pp.neighbors(adata)
158
+ sc.tl.umap(adata)
159
+ sns.set(rc={'figure.figsize':(10,10)}, font_scale=2.3)
160
+ sns.set_style("white")
161
+ default_kwargs_dict = {"palette":"Set2", "size":200}
162
+ if kwargs_dict is not None:
163
+ default_kwargs_dict.update(kwargs_dict)
164
+
165
+ sc.pl.umap(adata, color=label, save=output_file, **default_kwargs_dict)
166
+
167
+ def gen_heatmap_class_colors(labels, df):
168
+ pal = sns.cubehelix_palette(len(Counter(labels).keys()), light=0.9, dark=0.1, hue=1, reverse=True, start=1, rot=-2)
169
+ lut = dict(zip(map(str, Counter(labels).keys()), pal))
170
+ colors = pd.Series(labels, index=df.index).map(lut)
171
+ return colors
172
+
173
+ def gen_heatmap_class_dict(classes, label_colors_series):
174
+ class_color_dict_df = pd.DataFrame({"classes": classes, "color": label_colors_series})
175
+ class_color_dict_df = class_color_dict_df.drop_duplicates(subset=["classes"])
176
+ return dict(zip(class_color_dict_df["classes"],class_color_dict_df["color"]))
177
+
178
+ def make_colorbar(embs_df, label):
179
+
180
+ labels = list(embs_df[label])
181
+
182
+ cell_type_colors = gen_heatmap_class_colors(labels, embs_df)
183
+ label_colors = pd.DataFrame(cell_type_colors, columns=[label])
184
+
185
+ for i,row in label_colors.iterrows():
186
+ colors=row[0]
187
+ if len(colors)!=3 or any(np.isnan(colors)):
188
+ print(i,colors)
189
+
190
+ label_colors.isna().sum()
191
+
192
+ # create dictionary for colors and classes
193
+ label_color_dict = gen_heatmap_class_dict(labels, label_colors[label])
194
+ return label_colors, label_color_dict
195
+
196
+ def plot_heatmap(embs_df, emb_dims, label, output_file, kwargs_dict):
197
+ sns.set_style("white")
198
+ sns.set(font_scale=2)
199
+ plt.figure(figsize=(15, 15), dpi=150)
200
+ label_colors, label_color_dict = make_colorbar(embs_df, label)
201
+
202
+ default_kwargs_dict = {"row_cluster": True,
203
+ "col_cluster": True,
204
+ "row_colors": label_colors,
205
+ "standard_scale": 1,
206
+ "linewidths": 0,
207
+ "xticklabels": False,
208
+ "yticklabels": False,
209
+ "figsize": (15,15),
210
+ "center": 0,
211
+ "cmap": "magma"}
212
+
213
+ if kwargs_dict is not None:
214
+ default_kwargs_dict.update(kwargs_dict)
215
+ g = sns.clustermap(embs_df.iloc[:,0:emb_dims].apply(pd.to_numeric), **default_kwargs_dict)
216
+
217
+ plt.setp(g.ax_row_colors.get_xmajorticklabels(), rotation=45, ha="right")
218
+
219
+ for label_color in list(label_color_dict.keys()):
220
+ g.ax_col_dendrogram.bar(0, 0, color=label_color_dict[label_color], label=label_color, linewidth=0)
221
+
222
+ l1 = g.ax_col_dendrogram.legend(title=f"{label}",
223
+ loc="lower center",
224
+ ncol=4,
225
+ bbox_to_anchor=(0.5, 1),
226
+ facecolor="white")
227
+
228
+ plt.savefig(output_file, bbox_inches='tight')
229
+
230
+ class EmbExtractor:
231
+ valid_option_dict = {
232
+ "model_type": {"Pretrained","GeneClassifier","CellClassifier"},
233
+ "num_classes": {int},
234
+ "emb_mode": {"cell","gene"},
235
+ "cell_emb_style": {"mean_pool"},
236
+ "filter_data": {None, dict},
237
+ "max_ncells": {None, int},
238
+ "emb_layer": {-1, 0},
239
+ "emb_label": {None, list},
240
+ "labels_to_plot": {None, list},
241
+ "forward_batch_size": {int},
242
+ "nproc": {int},
243
+ "summary_stat": {None, "mean", "median"},
244
+ }
245
+ def __init__(
246
+ self,
247
+ model_type="Pretrained",
248
+ num_classes=0,
249
+ emb_mode="cell",
250
+ cell_emb_style="mean_pool",
251
+ filter_data=None,
252
+ max_ncells=1000,
253
+ emb_layer=-1,
254
+ emb_label=None,
255
+ labels_to_plot=None,
256
+ forward_batch_size=100,
257
+ nproc=4,
258
+ summary_stat=None,
259
+ token_dictionary_file=TOKEN_DICTIONARY_FILE,
260
+ ):
261
+ """
262
+ Initialize embedding extractor.
263
+
264
+ Parameters
265
+ ----------
266
+ model_type : {"Pretrained","GeneClassifier","CellClassifier"}
267
+ Whether model is the pretrained Geneformer or a fine-tuned gene or cell classifier.
268
+ num_classes : int
269
+ If model is a gene or cell classifier, specify number of classes it was trained to classify.
270
+ For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
271
+ emb_mode : {"cell","gene"}
272
+ Whether to output cell or gene embeddings.
273
+ cell_emb_style : "mean_pool"
274
+ Method for summarizing cell embeddings.
275
+ Currently only option is mean pooling of gene embeddings for given cell.
276
+ filter_data : None, dict
277
+ Default is to extract embeddings from all input data.
278
+ Otherwise, dictionary specifying .dataset column name and list of values to filter by.
279
+ max_ncells : None, int
280
+ Maximum number of cells to extract embeddings from.
281
+ Default is 1000 cells randomly sampled from input data.
282
+ If None, will extract embeddings from all cells.
283
+ emb_layer : {-1, 0}
284
+ Embedding layer to extract.
285
+ The last layer is most specifically weighted to optimize the given learning objective.
286
+ Generally, it is best to extract the 2nd to last layer to get a more general representation.
287
+ -1: 2nd to last layer
288
+ 0: last layer
289
+ emb_label : None, list
290
+ List of column name(s) in .dataset to add as labels to embedding output.
291
+ labels_to_plot : None, list
292
+ Cell labels to plot.
293
+ Shown as color bar in heatmap.
294
+ Shown as cell color in umap.
295
+ Plotting umap requires labels to plot.
296
+ forward_batch_size : int
297
+ Batch size for forward pass.
298
+ nproc : int
299
+ Number of CPU processes to use.
300
+ summary_stat : {None, "mean", "median"}
301
+ If not None, outputs only approximated mean or median embedding of input data.
302
+ Recommended if encountering memory constraints while generating goal embedding positions.
303
+ Slower but more memory-efficient.
304
+ token_dictionary_file : Path
305
+ Path to pickle file containing token dictionary (Ensembl ID:token).
306
+ """
307
+
308
+ self.model_type = model_type
309
+ self.num_classes = num_classes
310
+ self.emb_mode = emb_mode
311
+ self.cell_emb_style = cell_emb_style
312
+ self.filter_data = filter_data
313
+ self.max_ncells = max_ncells
314
+ self.emb_layer = emb_layer
315
+ self.emb_label = emb_label
316
+ self.labels_to_plot = labels_to_plot
317
+ self.forward_batch_size = forward_batch_size
318
+ self.nproc = nproc
319
+ self.summary_stat = summary_stat
320
+
321
+ self.validate_options()
322
+
323
+ # load token dictionary (Ensembl IDs:token)
324
+ with open(token_dictionary_file, "rb") as f:
325
+ self.gene_token_dict = pickle.load(f)
326
+
327
+ self.pad_token_id = self.gene_token_dict.get("<pad>")
328
+
329
+
330
+ def validate_options(self):
331
+ # first disallow options under development
332
+ if self.emb_mode == "gene":
333
+ logger.error(
334
+ "Extraction and plotting of gene-level embeddings currently under development. " \
335
+ "Current valid option for 'emb_mode': 'cell'"
336
+ )
337
+ raise
338
+
339
+ # confirm arguments are within valid options and compatible with each other
340
+ for attr_name,valid_options in self.valid_option_dict.items():
341
+ attr_value = self.__dict__[attr_name]
342
+ if type(attr_value) not in {list, dict}:
343
+ if attr_value in valid_options:
344
+ continue
345
+ valid_type = False
346
+ for option in valid_options:
347
+ if (option in [int,list,dict]) and isinstance(attr_value, option):
348
+ valid_type = True
349
+ break
350
+ if valid_type:
351
+ continue
352
+ logger.error(
353
+ f"Invalid option for {attr_name}. " \
354
+ f"Valid options for {attr_name}: {valid_options}"
355
+ )
356
+ raise
357
+
358
+ if self.filter_data is not None:
359
+ for key,value in self.filter_data.items():
360
+ if type(value) != list:
361
+ self.filter_data[key] = [value]
362
+ logger.warning(
363
+ "Values in filter_data dict must be lists. " \
364
+ f"Changing {key} value to list ([{value}]).")
365
+
366
+ def extract_embs(self,
367
+ model_directory,
368
+ input_data_file,
369
+ output_directory,
370
+ output_prefix):
371
+ """
372
+ Extract embeddings from input data and save as results in output_directory.
373
+
374
+ Parameters
375
+ ----------
376
+ model_directory : Path
377
+ Path to directory containing model
378
+ input_data_file : Path
379
+ Path to directory containing .dataset inputs
380
+ output_directory : Path
381
+ Path to directory where embedding data will be saved as csv
382
+ output_prefix : str
383
+ Prefix for output file
384
+ """
385
+
386
+ filtered_input_data = load_and_filter(self.filter_data, self.nproc, input_data_file)
387
+ downsampled_data = downsample_and_sort(filtered_input_data, self.max_ncells)
388
+ model = load_model(self.model_type, self.num_classes, model_directory)
389
+ layer_to_quant = quant_layers(model)+self.emb_layer
390
+ embs = get_embs(model,
391
+ downsampled_data,
392
+ self.emb_mode,
393
+ layer_to_quant,
394
+ self.pad_token_id,
395
+ self.forward_batch_size,
396
+ self.summary_stat)
397
+
398
+ if self.summary_stat is None:
399
+ embs_df = label_embs(embs, downsampled_data, self.emb_label)
400
+ elif self.summary_stat is not None:
401
+ embs_df = pd.DataFrame(embs.cpu()).T
402
+
403
+ # save embeddings to output_path
404
+ output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")
405
+ embs_df.to_csv(output_path)
406
+
407
+ return embs_df
408
+
409
+ def plot_embs(self,
410
+ embs,
411
+ plot_style,
412
+ output_directory,
413
+ output_prefix,
414
+ max_ncells_to_plot=1000,
415
+ kwargs_dict=None):
416
+
417
+ """
418
+ Plot embeddings, coloring by provided labels.
419
+
420
+ Parameters
421
+ ----------
422
+ embs : pandas.core.frame.DataFrame
423
+ Pandas dataframe containing embeddings output from extract_embs
424
+ plot_style : str
425
+ Style of plot: "heatmap" or "umap"
426
+ output_directory : Path
427
+ Path to directory where plots will be saved as pdf
428
+ output_prefix : str
429
+ Prefix for output file
430
+ max_ncells_to_plot : None, int
431
+ Maximum number of cells to plot.
432
+ Default is 1000 cells randomly sampled from embeddings.
433
+ If None, will plot embeddings from all cells.
434
+ kwargs_dict : dict
435
+ Dictionary of kwargs to pass to plotting function.
436
+ """
437
+
438
+ if plot_style not in ["heatmap","umap"]:
439
+ logger.error(
440
+ "Invalid option for 'plot_style'. " \
441
+ "Valid options: {'heatmap','umap'}"
442
+ )
443
+ raise
444
+
445
+ if (plot_style == "umap") and (self.labels_to_plot is None):
446
+ logger.error(
447
+ "Plotting UMAP requires 'labels_to_plot'. "
448
+ )
449
+ raise
450
+
451
+ if max_ncells_to_plot > self.max_ncells:
452
+ max_ncells_to_plot = self.max_ncells
453
+ logger.warning(
454
+ "max_ncells_to_plot must be <= max_ncells. " \
455
+ f"Changing max_ncells_to_plot to {self.max_ncells}.")
456
+
457
+ if (max_ncells_to_plot is not None) \
458
+ and (max_ncells_to_plot < self.max_ncells):
459
+ embs = embs.sample(max_ncells_to_plot, axis=0)
460
+
461
+ if self.emb_label is None:
462
+ label_len = 0
463
+ else:
464
+ label_len = len(self.emb_label)
465
+
466
+ emb_dims = embs.shape[1] - label_len
467
+
468
+ if self.emb_label is None:
469
+ emb_labels = None
470
+ else:
471
+ emb_labels = embs.columns[emb_dims:]
472
+
473
+ if plot_style == "umap":
474
+ for label in self.labels_to_plot:
475
+ if label not in emb_labels:
476
+ logger.warning(
477
+ f"Label {label} from labels_to_plot " \
478
+ f"not present in provided embeddings dataframe.")
479
+ continue
480
+ output_prefix_label = "_" + output_prefix + f"_umap_{label}"
481
+ output_file = (Path(output_directory) / output_prefix_label).with_suffix(".pdf")
482
+ plot_umap(embs, emb_dims, label, output_prefix_label, kwargs_dict)
483
+
484
+ if plot_style == "heatmap":
485
+ for label in self.labels_to_plot:
486
+ if label not in emb_labels:
487
+ logger.warning(
488
+ f"Label {label} from labels_to_plot " \
489
+ f"not present in provided embeddings dataframe.")
490
+ continue
491
+ output_prefix_label = output_prefix + f"_heatmap_{label}"
492
+ output_file = (Path(output_directory) / output_prefix_label).with_suffix(".pdf")
493
+ plot_heatmap(embs, emb_dims, label, output_file, kwargs_dict)
geneformer/gene_median_dictionary.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3129017daec18ff275f0900e674957d9b6547af266ef0e2c97b03d20b5d4c225
3
+ size 1640760
geneformer/gene_name_id_dict.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:90f7100adec84828555873be1bae866e83509ce016dedfd9633d12e01dee4ea4
3
+ size 607393
geneformer/in_silico_perturber.py ADDED
@@ -0,0 +1,1297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Geneformer in silico perturber.
3
+
4
+ Usage:
5
+ from geneformer import InSilicoPerturber
6
+ isp = InSilicoPerturber(perturb_type="delete",
7
+ perturb_rank_shift=None,
8
+ genes_to_perturb="all",
9
+ combos=0,
10
+ anchor_gene=None,
11
+ model_type="Pretrained",
12
+ num_classes=0,
13
+ emb_mode="cell",
14
+ cell_emb_style="mean_pool",
15
+ filter_data={"cell_type":["cardiomyocyte"]},
16
+ cell_states_to_model={"state_key": "disease", "start_state": "dcm", "goal_state": "nf", "alt_states": ["hcm", "other1", "other2"]},
17
+ max_ncells=None,
18
+ emb_layer=-1,
19
+ forward_batch_size=100,
20
+ nproc=4)
21
+ isp.perturb_data("path/to/model",
22
+ "path/to/input_data",
23
+ "path/to/output_directory",
24
+ "output_prefix")
25
+ """
26
+
27
+ # imports
28
+ import itertools as it
29
+ import logging
30
+ import numpy as np
31
+ import pickle
32
+ import re
33
+ import seaborn as sns; sns.set()
34
+ import torch
35
+ from collections import defaultdict
36
+ from datasets import Dataset, load_from_disk
37
+ from tqdm.notebook import trange
38
+ from transformers import BertForMaskedLM, BertForTokenClassification, BertForSequenceClassification
39
+
40
+ from .tokenizer import TOKEN_DICTIONARY_FILE
41
+
42
+ logger = logging.getLogger(__name__)
43
+
44
+
45
+ # load data and filter by defined criteria
46
+ def load_and_filter(filter_data, nproc, input_data_file):
47
+ data = load_from_disk(input_data_file)
48
+ if filter_data is not None:
49
+ for key,value in filter_data.items():
50
+ def filter_data_by_criteria(example):
51
+ return example[key] in value
52
+ data = data.filter(filter_data_by_criteria, num_proc=nproc)
53
+ if len(data) == 0:
54
+ logger.error(
55
+ "No cells remain after filtering. Check filtering criteria.")
56
+ raise
57
+ data_shuffled = data.shuffle(seed=42)
58
+ return data_shuffled
59
+
60
+ # load model to GPU
61
+ def load_model(model_type, num_classes, model_directory):
62
+ if model_type == "Pretrained":
63
+ model = BertForMaskedLM.from_pretrained(model_directory,
64
+ output_hidden_states=True,
65
+ output_attentions=False)
66
+ elif model_type == "GeneClassifier":
67
+ model = BertForTokenClassification.from_pretrained(model_directory,
68
+ num_labels=num_classes,
69
+ output_hidden_states=True,
70
+ output_attentions=False)
71
+ elif model_type == "CellClassifier":
72
+ model = BertForSequenceClassification.from_pretrained(model_directory,
73
+ num_labels=num_classes,
74
+ output_hidden_states=True,
75
+ output_attentions=False)
76
+ # put the model in eval mode for fwd pass
77
+ model.eval()
78
+ model = model.to("cuda:0")
79
+ return model
80
+
81
+ def quant_layers(model):
82
+ layer_nums = []
83
+ for name, parameter in model.named_parameters():
84
+ if "layer" in name:
85
+ layer_nums += [int(name.split("layer.")[1].split(".")[0])]
86
+ return int(max(layer_nums))+1
87
+
88
+ def get_model_input_size(model):
89
+ return int(re.split("\(|,",str(model.bert.embeddings.position_embeddings))[1])
90
+
91
+ def flatten_list(megalist):
92
+ return [item for sublist in megalist for item in sublist]
93
+
94
+ def measure_length(example):
95
+ example["length"] = len(example["input_ids"])
96
+ return example
97
+
98
+ def downsample_and_sort(data_shuffled, max_ncells):
99
+ num_cells = len(data_shuffled)
100
+ # if max number of cells is defined, then subsample to this max number
101
+ if max_ncells != None:
102
+ num_cells = min(max_ncells,num_cells)
103
+ data_subset = data_shuffled.select([i for i in range(num_cells)])
104
+ # sort dataset with largest cell first to encounter any memory errors earlier
105
+ data_sorted = data_subset.sort("length",reverse=True)
106
+ return data_sorted
107
+
108
+ def get_possible_states(cell_states_to_model):
109
+ possible_states = []
110
+ for key in ["start_state","goal_state"]:
111
+ possible_states += [cell_states_to_model[key]]
112
+ possible_states += cell_states_to_model.get("alt_states",[])
113
+ return possible_states
114
+
115
+ def forward_pass_single_cell(model, example_cell, layer_to_quant):
116
+ example_cell.set_format(type="torch")
117
+ input_data = example_cell["input_ids"]
118
+ with torch.no_grad():
119
+ outputs = model(
120
+ input_ids = input_data.to("cuda")
121
+ )
122
+ emb = torch.squeeze(outputs.hidden_states[layer_to_quant])
123
+ del outputs
124
+ return emb
125
+
126
+ def perturb_emb_by_index(emb, indices):
127
+ mask = torch.ones(emb.numel(), dtype=torch.bool)
128
+ mask[indices] = False
129
+ return emb[mask]
130
+
131
+ def delete_indices(example):
132
+ indices = example["perturb_index"]
133
+ if any(isinstance(el, list) for el in indices):
134
+ indices = flatten_list(indices)
135
+ for index in sorted(indices, reverse=True):
136
+ del example["input_ids"][index]
137
+ return example
138
+
139
+ # for genes_to_perturb = "all" where only genes within cell are overexpressed
140
+ def overexpress_indices(example):
141
+ indices = example["perturb_index"]
142
+ if any(isinstance(el, list) for el in indices):
143
+ indices = flatten_list(indices)
144
+ for index in sorted(indices, reverse=True):
145
+ example["input_ids"].insert(0, example["input_ids"].pop(index))
146
+ return example
147
+
148
+ # for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
149
+ def overexpress_tokens(example):
150
+ # -100 indicates tokens to overexpress are not present in rank value encoding
151
+ if example["perturb_index"] != [-100]:
152
+ example = delete_indices(example)
153
+ [example["input_ids"].insert(0, token) for token in example["tokens_to_perturb"][::-1]]
154
+ return example
155
+
156
+ def remove_indices_from_emb(emb, indices_to_remove, gene_dim):
157
+ # indices_to_remove is list of indices to remove
158
+ indices_to_keep = [i for i in range(emb.size()[gene_dim]) if i not in indices_to_remove]
159
+ num_dims = emb.dim()
160
+ emb_slice = [slice(None) if dim != gene_dim else indices_to_keep for dim in range(num_dims)]
161
+ sliced_emb = emb[emb_slice]
162
+ return sliced_emb
163
+
164
+ def remove_indices_from_emb_batch(emb_batch, list_of_indices_to_remove, gene_dim):
165
+ output_batch = torch.stack([
166
+ remove_indices_from_emb(emb_batch[i, :, :], idx, gene_dim-1) for
167
+ i, idx in enumerate(list_of_indices_to_remove)
168
+ ])
169
+ return output_batch
170
+
171
+ def make_perturbation_batch(example_cell,
172
+ perturb_type,
173
+ tokens_to_perturb,
174
+ anchor_token,
175
+ combo_lvl,
176
+ num_proc):
177
+ if tokens_to_perturb == "all":
178
+ if perturb_type in ["overexpress","activate"]:
179
+ range_start = 1
180
+ elif perturb_type in ["delete","inhibit"]:
181
+ range_start = 0
182
+ indices_to_perturb = [[i] for i in range(range_start,example_cell["length"][0])]
183
+ elif combo_lvl>0 and (anchor_token is not None):
184
+ example_input_ids = example_cell["input_ids "][0]
185
+ anchor_index = example_input_ids.index(anchor_token[0])
186
+ indices_to_perturb = [sorted([anchor_index,i]) if i!=anchor_index else None for i in range(example_cell["length"][0])]
187
+ indices_to_perturb = [item for item in indices_to_perturb if item is not None]
188
+ else:
189
+ example_input_ids = example_cell["input_ids"][0]
190
+ indices_to_perturb = [[example_input_ids.index(token)] if token in example_input_ids else None for token in tokens_to_perturb]
191
+ indices_to_perturb = [item for item in indices_to_perturb if item is not None]
192
+
193
+ # create all permutations of combo_lvl of modifiers from tokens_to_perturb
194
+ if combo_lvl>0 and (anchor_token is None):
195
+ if tokens_to_perturb != "all":
196
+ if len(tokens_to_perturb) == combo_lvl+1:
197
+ indices_to_perturb = [list(x) for x in it.combinations(indices_to_perturb, combo_lvl+1)]
198
+ else:
199
+ all_indices = [[i] for i in range(example_cell["length"][0])]
200
+ all_indices = [index for index in all_indices if index not in indices_to_perturb]
201
+ indices_to_perturb = [[[j for i in indices_to_perturb for j in i], x] for x in all_indices]
202
+ length = len(indices_to_perturb)
203
+ perturbation_dataset = Dataset.from_dict({"input_ids": example_cell["input_ids"]*length,
204
+ "perturb_index": indices_to_perturb})
205
+ if length<400:
206
+ num_proc_i = 1
207
+ else:
208
+ num_proc_i = num_proc
209
+ if perturb_type == "delete":
210
+ perturbation_dataset = perturbation_dataset.map(delete_indices, num_proc=num_proc_i)
211
+ elif perturb_type == "overexpress":
212
+ perturbation_dataset = perturbation_dataset.map(overexpress_indices, num_proc=num_proc_i)
213
+ return perturbation_dataset, indices_to_perturb
214
+
215
+ # perturbed cell emb removing the activated/overexpressed/inhibited gene emb
216
+ # so that only non-perturbed gene embeddings are compared to each other
217
+ # in original or perturbed context
218
+ def make_comparison_batch(original_emb_batch, indices_to_perturb, perturb_group):
219
+ all_embs_list = []
220
+
221
+ # if making comparison batch for multiple perturbations in single cell
222
+ if perturb_group == False:
223
+ original_emb_list = [original_emb_batch]*len(indices_to_perturb)
224
+ # if making comparison batch for single perturbation in multiple cells
225
+ elif perturb_group == True:
226
+ original_emb_list = original_emb_batch
227
+
228
+
229
+ for i in range(len(original_emb_list)):
230
+ original_emb = original_emb_list[i]
231
+ indices = indices_to_perturb[i]
232
+ if indices == [-100]:
233
+ all_embs_list += [original_emb[:]]
234
+ continue
235
+ emb_list = []
236
+ start = 0
237
+ if any(isinstance(el, list) for el in indices):
238
+ indices = flatten_list(indices)
239
+ for i in sorted(indices):
240
+ emb_list += [original_emb[start:i]]
241
+ start = i+1
242
+ emb_list += [original_emb[start:]]
243
+ all_embs_list += [torch.cat(emb_list)]
244
+ len_set = set([emb.size()[0] for emb in all_embs_list])
245
+ if len(len_set) > 1:
246
+ max_len = max(len_set)
247
+ all_embs_list = [pad_2d_tensor(emb, None, max_len, 0) for emb in all_embs_list]
248
+ return torch.stack(all_embs_list)
249
+
250
+ # average embedding position of goal cell states
251
+ def get_cell_state_avg_embs(model,
252
+ filtered_input_data,
253
+ cell_states_to_model,
254
+ layer_to_quant,
255
+ pad_token_id,
256
+ forward_batch_size,
257
+ num_proc):
258
+
259
+ model_input_size = get_model_input_size(model)
260
+ possible_states = get_possible_states(cell_states_to_model)
261
+ state_embs_dict = dict()
262
+ for possible_state in possible_states:
263
+ state_embs_list = []
264
+ original_lens = []
265
+
266
+ def filter_states(example):
267
+ state_key = cell_states_to_model["state_key"]
268
+ return example[state_key] in [possible_state]
269
+ filtered_input_data_state = filtered_input_data.filter(filter_states, num_proc=num_proc)
270
+ total_batch_length = len(filtered_input_data_state)
271
+ if ((total_batch_length-1)/forward_batch_size).is_integer():
272
+ forward_batch_size = forward_batch_size-1
273
+ max_len = max(filtered_input_data_state["length"])
274
+ for i in range(0, total_batch_length, forward_batch_size):
275
+ max_range = min(i+forward_batch_size, total_batch_length)
276
+
277
+ state_minibatch = filtered_input_data_state.select([i for i in range(i, max_range)])
278
+ state_minibatch.set_format(type="torch")
279
+
280
+ input_data_minibatch = state_minibatch["input_ids"]
281
+ original_lens += state_minibatch["length"]
282
+ input_data_minibatch = pad_tensor_list(input_data_minibatch,
283
+ max_len,
284
+ pad_token_id,
285
+ model_input_size)
286
+ attention_mask = gen_attention_mask(state_minibatch, max_len)
287
+
288
+ with torch.no_grad():
289
+ outputs = model(
290
+ input_ids = input_data_minibatch.to("cuda"),
291
+ attention_mask = attention_mask
292
+ )
293
+
294
+ state_embs_i = outputs.hidden_states[layer_to_quant]
295
+ state_embs_list += [state_embs_i]
296
+ del outputs
297
+ del state_minibatch
298
+ del input_data_minibatch
299
+ del attention_mask
300
+ del state_embs_i
301
+ torch.cuda.empty_cache()
302
+
303
+ state_embs = torch.cat(state_embs_list)
304
+ avg_state_emb = mean_nonpadding_embs(state_embs, torch.Tensor(original_lens).to("cuda"))
305
+ avg_state_emb = torch.mean(avg_state_emb, dim=0, keepdim=True)
306
+ state_embs_dict[possible_state] = avg_state_emb
307
+ return state_embs_dict
308
+
309
+ # quantify cosine similarity of perturbed vs original or alternate states
310
+ def quant_cos_sims(model,
311
+ perturb_type,
312
+ perturbation_batch,
313
+ forward_batch_size,
314
+ layer_to_quant,
315
+ original_emb,
316
+ tokens_to_perturb,
317
+ indices_to_perturb,
318
+ perturb_group,
319
+ cell_states_to_model,
320
+ state_embs_dict,
321
+ pad_token_id,
322
+ model_input_size,
323
+ nproc):
324
+ cos = torch.nn.CosineSimilarity(dim=2)
325
+ total_batch_length = len(perturbation_batch)
326
+ if ((total_batch_length-1)/forward_batch_size).is_integer():
327
+ forward_batch_size = forward_batch_size-1
328
+ if cell_states_to_model is None:
329
+ if perturb_group == False: # (if perturb_group is True, original_emb is filtered_input_data)
330
+ comparison_batch = make_comparison_batch(original_emb, indices_to_perturb, perturb_group)
331
+ cos_sims = []
332
+ else:
333
+ possible_states = get_possible_states(cell_states_to_model)
334
+ cos_sims_vs_alt_dict = dict(zip(possible_states,[[] for i in range(len(possible_states))]))
335
+
336
+ # measure length of each element in perturbation_batch
337
+ perturbation_batch = perturbation_batch.map(
338
+ measure_length, num_proc=nproc
339
+ )
340
+
341
+ for i in range(0, total_batch_length, forward_batch_size):
342
+ max_range = min(i+forward_batch_size, total_batch_length)
343
+
344
+ perturbation_minibatch = perturbation_batch.select([i for i in range(i, max_range)])
345
+ # determine if need to pad or truncate batch
346
+ minibatch_length_set = set(perturbation_minibatch["length"])
347
+ minibatch_lengths = perturbation_minibatch["length"]
348
+ if (len(minibatch_length_set) > 1) or (max(minibatch_length_set) > model_input_size):
349
+ needs_pad_or_trunc = True
350
+ else:
351
+ needs_pad_or_trunc = False
352
+ max_len = max(minibatch_length_set)
353
+
354
+ if needs_pad_or_trunc == True:
355
+ max_len = min(max(minibatch_length_set),model_input_size)
356
+ def pad_or_trunc_example(example):
357
+ example["input_ids"] = pad_or_truncate_encoding(example["input_ids"],
358
+ pad_token_id,
359
+ max_len)
360
+ return example
361
+ perturbation_minibatch = perturbation_minibatch.map(pad_or_trunc_example, num_proc=nproc)
362
+
363
+ perturbation_minibatch.set_format(type="torch")
364
+
365
+ input_data_minibatch = perturbation_minibatch["input_ids"]
366
+ attention_mask = gen_attention_mask(perturbation_minibatch, max_len)
367
+
368
+ # extract embeddings for perturbation minibatch
369
+ with torch.no_grad():
370
+ outputs = model(
371
+ input_ids = input_data_minibatch.to("cuda"),
372
+ attention_mask = attention_mask
373
+ )
374
+ del input_data_minibatch
375
+ del perturbation_minibatch
376
+ del attention_mask
377
+
378
+ if len(indices_to_perturb)>1:
379
+ minibatch_emb = torch.squeeze(outputs.hidden_states[layer_to_quant])
380
+ else:
381
+ minibatch_emb = outputs.hidden_states[layer_to_quant]
382
+
383
+ if perturb_type == "overexpress":
384
+ # remove overexpressed genes to quantify effect on remaining genes
385
+ if perturb_group == False:
386
+ overexpressed_to_remove = 1
387
+ if perturb_group == True:
388
+ overexpressed_to_remove = len(tokens_to_perturb)
389
+ minibatch_emb = minibatch_emb[:,overexpressed_to_remove:,:]
390
+
391
+ # if quantifying single perturbation in multiple different cells, pad original batch and extract embs
392
+ if perturb_group == True:
393
+ # pad minibatch of original batch to extract embeddings
394
+ # truncate to the (model input size - # tokens to overexpress) to ensure comparability
395
+ # since max input size of perturb batch will be reduced by # tokens to overexpress
396
+ original_minibatch = original_emb.select([i for i in range(i, max_range)])
397
+ original_minibatch_lengths = original_minibatch["length"]
398
+ original_minibatch_length_set = set(original_minibatch["length"])
399
+
400
+ indices_to_perturb_minibatch = indices_to_perturb[i:i+forward_batch_size]
401
+
402
+ if perturb_type == "overexpress":
403
+ new_max_len = model_input_size - len(tokens_to_perturb)
404
+ else:
405
+ new_max_len = model_input_size
406
+ if (len(original_minibatch_length_set) > 1) or (max(original_minibatch_length_set) > new_max_len):
407
+ new_max_len = min(max(original_minibatch_length_set),new_max_len)
408
+ def pad_or_trunc_example(example):
409
+ example["input_ids"] = pad_or_truncate_encoding(example["input_ids"], pad_token_id, new_max_len)
410
+ return example
411
+ original_minibatch = original_minibatch.map(pad_or_trunc_example, num_proc=nproc)
412
+ original_minibatch.set_format(type="torch")
413
+ original_input_data_minibatch = original_minibatch["input_ids"]
414
+ attention_mask = gen_attention_mask(original_minibatch, new_max_len)
415
+ # extract embeddings for original minibatch
416
+ with torch.no_grad():
417
+ original_outputs = model(
418
+ input_ids = original_input_data_minibatch.to("cuda"),
419
+ attention_mask = attention_mask
420
+ )
421
+ del original_input_data_minibatch
422
+ del original_minibatch
423
+ del attention_mask
424
+
425
+ if len(indices_to_perturb)>1:
426
+ original_minibatch_emb = torch.squeeze(original_outputs.hidden_states[layer_to_quant])
427
+ else:
428
+ original_minibatch_emb = original_outputs.hidden_states[layer_to_quant]
429
+
430
+ # embedding dimension of the genes
431
+ gene_dim = 1
432
+ # exclude overexpression due to case when genes are not expressed but being overexpressed
433
+ if perturb_type != "overexpress":
434
+ original_minibatch_emb = remove_indices_from_emb_batch(original_minibatch_emb,
435
+ indices_to_perturb_minibatch,
436
+ gene_dim)
437
+
438
+ # cosine similarity between original emb and batch items
439
+ if cell_states_to_model is None:
440
+ if perturb_group == False:
441
+ minibatch_comparison = comparison_batch[i:max_range]
442
+ elif perturb_group == True:
443
+ minibatch_comparison = original_minibatch_emb
444
+
445
+ cos_sims += [cos(minibatch_emb, minibatch_comparison).to("cpu")]
446
+ elif cell_states_to_model is not None:
447
+ for state in possible_states:
448
+ if perturb_group == False:
449
+ cos_sims_vs_alt_dict[state] += cos_sim_shift(original_emb,
450
+ minibatch_emb,
451
+ state_embs_dict[state],
452
+ perturb_group)
453
+ elif perturb_group == True:
454
+ cos_sims_vs_alt_dict[state] += cos_sim_shift(original_minibatch_emb,
455
+ minibatch_emb,
456
+ state_embs_dict[state],
457
+ perturb_group,
458
+ torch.tensor(original_minibatch_lengths, device="cuda"),
459
+ torch.tensor(minibatch_lengths, device="cuda"))
460
+ del outputs
461
+ del minibatch_emb
462
+ if cell_states_to_model is None:
463
+ del minibatch_comparison
464
+ torch.cuda.empty_cache()
465
+ if cell_states_to_model is None:
466
+ cos_sims_stack = torch.cat(cos_sims)
467
+ return cos_sims_stack
468
+ else:
469
+ for state in possible_states:
470
+ cos_sims_vs_alt_dict[state] = torch.cat(cos_sims_vs_alt_dict[state])
471
+ return cos_sims_vs_alt_dict
472
+
473
+ # calculate cos sim shift of perturbation with respect to origin and alternative cell
474
+ def cos_sim_shift(original_emb,
475
+ minibatch_emb,
476
+ end_emb,
477
+ perturb_group,
478
+ original_minibatch_lengths = None,
479
+ minibatch_lengths = None):
480
+ cos = torch.nn.CosineSimilarity(dim=2)
481
+ if not perturb_group:
482
+ original_emb = torch.mean(original_emb,dim=0,keepdim=True)
483
+ original_emb = original_emb[None, :]
484
+ origin_v_end = torch.squeeze(cos(original_emb, end_emb)) #test
485
+ else:
486
+ if original_emb.size() != minibatch_emb.size():
487
+ logger.error(
488
+ f"Embeddings are not the same dimensions. " \
489
+ f"original_emb is {original_emb.size()}. " \
490
+ f"minibatch_emb is {minibatch_emb.size()}. "
491
+ )
492
+ raise
493
+
494
+ if original_minibatch_lengths is not None:
495
+ original_emb = mean_nonpadding_embs(original_emb, original_minibatch_lengths)
496
+ # else:
497
+ # original_emb = torch.mean(original_emb,dim=1,keepdim=True)
498
+
499
+ end_emb = torch.unsqueeze(end_emb, 1)
500
+ origin_v_end = cos(original_emb, end_emb)
501
+ origin_v_end = torch.squeeze(origin_v_end)
502
+ if minibatch_lengths is not None:
503
+ perturb_emb = mean_nonpadding_embs(minibatch_emb, minibatch_lengths)
504
+ else:
505
+ perturb_emb = torch.mean(minibatch_emb,dim=1,keepdim=True)
506
+
507
+ perturb_v_end = cos(perturb_emb, end_emb)
508
+ perturb_v_end = torch.squeeze(perturb_v_end)
509
+ return [(perturb_v_end-origin_v_end).to("cpu")]
510
+
511
+ def pad_list(input_ids, pad_token_id, max_len):
512
+ input_ids = np.pad(input_ids,
513
+ (0, max_len-len(input_ids)),
514
+ mode='constant', constant_values=pad_token_id)
515
+ return input_ids
516
+
517
+ def pad_tensor(tensor, pad_token_id, max_len):
518
+ tensor = torch.nn.functional.pad(tensor, pad=(0,
519
+ max_len - tensor.numel()),
520
+ mode='constant',
521
+ value=pad_token_id)
522
+ return tensor
523
+
524
+ def pad_2d_tensor(tensor, pad_token_id, max_len, dim):
525
+ if dim == 0:
526
+ pad = (0, 0, 0, max_len - tensor.size()[dim])
527
+ elif dim == 1:
528
+ pad = (0, max_len - tensor.size()[dim], 0, 0)
529
+ tensor = torch.nn.functional.pad(tensor, pad=pad,
530
+ mode='constant',
531
+ value=pad_token_id)
532
+ return tensor
533
+
534
+ def pad_or_truncate_encoding(encoding, pad_token_id, max_len):
535
+ if isinstance(encoding, torch.Tensor):
536
+ encoding_len = tensor.size()[0]
537
+ elif isinstance(encoding, list):
538
+ encoding_len = len(encoding)
539
+ if encoding_len > max_len:
540
+ encoding = encoding[0:max_len]
541
+ elif encoding_len < max_len:
542
+ if isinstance(encoding, torch.Tensor):
543
+ encoding = pad_tensor(encoding, pad_token_id, max_len)
544
+ elif isinstance(encoding, list):
545
+ encoding = pad_list(encoding, pad_token_id, max_len)
546
+ return encoding
547
+
548
+ # pad list of tensors and convert to tensor
549
+ def pad_tensor_list(tensor_list, dynamic_or_constant, pad_token_id, model_input_size):
550
+
551
+ # Determine maximum tensor length
552
+ if dynamic_or_constant == "dynamic":
553
+ max_len = max([tensor.squeeze().numel() for tensor in tensor_list])
554
+ elif type(dynamic_or_constant) == int:
555
+ max_len = dynamic_or_constant
556
+ else:
557
+ max_len = model_input_size
558
+ logger.warning(
559
+ "If padding style is constant, must provide integer value. " \
560
+ f"Setting padding to max input size {model_input_size}.")
561
+
562
+ # pad all tensors to maximum length
563
+ tensor_list = [pad_tensor(tensor, pad_token_id, max_len) for tensor in tensor_list]
564
+
565
+ # return stacked tensors
566
+ return torch.stack(tensor_list)
567
+
568
+ def gen_attention_mask(minibatch_encoding, max_len = None):
569
+ if max_len == None:
570
+ max_len = max(minibatch_encoding["length"])
571
+ original_lens = minibatch_encoding["length"]
572
+ attention_mask = [[1]*original_len
573
+ +[0]*(max_len - original_len)
574
+ if original_len <= max_len
575
+ else [1]*max_len
576
+ for original_len in original_lens]
577
+ return torch.tensor(attention_mask).to("cuda")
578
+
579
+ # get cell embeddings excluding padding
580
+ def mean_nonpadding_embs(embs, original_lens):
581
+ # mask based on padding lengths
582
+ mask = torch.arange(embs.size(1)).unsqueeze(0).to("cuda") < original_lens.unsqueeze(1)
583
+
584
+ # extend mask dimensions to match the embeddings tensor
585
+ mask = mask.unsqueeze(2).expand_as(embs)
586
+
587
+ # use the mask to zero out the embeddings in padded areas
588
+ masked_embs = embs * mask.float()
589
+
590
+ # sum and divide by the lengths to get the mean of non-padding embs
591
+ mean_embs = masked_embs.sum(1) / original_lens.view(-1, 1).float()
592
+ return mean_embs
593
+
594
+ class InSilicoPerturber:
595
+ valid_option_dict = {
596
+ "perturb_type": {"delete","overexpress","inhibit","activate"},
597
+ "perturb_rank_shift": {None, 1, 2, 3},
598
+ "genes_to_perturb": {"all", list},
599
+ "combos": {0, 1},
600
+ "anchor_gene": {None, str},
601
+ "model_type": {"Pretrained","GeneClassifier","CellClassifier"},
602
+ "num_classes": {int},
603
+ "emb_mode": {"cell","cell_and_gene"},
604
+ "cell_emb_style": {"mean_pool"},
605
+ "filter_data": {None, dict},
606
+ "cell_states_to_model": {None, dict},
607
+ "max_ncells": {None, int},
608
+ "cell_inds_to_perturb": {"all", dict},
609
+ "emb_layer": {-1, 0},
610
+ "forward_batch_size": {int},
611
+ "nproc": {int},
612
+ }
613
+ def __init__(
614
+ self,
615
+ perturb_type="delete",
616
+ perturb_rank_shift=None,
617
+ genes_to_perturb="all",
618
+ combos=0,
619
+ anchor_gene=None,
620
+ model_type="Pretrained",
621
+ num_classes=0,
622
+ emb_mode="cell",
623
+ cell_emb_style="mean_pool",
624
+ filter_data=None,
625
+ cell_states_to_model=None,
626
+ max_ncells=None,
627
+ cell_inds_to_perturb="all",
628
+ emb_layer=-1,
629
+ forward_batch_size=100,
630
+ nproc=4,
631
+ token_dictionary_file=TOKEN_DICTIONARY_FILE,
632
+ ):
633
+ """
634
+ Initialize in silico perturber.
635
+
636
+ Parameters
637
+ ----------
638
+ perturb_type : {"delete","overexpress","inhibit","activate"}
639
+ Type of perturbation.
640
+ "delete": delete gene from rank value encoding
641
+ "overexpress": move gene to front of rank value encoding
642
+ "inhibit": move gene to lower quartile of rank value encoding
643
+ "activate": move gene to higher quartile of rank value encoding
644
+ perturb_rank_shift : None, {1,2,3}
645
+ Number of quartiles by which to shift rank of gene.
646
+ For example, if perturb_type="activate" and perturb_rank_shift=1:
647
+ genes in 4th quartile will move to middle of 3rd quartile.
648
+ genes in 3rd quartile will move to middle of 2nd quartile.
649
+ genes in 2nd quartile will move to middle of 1st quartile.
650
+ genes in 1st quartile will move to front of rank value encoding.
651
+ For example, if perturb_type="inhibit" and perturb_rank_shift=2:
652
+ genes in 1st quartile will move to middle of 3rd quartile.
653
+ genes in 2nd quartile will move to middle of 4th quartile.
654
+ genes in 3rd or 4th quartile will move to bottom of rank value encoding.
655
+ genes_to_perturb : "all", list
656
+ Default is perturbing each gene detected in each cell in the dataset.
657
+ Otherwise, may provide a list of ENSEMBL IDs of genes to perturb.
658
+ If gene list is provided, then perturber will only test perturbing them all together
659
+ (rather than testing each possible combination of the provided genes).
660
+ combos : {0,1}
661
+ Whether to perturb genes individually (0) or in pairs (1).
662
+ anchor_gene : None, str
663
+ ENSEMBL ID of gene to use as anchor in combination perturbations.
664
+ For example, if combos=1 and anchor_gene="ENSG00000148400":
665
+ anchor gene will be perturbed in combination with each other gene.
666
+ model_type : {"Pretrained","GeneClassifier","CellClassifier"}
667
+ Whether model is the pretrained Geneformer or a fine-tuned gene or cell classifier.
668
+ num_classes : int
669
+ If model is a gene or cell classifier, specify number of classes it was trained to classify.
670
+ For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
671
+ emb_mode : {"cell","cell_and_gene"}
672
+ Whether to output impact of perturbation on cell and/or gene embeddings.
673
+ cell_emb_style : "mean_pool"
674
+ Method for summarizing cell embeddings.
675
+ Currently only option is mean pooling of gene embeddings for given cell.
676
+ filter_data : None, dict
677
+ Default is to use all input data for in silico perturbation study.
678
+ Otherwise, dictionary specifying .dataset column name and list of values to filter by.
679
+ cell_states_to_model: None, dict
680
+ Cell states to model if testing perturbations that achieve goal state change.
681
+ Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states
682
+ state_key: key specifying name of column in .dataset that defines the start/goal states
683
+ start_state: value in the state_key column that specifies the start state
684
+ goal_state: value in the state_key column taht specifies the goal end state
685
+ alt_states: list of values in the state_key column that specify the alternate end states
686
+ For example: {"state_key": "disease",
687
+ "start_state": "dcm",
688
+ "goal_state": "nf",
689
+ "alt_states": ["hcm", "other1", "other2"]}
690
+ max_ncells : None, int
691
+ Maximum number of cells to test.
692
+ If None, will test all cells.
693
+ cell_inds_to_perturb : "all", list
694
+ Default is perturbing each cell in the dataset.
695
+ Otherwise, may provide a dict of indices of cells to perturb with keys start_ind and end_ind.
696
+ start_ind: the first index to perturb.
697
+ end_ind: the last index to perturb (exclusive).
698
+ Indices will be selected *after* the filter_data criteria and sorting.
699
+ Useful for splitting extremely large datasets across separate GPUs.
700
+ emb_layer : {-1, 0}
701
+ Embedding layer to use for quantification.
702
+ -1: 2nd to last layer (recommended for pretrained Geneformer)
703
+ 0: last layer (recommended for cell classifier fine-tuned for disease state)
704
+ forward_batch_size : int
705
+ Batch size for forward pass.
706
+ nproc : int
707
+ Number of CPU processes to use.
708
+ token_dictionary_file : Path
709
+ Path to pickle file containing token dictionary (Ensembl ID:token).
710
+ """
711
+
712
+ self.perturb_type = perturb_type
713
+ self.perturb_rank_shift = perturb_rank_shift
714
+ self.genes_to_perturb = genes_to_perturb
715
+ self.combos = combos
716
+ self.anchor_gene = anchor_gene
717
+ if self.genes_to_perturb == "all":
718
+ self.perturb_group = False
719
+ else:
720
+ self.perturb_group = True
721
+ if (self.anchor_gene != None) or (self.combos != 0):
722
+ self.anchor_gene = None
723
+ self.combos = 0
724
+ logger.warning(
725
+ "anchor_gene set to None and combos set to 0. " \
726
+ "If providing list of genes to perturb, " \
727
+ "list of genes_to_perturb will be perturbed together, "\
728
+ "without anchor gene or combinations.")
729
+ self.model_type = model_type
730
+ self.num_classes = num_classes
731
+ self.emb_mode = emb_mode
732
+ self.cell_emb_style = cell_emb_style
733
+ self.filter_data = filter_data
734
+ self.cell_states_to_model = cell_states_to_model
735
+ self.max_ncells = max_ncells
736
+ self.cell_inds_to_perturb = cell_inds_to_perturb
737
+ self.emb_layer = emb_layer
738
+ self.forward_batch_size = forward_batch_size
739
+ self.nproc = nproc
740
+
741
+ self.validate_options()
742
+
743
+ # load token dictionary (Ensembl IDs:token)
744
+ with open(token_dictionary_file, "rb") as f:
745
+ self.gene_token_dict = pickle.load(f)
746
+
747
+ self.pad_token_id = self.gene_token_dict.get("<pad>")
748
+
749
+ if self.anchor_gene is None:
750
+ self.anchor_token = None
751
+ else:
752
+ try:
753
+ self.anchor_token = [self.gene_token_dict[self.anchor_gene]]
754
+ except KeyError:
755
+ logger.error(
756
+ f"Anchor gene {self.anchor_gene} not in token dictionary."
757
+ )
758
+ raise
759
+
760
+ if self.genes_to_perturb == "all":
761
+ self.tokens_to_perturb = "all"
762
+ else:
763
+ missing_genes = [gene for gene in self.genes_to_perturb if gene not in self.gene_token_dict.keys()]
764
+ if len(missing_genes) == len(self.genes_to_perturb):
765
+ logger.error(
766
+ "None of the provided genes to perturb are in token dictionary."
767
+ )
768
+ raise
769
+ elif len(missing_genes)>0:
770
+ logger.warning(
771
+ f"Genes to perturb {missing_genes} are not in token dictionary.")
772
+ self.tokens_to_perturb = [self.gene_token_dict.get(gene) for gene in self.genes_to_perturb]
773
+
774
+ def validate_options(self):
775
+ # first disallow options under development
776
+ if self.perturb_type in ["inhibit", "activate"]:
777
+ logger.error(
778
+ "In silico inhibition and activation currently under development. " \
779
+ "Current valid options for 'perturb_type': 'delete' or 'overexpress'"
780
+ )
781
+ raise
782
+
783
+ # confirm arguments are within valid options and compatible with each other
784
+ for attr_name,valid_options in self.valid_option_dict.items():
785
+ attr_value = self.__dict__[attr_name]
786
+ if type(attr_value) not in {list, dict}:
787
+ if attr_value in valid_options:
788
+ continue
789
+ if attr_name in ["anchor_gene"]:
790
+ if type(attr_name) in {str}:
791
+ continue
792
+ valid_type = False
793
+ for option in valid_options:
794
+ if (option in [int,list,dict]) and isinstance(attr_value, option):
795
+ valid_type = True
796
+ break
797
+ if valid_type:
798
+ continue
799
+ logger.error(
800
+ f"Invalid option for {attr_name}. " \
801
+ f"Valid options for {attr_name}: {valid_options}"
802
+ )
803
+ raise
804
+
805
+ if self.perturb_type in ["delete","overexpress"]:
806
+ if self.perturb_rank_shift is not None:
807
+ if self.perturb_type == "delete":
808
+ logger.warning(
809
+ "perturb_rank_shift set to None. " \
810
+ "If perturb type is delete then gene is deleted entirely " \
811
+ "rather than shifted by quartile")
812
+ elif self.perturb_type == "overexpress":
813
+ logger.warning(
814
+ "perturb_rank_shift set to None. " \
815
+ "If perturb type is overexpress then gene is moved to front " \
816
+ "of rank value encoding rather than shifted by quartile")
817
+ self.perturb_rank_shift = None
818
+
819
+ if (self.anchor_gene is not None) and (self.emb_mode == "cell_and_gene"):
820
+ self.emb_mode = "cell"
821
+ logger.warning(
822
+ "emb_mode set to 'cell'. " \
823
+ "Currently, analysis with anchor gene " \
824
+ "only outputs effect on cell embeddings.")
825
+
826
+ if self.cell_states_to_model is not None:
827
+ if len(self.cell_states_to_model.items()) == 1:
828
+ logger.warning(
829
+ "The single value dictionary for cell_states_to_model will be " \
830
+ "replaced with a dictionary with named keys for start, goal, and alternate states. " \
831
+ "Please specify state_key, start_state, goal_state, and alt_states " \
832
+ "in the cell_states_to_model dictionary for future use. " \
833
+ "For example, cell_states_to_model={" \
834
+ "'state_key': 'disease', " \
835
+ "'start_state': 'dcm', " \
836
+ "'goal_state': 'nf', " \
837
+ "'alt_states': ['hcm', 'other1', 'other2']}"
838
+ )
839
+ for key,value in self.cell_states_to_model.items():
840
+ if (len(value) == 3) and isinstance(value, tuple):
841
+ if isinstance(value[0],list) and isinstance(value[1],list) and isinstance(value[2],list):
842
+ if len(value[0]) == 1 and len(value[1]) == 1:
843
+ all_values = value[0]+value[1]+value[2]
844
+ if len(all_values) == len(set(all_values)):
845
+ continue
846
+ # reformat to the new named key format
847
+ state_values = flatten_list(list(self.cell_states_to_model.values()))
848
+ self.cell_states_to_model = {
849
+ "state_key": list(self.cell_states_to_model.keys())[0],
850
+ "start_state": state_values[0][0],
851
+ "goal_state": state_values[1][0],
852
+ "alt_states": state_values[2:][0]
853
+ }
854
+ elif set(self.cell_states_to_model.keys()) == {"state_key", "start_state", "goal_state", "alt_states"}:
855
+ if (self.cell_states_to_model["state_key"] is None) \
856
+ or (self.cell_states_to_model["start_state"] is None) \
857
+ or (self.cell_states_to_model["goal_state"] is None):
858
+ logger.error(
859
+ "Please specify 'state_key', 'start_state', and 'goal_state' in cell_states_to_model.")
860
+ raise
861
+
862
+ if self.cell_states_to_model["start_state"] == self.cell_states_to_model["goal_state"]:
863
+ logger.error(
864
+ "All states must be unique.")
865
+ raise
866
+
867
+ if self.cell_states_to_model["alt_states"] is not None:
868
+ if type(self.cell_states_to_model["alt_states"]) is not list:
869
+ logger.error(
870
+ "self.cell_states_to_model['alt_states'] must be a list (even if it is one element)."
871
+ )
872
+ raise
873
+ if len(self.cell_states_to_model["alt_states"])!= len(set(self.cell_states_to_model["alt_states"])):
874
+ logger.error(
875
+ "All states must be unique.")
876
+ raise
877
+
878
+ else:
879
+ logger.error(
880
+ "cell_states_to_model must only have the following four keys: " \
881
+ "'state_key', 'start_state', 'goal_state', 'alt_states'." \
882
+ "For example, cell_states_to_model={" \
883
+ "'state_key': 'disease', " \
884
+ "'start_state': 'dcm', " \
885
+ "'goal_state': 'nf', " \
886
+ "'alt_states': ['hcm', 'other1', 'other2']}"
887
+ )
888
+ raise
889
+
890
+ if self.anchor_gene is not None:
891
+ self.anchor_gene = None
892
+ logger.warning(
893
+ "anchor_gene set to None. " \
894
+ "Currently, anchor gene not available " \
895
+ "when modeling multiple cell states.")
896
+
897
+ if self.perturb_type in ["inhibit","activate"]:
898
+ if self.perturb_rank_shift is None:
899
+ logger.error(
900
+ "If perturb_type is inhibit or activate then " \
901
+ "quartile to shift by must be specified.")
902
+ raise
903
+
904
+ if self.filter_data is not None:
905
+ for key,value in self.filter_data.items():
906
+ if type(value) != list:
907
+ self.filter_data[key] = [value]
908
+ logger.warning(
909
+ "Values in filter_data dict must be lists. " \
910
+ f"Changing {key} value to list ([{value}]).")
911
+
912
+ if self.cell_inds_to_perturb != "all":
913
+ if set(self.cell_inds_to_perturb.keys()) != {"start", "end"}:
914
+ logger.error(
915
+ "If cell_inds_to_perturb is a dictionary, keys must be 'start' and 'end'."
916
+ )
917
+ raise
918
+ if self.cell_inds_to_perturb["start"] < 0 or self.cell_inds_to_perturb["end"] < 0:
919
+ logger.error(
920
+ 'cell_inds_to_perturb must be positive.'
921
+ )
922
+ raise
923
+
924
+ def perturb_data(self,
925
+ model_directory,
926
+ input_data_file,
927
+ output_directory,
928
+ output_prefix):
929
+ """
930
+ Perturb genes in input data and save as results in output_directory.
931
+
932
+ Parameters
933
+ ----------
934
+ model_directory : Path
935
+ Path to directory containing model
936
+ input_data_file : Path
937
+ Path to directory containing .dataset inputs
938
+ output_directory : Path
939
+ Path to directory where perturbation data will be saved as batched pickle files
940
+ output_prefix : str
941
+ Prefix for output files
942
+ """
943
+
944
+ filtered_input_data = load_and_filter(self.filter_data, self.nproc, input_data_file)
945
+ model = load_model(self.model_type, self.num_classes, model_directory)
946
+ layer_to_quant = quant_layers(model)+self.emb_layer
947
+
948
+ if self.cell_states_to_model is None:
949
+ state_embs_dict = None
950
+ else:
951
+ # confirm that all states are valid to prevent futile filtering
952
+ state_name = self.cell_states_to_model["state_key"]
953
+ state_values = filtered_input_data[state_name]
954
+ for value in get_possible_states(self.cell_states_to_model):
955
+ if value not in state_values:
956
+ logger.error(
957
+ f"{value} is not present in the dataset's {state_name} attribute.")
958
+ raise
959
+ # get dictionary of average cell state embeddings for comparison
960
+ downsampled_data = downsample_and_sort(filtered_input_data, self.max_ncells)
961
+ state_embs_dict = get_cell_state_avg_embs(model,
962
+ downsampled_data,
963
+ self.cell_states_to_model,
964
+ layer_to_quant,
965
+ self.pad_token_id,
966
+ self.forward_batch_size,
967
+ self.nproc)
968
+ # filter for start state cells
969
+ start_state = self.cell_states_to_model["start_state"]
970
+ def filter_for_origin(example):
971
+ return example[state_name] in [start_state]
972
+
973
+ filtered_input_data = filtered_input_data.filter(filter_for_origin, num_proc=self.nproc)
974
+
975
+ self.in_silico_perturb(model,
976
+ filtered_input_data,
977
+ layer_to_quant,
978
+ state_embs_dict,
979
+ output_directory,
980
+ output_prefix)
981
+
982
+ # determine effect of perturbation on other genes
983
+ def in_silico_perturb(self,
984
+ model,
985
+ filtered_input_data,
986
+ layer_to_quant,
987
+ state_embs_dict,
988
+ output_directory,
989
+ output_prefix):
990
+
991
+ output_path_prefix = f"{output_directory}in_silico_{self.perturb_type}_{output_prefix}_dict_1Kbatch"
992
+ model_input_size = get_model_input_size(model)
993
+
994
+ # filter dataset for cells that have tokens to be perturbed
995
+ if self.anchor_token is not None:
996
+ def if_has_tokens_to_perturb(example):
997
+ return (len(set(example["input_ids"]).intersection(self.anchor_token))==len(self.anchor_token))
998
+ filtered_input_data = filtered_input_data.filter(if_has_tokens_to_perturb, num_proc=self.nproc)
999
+ if len(filtered_input_data) == 0:
1000
+ logger.error(
1001
+ "No cells in dataset contain anchor gene.")
1002
+ raise
1003
+ else:
1004
+ logger.info(f"# cells with anchor gene: {len(filtered_input_data)}")
1005
+
1006
+ if (self.tokens_to_perturb != "all") and (self.perturb_type != "overexpress"):
1007
+ # minimum # genes needed for perturbation test
1008
+ min_genes = len(self.tokens_to_perturb)
1009
+
1010
+ def if_has_tokens_to_perturb(example):
1011
+ return (len(set(example["input_ids"]).intersection(self.tokens_to_perturb))>=min_genes)
1012
+ filtered_input_data = filtered_input_data.filter(if_has_tokens_to_perturb, num_proc=self.nproc)
1013
+ if len(filtered_input_data) == 0:
1014
+ logger.error(
1015
+ "No cells in dataset contain all genes to perturb as a group.")
1016
+ raise
1017
+
1018
+ cos_sims_dict = defaultdict(list)
1019
+ pickle_batch = -1
1020
+ filtered_input_data = downsample_and_sort(filtered_input_data, self.max_ncells)
1021
+ if self.cell_inds_to_perturb != "all":
1022
+ if self.cell_inds_to_perturb["start"] >= len(filtered_input_data):
1023
+ logger.error("cell_inds_to_perturb['start'] is larger than the filtered dataset.")
1024
+ raise
1025
+ if self.cell_inds_to_perturb["end"] > len(filtered_input_data):
1026
+ logger.warning("cell_inds_to_perturb['end'] is larger than the filtered dataset. \
1027
+ Setting to the end of the filtered dataset.")
1028
+ self.cell_inds_to_perturb["end"] = len(filtered_input_data)
1029
+ filtered_input_data = filtered_input_data.select([i for i in range(self.cell_inds_to_perturb["start"], self.cell_inds_to_perturb["end"])])
1030
+
1031
+ # make perturbation batch w/ single perturbation in multiple cells
1032
+ if self.perturb_group == True:
1033
+
1034
+ def make_group_perturbation_batch(example):
1035
+ example_input_ids = example["input_ids"]
1036
+ example["tokens_to_perturb"] = self.tokens_to_perturb
1037
+ indices_to_perturb = [example_input_ids.index(token) if token in example_input_ids else None for token in self.tokens_to_perturb]
1038
+ indices_to_perturb = [item for item in indices_to_perturb if item is not None]
1039
+ if len(indices_to_perturb) > 0:
1040
+ example["perturb_index"] = indices_to_perturb
1041
+ else:
1042
+ # -100 indicates tokens to overexpress are not present in rank value encoding
1043
+ example["perturb_index"] = [-100]
1044
+ if self.perturb_type == "delete":
1045
+ example = delete_indices(example)
1046
+ elif self.perturb_type == "overexpress":
1047
+ example = overexpress_tokens(example)
1048
+ return example
1049
+
1050
+ perturbation_batch = filtered_input_data.map(make_group_perturbation_batch, num_proc=self.nproc)
1051
+ indices_to_perturb = perturbation_batch["perturb_index"]
1052
+
1053
+ cos_sims_data = quant_cos_sims(model,
1054
+ self.perturb_type,
1055
+ perturbation_batch,
1056
+ self.forward_batch_size,
1057
+ layer_to_quant,
1058
+ filtered_input_data,
1059
+ self.tokens_to_perturb,
1060
+ indices_to_perturb,
1061
+ self.perturb_group,
1062
+ self.cell_states_to_model,
1063
+ state_embs_dict,
1064
+ self.pad_token_id,
1065
+ model_input_size,
1066
+ self.nproc)
1067
+
1068
+ perturbed_genes = tuple(self.tokens_to_perturb)
1069
+ original_lengths = filtered_input_data["length"]
1070
+ if self.cell_states_to_model is None:
1071
+ # update cos sims dict
1072
+ # key is tuple of (perturbed_gene, affected_gene)
1073
+ # or (perturbed_genes, "cell_emb") for avg cell emb change
1074
+ cos_sims_data = cos_sims_data.to("cuda")
1075
+ max_padded_len = cos_sims_data.shape[1]
1076
+ for j in range(cos_sims_data.shape[0]):
1077
+ # remove padding before mean pooling cell embedding
1078
+ original_length = original_lengths[j]
1079
+ gene_list = filtered_input_data[j]["input_ids"]
1080
+ indices_removed = indices_to_perturb[j]
1081
+ padding_to_remove = max_padded_len - (original_length \
1082
+ - len(self.tokens_to_perturb) \
1083
+ - len(indices_removed))
1084
+ nonpadding_cos_sims_data = cos_sims_data[j][:-padding_to_remove]
1085
+ cell_cos_sim = torch.mean(nonpadding_cos_sims_data).item()
1086
+ cos_sims_dict[(perturbed_genes, "cell_emb")] += [cell_cos_sim]
1087
+
1088
+ if self.emb_mode == "cell_and_gene":
1089
+ for k in range(cos_sims_data.shape[1]):
1090
+ cos_sim_value = nonpadding_cos_sims_data[k]
1091
+ affected_gene = gene_list[k].item()
1092
+ cos_sims_dict[(perturbed_genes, affected_gene)] += [cos_sim_value.item()]
1093
+ else:
1094
+ # update cos sims dict
1095
+ # key is tuple of (perturbed_genes, "cell_emb")
1096
+ # value is list of tuples of cos sims for cell_states_to_model
1097
+ origin_state_key = self.cell_states_to_model["start_state"]
1098
+ cos_sims_origin = cos_sims_data[origin_state_key]
1099
+ for j in range(cos_sims_origin.shape[0]):
1100
+ data_list = []
1101
+ for data in list(cos_sims_data.values()):
1102
+ data_item = data.to("cuda")
1103
+ data_list += [data_item[j].item()]
1104
+ cos_sims_dict[(perturbed_genes, "cell_emb")] += [tuple(data_list)]
1105
+
1106
+ with open(f"{output_path_prefix}_raw.pickle", "wb") as fp:
1107
+ pickle.dump(cos_sims_dict, fp)
1108
+
1109
+ # make perturbation batch w/ multiple perturbations in single cell
1110
+ if self.perturb_group == False:
1111
+
1112
+ for i in trange(len(filtered_input_data)):
1113
+ example_cell = filtered_input_data.select([i])
1114
+ original_emb = forward_pass_single_cell(model, example_cell, layer_to_quant)
1115
+ gene_list = torch.squeeze(example_cell["input_ids"])
1116
+
1117
+ # reset to original type to prevent downstream issues due to forward_pass_single_cell modifying as torch format in place
1118
+ example_cell = filtered_input_data.select([i])
1119
+
1120
+ if self.anchor_token is None:
1121
+ for combo_lvl in range(self.combos+1):
1122
+ perturbation_batch, indices_to_perturb = make_perturbation_batch(example_cell,
1123
+ self.perturb_type,
1124
+ self.tokens_to_perturb,
1125
+ self.anchor_token,
1126
+ combo_lvl,
1127
+ self.nproc)
1128
+ cos_sims_data = quant_cos_sims(model,
1129
+ self.perturb_type,
1130
+ perturbation_batch,
1131
+ self.forward_batch_size,
1132
+ layer_to_quant,
1133
+ original_emb,
1134
+ self.tokens_to_perturb,
1135
+ indices_to_perturb,
1136
+ self.perturb_group,
1137
+ self.cell_states_to_model,
1138
+ state_embs_dict,
1139
+ self.pad_token_id,
1140
+ model_input_size,
1141
+ self.nproc)
1142
+
1143
+ if self.cell_states_to_model is None:
1144
+ # update cos sims dict
1145
+ # key is tuple of (perturbed_gene, affected_gene)
1146
+ # or (perturbed_gene, "cell_emb") for avg cell emb change
1147
+ cos_sims_data = cos_sims_data.to("cuda")
1148
+ for j in range(cos_sims_data.shape[0]):
1149
+ if self.tokens_to_perturb != "all":
1150
+ j_index = torch.tensor(indices_to_perturb[j])
1151
+ if j_index.shape[0]>1:
1152
+ j_index = torch.squeeze(j_index)
1153
+ else:
1154
+ j_index = torch.tensor([j])
1155
+ perturbed_gene = torch.index_select(gene_list, 0, j_index)
1156
+
1157
+ if perturbed_gene.shape[0]==1:
1158
+ perturbed_gene = perturbed_gene.item()
1159
+ elif perturbed_gene.shape[0]>1:
1160
+ perturbed_gene = tuple(perturbed_gene.tolist())
1161
+
1162
+ cell_cos_sim = torch.mean(cos_sims_data[j]).item()
1163
+ cos_sims_dict[(perturbed_gene, "cell_emb")] += [cell_cos_sim]
1164
+
1165
+ # not_j_index = list(set(i for i in range(gene_list.shape[0])).difference(j_index))
1166
+ # gene_list_j = torch.index_select(gene_list, 0, j_index)
1167
+ if self.emb_mode == "cell_and_gene":
1168
+ for k in range(cos_sims_data.shape[1]):
1169
+ cos_sim_value = cos_sims_data[j][k]
1170
+ affected_gene = gene_list[k].item()
1171
+ cos_sims_dict[(perturbed_gene, affected_gene)] += [cos_sim_value.item()]
1172
+ else:
1173
+ # update cos sims dict
1174
+ # key is tuple of (perturbed_gene, "cell_emb")
1175
+ # value is list of tuples of cos sims for cell_states_to_model
1176
+ origin_state_key = self.cell_states_to_model["start_state"]
1177
+ cos_sims_origin = cos_sims_data[origin_state_key]
1178
+
1179
+ for j in range(cos_sims_origin.shape[0]):
1180
+ if (self.tokens_to_perturb != "all") or (combo_lvl>0):
1181
+ j_index = torch.tensor(indices_to_perturb[j])
1182
+ if j_index.shape[0]>1:
1183
+ j_index = torch.squeeze(j_index)
1184
+ else:
1185
+ j_index = torch.tensor([j])
1186
+ perturbed_gene = torch.index_select(gene_list, 0, j_index)
1187
+
1188
+ if perturbed_gene.shape[0]==1:
1189
+ perturbed_gene = perturbed_gene.item()
1190
+ elif perturbed_gene.shape[0]>1:
1191
+ perturbed_gene = tuple(perturbed_gene.tolist())
1192
+
1193
+ data_list = []
1194
+ for data in list(cos_sims_data.values()):
1195
+ data_item = data.to("cuda")
1196
+ cell_data = torch.mean(data_item[j]).item()
1197
+ data_list += [cell_data]
1198
+ cos_sims_dict[(perturbed_gene, "cell_emb")] += [tuple(data_list)]
1199
+
1200
+ elif self.anchor_token is not None:
1201
+ perturbation_batch, indices_to_perturb = make_perturbation_batch(example_cell,
1202
+ self.perturb_type,
1203
+ self.tokens_to_perturb,
1204
+ None, # first run without anchor token to test individual gene perturbations
1205
+ 0,
1206
+ self.nproc)
1207
+ cos_sims_data = quant_cos_sims(model,
1208
+ self.perturb_type,
1209
+ perturbation_batch,
1210
+ self.forward_batch_size,
1211
+ layer_to_quant,
1212
+ original_emb,
1213
+ self.tokens_to_perturb,
1214
+ indices_to_perturb,
1215
+ self.perturb_group,
1216
+ self.cell_states_to_model,
1217
+ state_embs_dict,
1218
+ self.pad_token_id,
1219
+ model_input_size,
1220
+ self.nproc)
1221
+ cos_sims_data = cos_sims_data.to("cuda")
1222
+
1223
+ combo_perturbation_batch, combo_indices_to_perturb = make_perturbation_batch(example_cell,
1224
+ self.perturb_type,
1225
+ self.tokens_to_perturb,
1226
+ self.anchor_token,
1227
+ 1,
1228
+ self.nproc)
1229
+ combo_cos_sims_data = quant_cos_sims(model,
1230
+ self.perturb_type,
1231
+ combo_perturbation_batch,
1232
+ self.forward_batch_size,
1233
+ layer_to_quant,
1234
+ original_emb,
1235
+ self.tokens_to_perturb,
1236
+ combo_indices_to_perturb,
1237
+ self.perturb_group,
1238
+ self.cell_states_to_model,
1239
+ state_embs_dict,
1240
+ self.pad_token_id,
1241
+ model_input_size,
1242
+ self.nproc)
1243
+ combo_cos_sims_data = combo_cos_sims_data.to("cuda")
1244
+
1245
+ # update cos sims dict
1246
+ # key is tuple of (perturbed_gene, "cell_emb") for avg cell emb change
1247
+ anchor_index = example_cell["input_ids"][0].index(self.anchor_token[0])
1248
+ anchor_cell_cos_sim = torch.mean(cos_sims_data[anchor_index]).item()
1249
+ non_anchor_indices = [k for k in range(cos_sims_data.shape[0]) if k != anchor_index]
1250
+ cos_sims_data = cos_sims_data[non_anchor_indices,:]
1251
+
1252
+ for j in range(cos_sims_data.shape[0]):
1253
+
1254
+ if j<anchor_index:
1255
+ j_index = torch.tensor([j])
1256
+ else:
1257
+ j_index = torch.tensor([j+1])
1258
+
1259
+ perturbed_gene = torch.index_select(gene_list, 0, j_index)
1260
+ perturbed_gene = perturbed_gene.item()
1261
+
1262
+ cell_cos_sim = torch.mean(cos_sims_data[j]).item()
1263
+ combo_cos_sim = torch.mean(combo_cos_sims_data[j]).item()
1264
+ cos_sims_dict[(perturbed_gene, "cell_emb")] += [(anchor_cell_cos_sim, # cos sim anchor gene alone
1265
+ cell_cos_sim, # cos sim deleted gene alone
1266
+ combo_cos_sim)] # cos sim anchor gene + deleted gene
1267
+
1268
+ # save dict to disk every 100 cells
1269
+ if (i/100).is_integer():
1270
+ with open(f"{output_path_prefix}{pickle_batch}_raw.pickle", "wb") as fp:
1271
+ pickle.dump(cos_sims_dict, fp)
1272
+ # reset and clear memory every 1000 cells
1273
+ if (i/1000).is_integer():
1274
+ pickle_batch = pickle_batch+1
1275
+ # clear memory
1276
+ del perturbed_gene
1277
+ del cos_sims_data
1278
+ if self.cell_states_to_model is None:
1279
+ del cell_cos_sim
1280
+ if self.cell_states_to_model is not None:
1281
+ del cell_data
1282
+ del data_list
1283
+ elif self.anchor_token is None:
1284
+ if self.emb_mode == "cell_and_gene":
1285
+ del affected_gene
1286
+ del cos_sim_value
1287
+ else:
1288
+ del combo_cos_sim
1289
+ del combo_cos_sims_data
1290
+ # reset dict
1291
+ del cos_sims_dict
1292
+ cos_sims_dict = defaultdict(list)
1293
+ torch.cuda.empty_cache()
1294
+
1295
+ # save remainder cells
1296
+ with open(f"{output_path_prefix}{pickle_batch}_raw.pickle", "wb") as fp:
1297
+ pickle.dump(cos_sims_dict, fp)
geneformer/in_silico_perturber_stats.py ADDED
@@ -0,0 +1,716 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Geneformer in silico perturber stats generator.
3
+
4
+ Usage:
5
+ from geneformer import InSilicoPerturberStats
6
+ ispstats = InSilicoPerturberStats(mode="goal_state_shift",
7
+ combos=0,
8
+ anchor_gene=None,
9
+ cell_states_to_model={"state_key": "disease",
10
+ "start_state": "dcm",
11
+ "goal_state": "nf",
12
+ "alt_states": ["hcm", "other1", "other2"]})
13
+ ispstats.get_stats("path/to/input_data",
14
+ None,
15
+ "path/to/output_directory",
16
+ "output_prefix")
17
+ """
18
+
19
+
20
+ import os
21
+ import logging
22
+ import numpy as np
23
+ import pandas as pd
24
+ import pickle
25
+ import random
26
+ import statsmodels.stats.multitest as smt
27
+ from pathlib import Path
28
+ from scipy.stats import ranksums
29
+ from sklearn.mixture import GaussianMixture
30
+ from tqdm.notebook import trange, tqdm
31
+
32
+ from .in_silico_perturber import flatten_list
33
+
34
+ from .tokenizer import TOKEN_DICTIONARY_FILE
35
+
36
+ GENE_NAME_ID_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl"
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+ # invert dictionary keys/values
41
+ def invert_dict(dictionary):
42
+ return {v: k for k, v in dictionary.items()}
43
+
44
+ # read raw dictionary files
45
+ def read_dictionaries(input_data_directory, cell_or_gene_emb, anchor_token):
46
+ file_found = 0
47
+ file_path_list = []
48
+ dict_list = []
49
+ for file in os.listdir(input_data_directory):
50
+ # process only _raw.pickle files
51
+ if file.endswith("_raw.pickle"):
52
+ file_found = 1
53
+ file_path_list += [f"{input_data_directory}/{file}"]
54
+ for file_path in tqdm(file_path_list):
55
+ with open(file_path, "rb") as fp:
56
+ cos_sims_dict = pickle.load(fp)
57
+ if cell_or_gene_emb == "cell":
58
+ cell_emb_dict = {k: v for k,
59
+ v in cos_sims_dict.items() if v and "cell_emb" in k}
60
+ dict_list += [cell_emb_dict]
61
+ elif cell_or_gene_emb == "gene":
62
+ gene_emb_dict = {k: v for k,
63
+ v in cos_sims_dict.items() if v and anchor_token == k[0]}
64
+ dict_list += [gene_emb_dict]
65
+ if file_found == 0:
66
+ logger.error(
67
+ "No raw data for processing found within provided directory. " \
68
+ "Please ensure data files end with '_raw.pickle'.")
69
+ raise
70
+ return dict_list
71
+
72
+ # get complete gene list
73
+ def get_gene_list(dict_list,mode):
74
+ if mode == "cell":
75
+ position = 0
76
+ elif mode == "gene":
77
+ position = 1
78
+ gene_set = set()
79
+ for dict_i in dict_list:
80
+ gene_set.update([k[position] for k, v in dict_i.items() if v])
81
+ gene_list = list(gene_set)
82
+ if mode == "gene":
83
+ gene_list.remove("cell_emb")
84
+ gene_list.sort()
85
+ return gene_list
86
+
87
+ def token_tuple_to_ensembl_ids(token_tuple, gene_token_id_dict):
88
+ return tuple([gene_token_id_dict.get(i, np.nan) for i in token_tuple])
89
+
90
+ def n_detections(token, dict_list, mode, anchor_token):
91
+ cos_sim_megalist = []
92
+ for dict_i in dict_list:
93
+ if mode == "cell":
94
+ cos_sim_megalist += dict_i.get((token, "cell_emb"),[])
95
+ elif mode == "gene":
96
+ cos_sim_megalist += dict_i.get((anchor_token, token),[])
97
+ return len(cos_sim_megalist)
98
+
99
+ def get_fdr(pvalues):
100
+ return list(smt.multipletests(pvalues, alpha=0.05, method="fdr_bh")[1])
101
+
102
+ def get_impact_component(test_value, gaussian_mixture_model):
103
+ impact_border = gaussian_mixture_model.means_[0][0]
104
+ nonimpact_border = gaussian_mixture_model.means_[1][0]
105
+ if test_value > nonimpact_border:
106
+ impact_component = 0
107
+ elif test_value < impact_border:
108
+ impact_component = 1
109
+ else:
110
+ impact_component_raw = gaussian_mixture_model.predict([[test_value]])[0]
111
+ if impact_component_raw == 1:
112
+ impact_component = 0
113
+ elif impact_component_raw == 0:
114
+ impact_component = 1
115
+ return impact_component
116
+
117
+ # aggregate data for single perturbation in multiple cells
118
+ def isp_aggregate_grouped_perturb(cos_sims_df, dict_list):
119
+ names=["Cosine_shift"]
120
+ cos_sims_full_df = pd.DataFrame(columns=names)
121
+
122
+ cos_shift_data = []
123
+ token = cos_sims_df["Gene"][0]
124
+ for dict_i in dict_list:
125
+ cos_shift_data += dict_i.get((token, "cell_emb"),[])
126
+ cos_sims_full_df["Cosine_shift"] = cos_shift_data
127
+ return cos_sims_full_df
128
+
129
+ # stats comparing cos sim shifts towards goal state of test perturbations vs random perturbations
130
+ def isp_stats_to_goal_state(cos_sims_df, dict_list, cell_states_to_model, genes_perturbed):
131
+ cell_state_key = cell_states_to_model["start_state"]
132
+ if ("alt_states" not in cell_states_to_model.keys()) \
133
+ or (len(cell_states_to_model["alt_states"]) == 0) \
134
+ or (cell_states_to_model["alt_states"] == [None]):
135
+ alt_end_state_exists = False
136
+ elif (len(cell_states_to_model["alt_states"]) > 0) and (cell_states_to_model["alt_states"] != [None]):
137
+ alt_end_state_exists = True
138
+
139
+ # for single perturbation in multiple cells, there are no random perturbations to compare to
140
+ if genes_perturbed != "all":
141
+ names=["Shift_to_goal_end",
142
+ "Shift_to_alt_end"]
143
+ if alt_end_state_exists == False:
144
+ names.remove("Shift_to_alt_end")
145
+ cos_sims_full_df = pd.DataFrame(columns=names)
146
+
147
+ cos_shift_data = []
148
+ token = cos_sims_df["Gene"][0]
149
+ for dict_i in dict_list:
150
+ cos_shift_data += dict_i.get((token, "cell_emb"),[])
151
+ if alt_end_state_exists == False:
152
+ cos_sims_full_df["Shift_to_goal_end"] = [goal_end for start_state,goal_end in cos_shift_data]
153
+ if alt_end_state_exists == True:
154
+ cos_sims_full_df["Shift_to_goal_end"] = [goal_end for start_state,goal_end,alt_end in cos_shift_data]
155
+ cos_sims_full_df["Shift_to_alt_end"] = [alt_end for start_state,goal_end,alt_end in cos_shift_data]
156
+
157
+ # sort by shift to desired state
158
+ cos_sims_full_df = cos_sims_full_df.sort_values(by=["Shift_to_goal_end"],
159
+ ascending=[False])
160
+ return cos_sims_full_df
161
+
162
+ elif genes_perturbed == "all":
163
+ random_tuples = []
164
+ for i in trange(cos_sims_df.shape[0]):
165
+ token = cos_sims_df["Gene"][i]
166
+ for dict_i in dict_list:
167
+ random_tuples += dict_i.get((token, "cell_emb"),[])
168
+
169
+ if alt_end_state_exists == False:
170
+ goal_end_random_megalist = [goal_end for start_state,goal_end in random_tuples]
171
+ elif alt_end_state_exists == True:
172
+ goal_end_random_megalist = [goal_end for start_state,goal_end,alt_end in random_tuples]
173
+ alt_end_random_megalist = [alt_end for start_state,goal_end,alt_end in random_tuples]
174
+
175
+ # downsample to improve speed of ranksums
176
+ if len(goal_end_random_megalist) > 100_000:
177
+ random.seed(42)
178
+ goal_end_random_megalist = random.sample(goal_end_random_megalist, k=100_000)
179
+ if alt_end_state_exists == True:
180
+ if len(alt_end_random_megalist) > 100_000:
181
+ random.seed(42)
182
+ alt_end_random_megalist = random.sample(alt_end_random_megalist, k=100_000)
183
+
184
+ names=["Gene",
185
+ "Gene_name",
186
+ "Ensembl_ID",
187
+ "Shift_to_goal_end",
188
+ "Shift_to_alt_end",
189
+ "Goal_end_vs_random_pval",
190
+ "Alt_end_vs_random_pval"]
191
+ if alt_end_state_exists == False:
192
+ names.remove("Shift_to_alt_end")
193
+ names.remove("Alt_end_vs_random_pval")
194
+ cos_sims_full_df = pd.DataFrame(columns=names)
195
+
196
+ for i in trange(cos_sims_df.shape[0]):
197
+ token = cos_sims_df["Gene"][i]
198
+ name = cos_sims_df["Gene_name"][i]
199
+ ensembl_id = cos_sims_df["Ensembl_ID"][i]
200
+ cos_shift_data = []
201
+
202
+ for dict_i in dict_list:
203
+ cos_shift_data += dict_i.get((token, "cell_emb"),[])
204
+
205
+ if alt_end_state_exists == False:
206
+ goal_end_cos_sim_megalist = [goal_end for start_state,goal_end in cos_shift_data]
207
+ elif alt_end_state_exists == True:
208
+ goal_end_cos_sim_megalist = [goal_end for start_state,goal_end,alt_end in cos_shift_data]
209
+ alt_end_cos_sim_megalist = [alt_end for start_state,goal_end,alt_end in cos_shift_data]
210
+ mean_alt_end = np.mean(alt_end_cos_sim_megalist)
211
+ pval_alt_end = ranksums(alt_end_random_megalist,alt_end_cos_sim_megalist).pvalue
212
+
213
+ mean_goal_end = np.mean(goal_end_cos_sim_megalist)
214
+ pval_goal_end = ranksums(goal_end_random_megalist,goal_end_cos_sim_megalist).pvalue
215
+
216
+ if alt_end_state_exists == False:
217
+ data_i = [token,
218
+ name,
219
+ ensembl_id,
220
+ mean_goal_end,
221
+ pval_goal_end]
222
+ elif alt_end_state_exists == True:
223
+ data_i = [token,
224
+ name,
225
+ ensembl_id,
226
+ mean_goal_end,
227
+ mean_alt_end,
228
+ pval_goal_end,
229
+ pval_alt_end]
230
+
231
+ cos_sims_df_i = pd.DataFrame(dict(zip(names,data_i)),index=[i])
232
+ cos_sims_full_df = pd.concat([cos_sims_full_df,cos_sims_df_i])
233
+
234
+ cos_sims_full_df["Goal_end_FDR"] = get_fdr(list(cos_sims_full_df["Goal_end_vs_random_pval"]))
235
+ if alt_end_state_exists == True:
236
+ cos_sims_full_df["Alt_end_FDR"] = get_fdr(list(cos_sims_full_df["Alt_end_vs_random_pval"]))
237
+
238
+ # quantify number of detections of each gene
239
+ cos_sims_full_df["N_Detections"] = [n_detections(i, dict_list, "cell", None) for i in cos_sims_full_df["Gene"]]
240
+
241
+ # sort by shift to desired state\
242
+ cos_sims_full_df["Sig"] = [1 if fdr<0.05 else 0 for fdr in cos_sims_full_df["Goal_end_FDR"]]
243
+ cos_sims_full_df = cos_sims_full_df.sort_values(by=["Sig",
244
+ "Shift_to_goal_end",
245
+ "Goal_end_FDR"],
246
+ ascending=[False,False,True])
247
+
248
+ return cos_sims_full_df
249
+
250
+ # stats comparing cos sim shifts of test perturbations vs null distribution
251
+ def isp_stats_vs_null(cos_sims_df, dict_list, null_dict_list):
252
+ cos_sims_full_df = cos_sims_df.copy()
253
+
254
+ cos_sims_full_df["Test_avg_shift"] = np.zeros(cos_sims_df.shape[0], dtype=float)
255
+ cos_sims_full_df["Null_avg_shift"] = np.zeros(cos_sims_df.shape[0], dtype=float)
256
+ cos_sims_full_df["Test_vs_null_avg_shift"] = np.zeros(cos_sims_df.shape[0], dtype=float)
257
+ cos_sims_full_df["Test_vs_null_pval"] = np.zeros(cos_sims_df.shape[0], dtype=float)
258
+ cos_sims_full_df["Test_vs_null_FDR"] = np.zeros(cos_sims_df.shape[0], dtype=float)
259
+ cos_sims_full_df["N_Detections_test"] = np.zeros(cos_sims_df.shape[0], dtype="uint32")
260
+ cos_sims_full_df["N_Detections_null"] = np.zeros(cos_sims_df.shape[0], dtype="uint32")
261
+
262
+ for i in trange(cos_sims_df.shape[0]):
263
+ token = cos_sims_df["Gene"][i]
264
+ test_shifts = []
265
+ null_shifts = []
266
+
267
+ for dict_i in dict_list:
268
+ test_shifts += dict_i.get((token, "cell_emb"),[])
269
+
270
+ for dict_i in null_dict_list:
271
+ null_shifts += dict_i.get((token, "cell_emb"),[])
272
+
273
+ cos_sims_full_df.loc[i, "Test_avg_shift"] = np.mean(test_shifts)
274
+ cos_sims_full_df.loc[i, "Null_avg_shift"] = np.mean(null_shifts)
275
+ cos_sims_full_df.loc[i, "Test_vs_null_avg_shift"] = np.mean(test_shifts)-np.mean(null_shifts)
276
+ cos_sims_full_df.loc[i, "Test_vs_null_pval"] = ranksums(test_shifts,
277
+ null_shifts, nan_policy="omit").pvalue
278
+
279
+ cos_sims_full_df.loc[i, "N_Detections_test"] = len(test_shifts)
280
+ cos_sims_full_df.loc[i, "N_Detections_null"] = len(null_shifts)
281
+
282
+ cos_sims_full_df["Test_vs_null_FDR"] = get_fdr(cos_sims_full_df["Test_vs_null_pval"])
283
+
284
+ cos_sims_full_df["Sig"] = [1 if fdr<0.05 else 0 for fdr in cos_sims_full_df["Test_vs_null_FDR"]]
285
+ cos_sims_full_df = cos_sims_full_df.sort_values(by=["Sig",
286
+ "Test_vs_null_avg_shift",
287
+ "Test_vs_null_FDR"],
288
+ ascending=[False,False,True])
289
+ return cos_sims_full_df
290
+
291
+ # stats for identifying perturbations with largest effect within a given set of cells
292
+ # fits a mixture model to 2 components (impact vs. non-impact) and
293
+ # reports the most likely component for each test perturbation
294
+ # Note: because assumes given perturbation has a consistent effect in the cells tested,
295
+ # we recommend only using the mixture model strategy with uniform cell populations
296
+ def isp_stats_mixture_model(cos_sims_df, dict_list, combos, anchor_token):
297
+
298
+ names=["Gene",
299
+ "Gene_name",
300
+ "Ensembl_ID"]
301
+
302
+ if combos == 0:
303
+ names += ["Test_avg_shift"]
304
+ elif combos == 1:
305
+ names += ["Anchor_shift",
306
+ "Test_token_shift",
307
+ "Sum_of_indiv_shifts",
308
+ "Combo_shift",
309
+ "Combo_minus_sum_shift"]
310
+
311
+ names += ["Impact_component",
312
+ "Impact_component_percent"]
313
+
314
+ cos_sims_full_df = pd.DataFrame(columns=names)
315
+ avg_values = []
316
+ gene_names = []
317
+
318
+ for i in trange(cos_sims_df.shape[0]):
319
+ token = cos_sims_df["Gene"][i]
320
+ name = cos_sims_df["Gene_name"][i]
321
+ ensembl_id = cos_sims_df["Ensembl_ID"][i]
322
+ cos_shift_data = []
323
+
324
+ for dict_i in dict_list:
325
+ if (combos == 0) and (anchor_token is not None):
326
+ cos_shift_data += dict_i.get((anchor_token, token),[])
327
+ else:
328
+ cos_shift_data += dict_i.get((token, "cell_emb"),[])
329
+
330
+ # Extract values for current gene
331
+ if combos == 0:
332
+ test_values = cos_shift_data
333
+ elif combos == 1:
334
+ test_values = []
335
+ for tup in cos_shift_data:
336
+ test_values.append(tup[2])
337
+
338
+ if len(test_values) > 0:
339
+ avg_value = np.mean(test_values)
340
+ avg_values.append(avg_value)
341
+ gene_names.append(name)
342
+
343
+ # fit Gaussian mixture model to dataset of mean for each gene
344
+ avg_values_to_fit = np.array(avg_values).reshape(-1, 1)
345
+ gm = GaussianMixture(n_components=2, random_state=0).fit(avg_values_to_fit)
346
+
347
+ for i in trange(cos_sims_df.shape[0]):
348
+ token = cos_sims_df["Gene"][i]
349
+ name = cos_sims_df["Gene_name"][i]
350
+ ensembl_id = cos_sims_df["Ensembl_ID"][i]
351
+ cos_shift_data = []
352
+
353
+ for dict_i in dict_list:
354
+ if (combos == 0) and (anchor_token is not None):
355
+ cos_shift_data += dict_i.get((anchor_token, token),[])
356
+ else:
357
+ cos_shift_data += dict_i.get((token, "cell_emb"),[])
358
+
359
+ if combos == 0:
360
+ mean_test = np.mean(cos_shift_data)
361
+ impact_components = [get_impact_component(value,gm) for value in cos_shift_data]
362
+ elif combos == 1:
363
+ anchor_cos_sim_megalist = [anchor for anchor,token,combo in cos_shift_data]
364
+ token_cos_sim_megalist = [token for anchor,token,combo in cos_shift_data]
365
+ anchor_plus_token_cos_sim_megalist = [1-((1-anchor)+(1-token)) for anchor,token,combo in cos_shift_data]
366
+ combo_anchor_token_cos_sim_megalist = [combo for anchor,token,combo in cos_shift_data]
367
+ combo_minus_sum_cos_sim_megalist = [combo-(1-((1-anchor)+(1-token))) for anchor,token,combo in cos_shift_data]
368
+
369
+ mean_anchor = np.mean(anchor_cos_sim_megalist)
370
+ mean_token = np.mean(token_cos_sim_megalist)
371
+ mean_sum = np.mean(anchor_plus_token_cos_sim_megalist)
372
+ mean_test = np.mean(combo_anchor_token_cos_sim_megalist)
373
+ mean_combo_minus_sum = np.mean(combo_minus_sum_cos_sim_megalist)
374
+
375
+ impact_components = [get_impact_component(value,gm) for value in combo_anchor_token_cos_sim_megalist]
376
+
377
+ impact_component = get_impact_component(mean_test,gm)
378
+ impact_component_percent = np.mean(impact_components)*100
379
+
380
+ data_i = [token,
381
+ name,
382
+ ensembl_id]
383
+ if combos == 0:
384
+ data_i += [mean_test]
385
+ elif combos == 1:
386
+ data_i += [mean_anchor,
387
+ mean_token,
388
+ mean_sum,
389
+ mean_test,
390
+ mean_combo_minus_sum]
391
+ data_i += [impact_component,
392
+ impact_component_percent]
393
+
394
+ cos_sims_df_i = pd.DataFrame(dict(zip(names,data_i)),index=[i])
395
+ cos_sims_full_df = pd.concat([cos_sims_full_df,cos_sims_df_i])
396
+
397
+ # quantify number of detections of each gene
398
+ cos_sims_full_df["N_Detections"] = [n_detections(i,
399
+ dict_list,
400
+ "gene",
401
+ anchor_token) for i in cos_sims_full_df["Gene"]]
402
+
403
+ if combos == 0:
404
+ cos_sims_full_df = cos_sims_full_df.sort_values(by=["Impact_component",
405
+ "Test_avg_shift"],
406
+ ascending=[False,True])
407
+ elif combos == 1:
408
+ cos_sims_full_df = cos_sims_full_df.sort_values(by=["Impact_component",
409
+ "Combo_minus_sum_shift"],
410
+ ascending=[False,True])
411
+ return cos_sims_full_df
412
+
413
+ class InSilicoPerturberStats:
414
+ valid_option_dict = {
415
+ "mode": {"goal_state_shift","vs_null","mixture_model","aggregate_data"},
416
+ "combos": {0,1},
417
+ "anchor_gene": {None, str},
418
+ "cell_states_to_model": {None, dict},
419
+ }
420
+ def __init__(
421
+ self,
422
+ mode="mixture_model",
423
+ genes_perturbed="all",
424
+ combos=0,
425
+ anchor_gene=None,
426
+ cell_states_to_model=None,
427
+ token_dictionary_file=TOKEN_DICTIONARY_FILE,
428
+ gene_name_id_dictionary_file=GENE_NAME_ID_DICTIONARY_FILE,
429
+ ):
430
+ """
431
+ Initialize in silico perturber stats generator.
432
+
433
+ Parameters
434
+ ----------
435
+ mode : {"goal_state_shift","vs_null","mixture_model","aggregate_data"}
436
+ Type of stats.
437
+ "goal_state_shift": perturbation vs. random for desired cell state shift
438
+ "vs_null": perturbation vs. null from provided null distribution dataset
439
+ "mixture_model": perturbation in impact vs. no impact component of mixture model (no goal direction)
440
+ "aggregate_data": aggregates cosine shifts for single perturbation in multiple cells
441
+ genes_perturbed : "all", list
442
+ Genes perturbed in isp experiment.
443
+ Default is assuming genes_to_perturb in isp experiment was "all" (each gene in each cell).
444
+ Otherwise, may provide a list of ENSEMBL IDs of genes perturbed as a group all together.
445
+ combos : {0,1,2}
446
+ Whether to perturb genes individually (0), in pairs (1), or in triplets (2).
447
+ anchor_gene : None, str
448
+ ENSEMBL ID of gene to use as anchor in combination perturbations or in testing effect on downstream genes.
449
+ For example, if combos=1 and anchor_gene="ENSG00000136574":
450
+ analyzes data for anchor gene perturbed in combination with each other gene.
451
+ However, if combos=0 and anchor_gene="ENSG00000136574":
452
+ analyzes data for the effect of anchor gene's perturbation on the embedding of each other gene.
453
+ cell_states_to_model: None, dict
454
+ Cell states to model if testing perturbations that achieve goal state change.
455
+ Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states
456
+ state_key: key specifying name of column in .dataset that defines the start/goal states
457
+ start_state: value in the state_key column that specifies the start state
458
+ goal_state: value in the state_key column taht specifies the goal end state
459
+ alt_states: list of values in the state_key column that specify the alternate end states
460
+ For example: {"state_key": "disease",
461
+ "start_state": "dcm",
462
+ "goal_state": "nf",
463
+ "alt_states": ["hcm", "other1", "other2"]}
464
+ token_dictionary_file : Path
465
+ Path to pickle file containing token dictionary (Ensembl ID:token).
466
+ gene_name_id_dictionary_file : Path
467
+ Path to pickle file containing gene name to ID dictionary (gene name:Ensembl ID).
468
+ """
469
+
470
+ self.mode = mode
471
+ self.genes_perturbed = genes_perturbed
472
+ self.combos = combos
473
+ self.anchor_gene = anchor_gene
474
+ self.cell_states_to_model = cell_states_to_model
475
+
476
+ self.validate_options()
477
+
478
+ # load token dictionary (Ensembl IDs:token)
479
+ with open(token_dictionary_file, "rb") as f:
480
+ self.gene_token_dict = pickle.load(f)
481
+
482
+ # load gene name dictionary (gene name:Ensembl ID)
483
+ with open(gene_name_id_dictionary_file, "rb") as f:
484
+ self.gene_name_id_dict = pickle.load(f)
485
+
486
+ if anchor_gene is None:
487
+ self.anchor_token = None
488
+ else:
489
+ self.anchor_token = self.gene_token_dict[self.anchor_gene]
490
+
491
+ def validate_options(self):
492
+ for attr_name,valid_options in self.valid_option_dict.items():
493
+ attr_value = self.__dict__[attr_name]
494
+ if type(attr_value) not in {list, dict}:
495
+ if attr_name in {"anchor_gene"}:
496
+ continue
497
+ elif attr_value in valid_options:
498
+ continue
499
+ valid_type = False
500
+ for option in valid_options:
501
+ if (option in [int,list,dict]) and isinstance(attr_value, option):
502
+ valid_type = True
503
+ break
504
+ if valid_type:
505
+ continue
506
+ logger.error(
507
+ f"Invalid option for {attr_name}. " \
508
+ f"Valid options for {attr_name}: {valid_options}"
509
+ )
510
+ raise
511
+
512
+ if self.cell_states_to_model is not None:
513
+ if len(self.cell_states_to_model.items()) == 1:
514
+ logger.warning(
515
+ "The single value dictionary for cell_states_to_model will be " \
516
+ "replaced with a dictionary with named keys for start, goal, and alternate states. " \
517
+ "Please specify state_key, start_state, goal_state, and alt_states " \
518
+ "in the cell_states_to_model dictionary for future use. " \
519
+ "For example, cell_states_to_model={" \
520
+ "'state_key': 'disease', " \
521
+ "'start_state': 'dcm', " \
522
+ "'goal_state': 'nf', " \
523
+ "'alt_states': ['hcm', 'other1', 'other2']}"
524
+ )
525
+ for key,value in self.cell_states_to_model.items():
526
+ if (len(value) == 3) and isinstance(value, tuple):
527
+ if isinstance(value[0],list) and isinstance(value[1],list) and isinstance(value[2],list):
528
+ if len(value[0]) == 1 and len(value[1]) == 1:
529
+ all_values = value[0]+value[1]+value[2]
530
+ if len(all_values) == len(set(all_values)):
531
+ continue
532
+ # reformat to the new named key format
533
+ state_values = flatten_list(list(self.cell_states_to_model.values()))
534
+ self.cell_states_to_model = {
535
+ "state_key": list(self.cell_states_to_model.keys())[0],
536
+ "start_state": state_values[0][0],
537
+ "goal_state": state_values[1][0],
538
+ "alt_states": state_values[2:][0]
539
+ }
540
+ elif set(self.cell_states_to_model.keys()) == {"state_key", "start_state", "goal_state", "alt_states"}:
541
+ if (self.cell_states_to_model["state_key"] is None) \
542
+ or (self.cell_states_to_model["start_state"] is None) \
543
+ or (self.cell_states_to_model["goal_state"] is None):
544
+ logger.error(
545
+ "Please specify 'state_key', 'start_state', and 'goal_state' in cell_states_to_model.")
546
+ raise
547
+
548
+ if self.cell_states_to_model["start_state"] == self.cell_states_to_model["goal_state"]:
549
+ logger.error(
550
+ "All states must be unique.")
551
+ raise
552
+
553
+ if self.cell_states_to_model["alt_states"] is not None:
554
+ if type(self.cell_states_to_model["alt_states"]) is not list:
555
+ logger.error(
556
+ "self.cell_states_to_model['alt_states'] must be a list (even if it is one element)."
557
+ )
558
+ raise
559
+ if len(self.cell_states_to_model["alt_states"])!= len(set(self.cell_states_to_model["alt_states"])):
560
+ logger.error(
561
+ "All states must be unique.")
562
+ raise
563
+
564
+ else:
565
+ logger.error(
566
+ "cell_states_to_model must only have the following four keys: " \
567
+ "'state_key', 'start_state', 'goal_state', 'alt_states'." \
568
+ "For example, cell_states_to_model={" \
569
+ "'state_key': 'disease', " \
570
+ "'start_state': 'dcm', " \
571
+ "'goal_state': 'nf', " \
572
+ "'alt_states': ['hcm', 'other1', 'other2']}"
573
+ )
574
+ raise
575
+
576
+ if self.anchor_gene is not None:
577
+ self.anchor_gene = None
578
+ logger.warning(
579
+ "anchor_gene set to None. " \
580
+ "Currently, anchor gene not available " \
581
+ "when modeling multiple cell states.")
582
+
583
+ if self.combos > 0:
584
+ if self.anchor_gene is None:
585
+ logger.error(
586
+ "Currently, stats are only supported for combination " \
587
+ "in silico perturbation run with anchor gene. Please add " \
588
+ "anchor gene when using with combos > 0. ")
589
+ raise
590
+
591
+ if (self.mode == "mixture_model") and (self.genes_perturbed != "all"):
592
+ logger.error(
593
+ "Mixture model mode requires multiple gene perturbations to fit model " \
594
+ "so is incompatible with a single grouped perturbation.")
595
+ raise
596
+ if (self.mode == "aggregate_data") and (self.genes_perturbed == "all"):
597
+ logger.error(
598
+ "Simple data aggregation mode is for single perturbation in multiple cells " \
599
+ "so is incompatible with a genes_perturbed being 'all'.")
600
+ raise
601
+
602
+ def get_stats(self,
603
+ input_data_directory,
604
+ null_dist_data_directory,
605
+ output_directory,
606
+ output_prefix):
607
+ """
608
+ Get stats for in silico perturbation data and save as results in output_directory.
609
+
610
+ Parameters
611
+ ----------
612
+ input_data_directory : Path
613
+ Path to directory containing cos_sim dictionary inputs
614
+ null_dist_data_directory : Path
615
+ Path to directory containing null distribution cos_sim dictionary inputs
616
+ output_directory : Path
617
+ Path to directory where perturbation data will be saved as .csv
618
+ output_prefix : str
619
+ Prefix for output .csv
620
+
621
+ Outputs
622
+ ----------
623
+ Definition of possible columns in .csv output file.
624
+
625
+ Of note, not all columns will be present in all output files.
626
+ Some columns are specific to particular perturbation modes.
627
+
628
+ "Gene": gene token
629
+ "Gene_name": gene name
630
+ "Ensembl_ID": gene Ensembl ID
631
+ "N_Detections": number of cells in which each gene or gene combination was detected in the input dataset
632
+ "Sig": 1 if FDR<0.05, otherwise 0
633
+
634
+ "Shift_to_goal_end": cosine shift from start state towards goal end state in response to given perturbation
635
+ "Shift_to_alt_end": cosine shift from start state towards alternate end state in response to given perturbation
636
+ "Goal_end_vs_random_pval": pvalue of cosine shift from start state towards goal end state by Wilcoxon
637
+ pvalue compares shift caused by perturbing given gene compared to random genes
638
+ "Alt_end_vs_random_pval": pvalue of cosine shift from start state towards alternate end state by Wilcoxon
639
+ pvalue compares shift caused by perturbing given gene compared to random genes
640
+ "Goal_end_FDR": Benjamini-Hochberg correction of "Goal_end_vs_random_pval"
641
+ "Alt_end_FDR": Benjamini-Hochberg correction of "Alt_end_vs_random_pval"
642
+
643
+ "Test_avg_shift": cosine shift in response to given perturbation in cells from test distribution
644
+ "Null_avg_shift": cosine shift in response to given perturbation in cells from null distribution (e.g. random cells)
645
+ "Test_vs_null_avg_shift": difference in cosine shift in cells from test vs. null distribution
646
+ (i.e. "Test_avg_shift" minus "Null_avg_shift")
647
+ "Test_vs_null_pval": pvalue of cosine shift in test vs. null distribution
648
+ "Test_vs_null_FDR": Benjamini-Hochberg correction of "Test_vs_null_pval"
649
+ "N_Detections_test": "N_Detections" in cells from test distribution
650
+ "N_Detections_null": "N_Detections" in cells from null distribution
651
+
652
+ "Anchor_shift": cosine shift in response to given perturbation of anchor gene
653
+ "Test_token_shift": cosine shift in response to given perturbation of test gene
654
+ "Sum_of_indiv_shifts": sum of cosine shifts in response to individually perturbing test and anchor genes
655
+ "Combo_shift": cosine shift in response to given perturbation of both anchor and test gene(s) in combination
656
+ "Combo_minus_sum_shift": difference of cosine shifts in response combo perturbation vs. sum of individual perturbations
657
+ (i.e. "Combo_shift" minus "Sum_of_indiv_shifts")
658
+ "Impact_component": whether the given perturbation was modeled to be within the impact component by the mixture model
659
+ 1: within impact component; 0: not within impact component
660
+ "Impact_component_percent": percent of cells in which given perturbation was modeled to be within impact component
661
+ """
662
+
663
+ if self.mode not in ["goal_state_shift", "vs_null", "mixture_model","aggregate_data"]:
664
+ logger.error(
665
+ "Currently, only modes available are stats for goal_state_shift, " \
666
+ "vs_null (comparing to null distribution), and " \
667
+ "mixture_model (fitting mixture model for perturbations with or without impact.")
668
+ raise
669
+
670
+ self.gene_token_id_dict = invert_dict(self.gene_token_dict)
671
+ self.gene_id_name_dict = invert_dict(self.gene_name_id_dict)
672
+
673
+ # obtain total gene list
674
+ if (self.combos == 0) and (self.anchor_token is not None):
675
+ # cos sim data for effect of gene perturbation on the embedding of each other gene
676
+ dict_list = read_dictionaries(input_data_directory, "gene", self.anchor_token)
677
+ gene_list = get_gene_list(dict_list, "gene")
678
+ else:
679
+ # cos sim data for effect of gene perturbation on the embedding of each cell
680
+ dict_list = read_dictionaries(input_data_directory, "cell", self.anchor_token)
681
+ gene_list = get_gene_list(dict_list, "cell")
682
+
683
+ # initiate results dataframe
684
+ cos_sims_df_initial = pd.DataFrame({"Gene": gene_list,
685
+ "Gene_name": [self.token_to_gene_name(item) \
686
+ for item in gene_list], \
687
+ "Ensembl_ID": [token_tuple_to_ensembl_ids(genes, self.gene_token_id_dict) \
688
+ if self.genes_perturbed != "all" else \
689
+ self.gene_token_id_dict[genes[1]] \
690
+ if isinstance(genes,tuple) else \
691
+ self.gene_token_id_dict[genes] \
692
+ for genes in gene_list]}, \
693
+ index=[i for i in range(len(gene_list))])
694
+
695
+ if self.mode == "goal_state_shift":
696
+ cos_sims_df = isp_stats_to_goal_state(cos_sims_df_initial, dict_list, self.cell_states_to_model, self.genes_perturbed)
697
+
698
+ elif self.mode == "vs_null":
699
+ null_dict_list = read_dictionaries(null_dist_data_directory, "cell", self.anchor_token)
700
+ cos_sims_df = isp_stats_vs_null(cos_sims_df_initial, dict_list, null_dict_list)
701
+
702
+ elif self.mode == "mixture_model":
703
+ cos_sims_df = isp_stats_mixture_model(cos_sims_df_initial, dict_list, self.combos, self.anchor_token)
704
+
705
+ elif self.mode == "aggregate_data":
706
+ cos_sims_df = isp_aggregate_grouped_perturb(cos_sims_df_initial, dict_list)
707
+
708
+ # save perturbation stats to output_path
709
+ output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")
710
+ cos_sims_df.to_csv(output_path)
711
+
712
+ def token_to_gene_name(self, item):
713
+ if isinstance(item,int):
714
+ return self.gene_id_name_dict.get(self.gene_token_id_dict.get(item, np.nan), np.nan)
715
+ if isinstance(item,tuple):
716
+ return tuple([self.gene_id_name_dict.get(self.gene_token_id_dict.get(i, np.nan), np.nan) for i in item])
geneformer/pretrainer.py ADDED
@@ -0,0 +1,822 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Geneformer precollator and pretrainer.
3
+
4
+ Huggingface data collator and trainer modified to accommodate single-cell transcriptomics data.
5
+ """
6
+ import collections
7
+ import math
8
+ import pickle
9
+ import warnings
10
+ from enum import Enum
11
+ from typing import Dict, Iterator, List, Optional, Union
12
+
13
+ import numpy as np
14
+ import torch
15
+ from datasets import Dataset
16
+ from packaging import version
17
+ from torch.utils.data.distributed import DistributedSampler
18
+ from torch.utils.data.sampler import RandomSampler
19
+ from transformers import (
20
+ BatchEncoding,
21
+ DataCollatorForLanguageModeling,
22
+ SpecialTokensMixin,
23
+ Trainer,
24
+ )
25
+ from transformers.file_utils import is_datasets_available, is_sagemaker_dp_enabled
26
+ from transformers.trainer_pt_utils import (
27
+ DistributedLengthGroupedSampler,
28
+ DistributedSamplerWithLoop,
29
+ LengthGroupedSampler,
30
+ )
31
+ from transformers.training_args import ParallelMode
32
+ from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
33
+ from transformers.utils.generic import _is_tensorflow, _is_torch
34
+
35
+ from .tokenizer import TOKEN_DICTIONARY_FILE
36
+
37
+ logger = logging.get_logger(__name__)
38
+ EncodedInput = List[int]
39
+ VERY_LARGE_INTEGER = int(
40
+ 1e30
41
+ ) # This is used to set the max input length for a model with infinite size input
42
+ LARGE_INTEGER = int(
43
+ 1e20
44
+ ) # This is used when we need something big but slightly smaller than VERY_LARGE_INTEGER
45
+
46
+ if is_sagemaker_dp_enabled():
47
+ import smdistributed.dataparallel.torch.distributed as dist
48
+ else:
49
+ import torch.distributed as dist
50
+
51
+ _is_torch_generator_available = False
52
+ if version.parse(torch.__version__) >= version.parse("1.6"):
53
+ _is_torch_generator_available = True
54
+
55
+ with open(TOKEN_DICTIONARY_FILE, "rb") as f:
56
+ token_dictionary = pickle.load(f)
57
+
58
+
59
+ class ExplicitEnum(Enum):
60
+ """
61
+ Enum with more explicit error message for missing values.
62
+ """
63
+
64
+ @classmethod
65
+ def _missing_(cls, value):
66
+ raise ValueError(
67
+ "%r is not a valid %s, please select one of %s"
68
+ % (value, cls.__name__, str(list(cls._value2member_map_.keys())))
69
+ )
70
+
71
+
72
+ class TruncationStrategy(ExplicitEnum):
73
+ """
74
+ Possible values for the ``truncation`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
75
+ tab-completion in an IDE.
76
+ """
77
+
78
+ ONLY_FIRST = "only_first"
79
+ ONLY_SECOND = "only_second"
80
+ LONGEST_FIRST = "longest_first"
81
+ DO_NOT_TRUNCATE = "do_not_truncate"
82
+
83
+
84
+ class PaddingStrategy(ExplicitEnum):
85
+ """
86
+ Possible values for the ``padding`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for tab-completion
87
+ in an IDE.
88
+ """
89
+
90
+ LONGEST = "longest"
91
+ MAX_LENGTH = "max_length"
92
+ DO_NOT_PAD = "do_not_pad"
93
+
94
+
95
+ class TensorType(ExplicitEnum):
96
+ """
97
+ Possible values for the ``return_tensors`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
98
+ tab-completion in an IDE.
99
+ """
100
+
101
+ PYTORCH = "pt"
102
+ TENSORFLOW = "tf"
103
+ NUMPY = "np"
104
+ JAX = "jax"
105
+
106
+
107
+ class GeneformerPreCollator(SpecialTokensMixin):
108
+ def __init__(self, *args, **kwargs) -> None:
109
+
110
+ super().__init__(mask_token = "<mask>", pad_token = "<pad>")
111
+
112
+ self.token_dictionary = kwargs.get("token_dictionary")
113
+ # self.mask_token = "<mask>"
114
+ # self.mask_token_id = self.token_dictionary.get("<mask>")
115
+ # self.pad_token = "<pad>"
116
+ # self.pad_token_id = self.token_dictionary.get("<pad>")
117
+ self.padding_side = "right"
118
+ # self.all_special_ids = [
119
+ # self.token_dictionary.get("<mask>"),
120
+ # self.token_dictionary.get("<pad>"),
121
+ # ]
122
+ self.model_input_names = ["input_ids"]
123
+
124
+ def convert_ids_to_tokens(self,value):
125
+ return self.token_dictionary.get(value)
126
+
127
+ def _get_padding_truncation_strategies(
128
+ self,
129
+ padding=False,
130
+ truncation=False,
131
+ max_length=None,
132
+ pad_to_multiple_of=None,
133
+ verbose=True,
134
+ **kwargs,
135
+ ):
136
+ """
137
+ Find the correct padding/truncation strategy with backward compatibility for old arguments (truncation_strategy
138
+ and pad_to_max_length) and behaviors.
139
+ """
140
+ old_truncation_strategy = kwargs.pop("truncation_strategy", "do_not_truncate")
141
+ old_pad_to_max_length = kwargs.pop("pad_to_max_length", False)
142
+
143
+ # Backward compatibility for previous behavior, maybe we should deprecate it:
144
+ # If you only set max_length, it activates truncation for max_length
145
+ if max_length is not None and padding is False and truncation is False:
146
+ if verbose:
147
+ if not self.deprecation_warnings.get(
148
+ "Truncation-not-explicitly-activated", False
149
+ ):
150
+ logger.warning(
151
+ "Truncation was not explicitly activated but `max_length` is provided a specific value, "
152
+ "please use `truncation=True` to explicitly truncate examples to max length. "
153
+ "Defaulting to 'longest_first' truncation strategy. "
154
+ "If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy "
155
+ "more precisely by providing a specific strategy to `truncation`."
156
+ )
157
+ self.deprecation_warnings["Truncation-not-explicitly-activated"] = True
158
+ truncation = "longest_first"
159
+
160
+ # Get padding strategy
161
+ if padding is False and old_pad_to_max_length:
162
+ if verbose:
163
+ warnings.warn(
164
+ "The `pad_to_max_length` argument is deprecated and will be removed in a future version, "
165
+ "use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or "
166
+ "use `padding='max_length'` to pad to a max length. In this case, you can give a specific "
167
+ "length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the "
168
+ "maximal input size of the model (e.g. 512 for Bert).",
169
+ FutureWarning,
170
+ )
171
+ if max_length is None:
172
+ padding_strategy = PaddingStrategy.LONGEST
173
+ else:
174
+ padding_strategy = PaddingStrategy.MAX_LENGTH
175
+ elif padding is not False:
176
+ if padding is True:
177
+ padding_strategy = (
178
+ PaddingStrategy.LONGEST
179
+ ) # Default to pad to the longest sequence in the batch
180
+ elif not isinstance(padding, PaddingStrategy):
181
+ padding_strategy = PaddingStrategy(padding)
182
+ elif isinstance(padding, PaddingStrategy):
183
+ padding_strategy = padding
184
+ else:
185
+ padding_strategy = PaddingStrategy.DO_NOT_PAD
186
+
187
+ # Get truncation strategy
188
+ if truncation is False and old_truncation_strategy != "do_not_truncate":
189
+ if verbose:
190
+ warnings.warn(
191
+ "The `truncation_strategy` argument is deprecated and will be removed in a future version, "
192
+ "use `truncation=True` to truncate examples to a max length. You can give a specific "
193
+ "length with `max_length` (e.g. `max_length=45`) or leave max_length to None to truncate to the "
194
+ "maximal input size of the model (e.g. 512 for Bert). "
195
+ " If you have pairs of inputs, you can give a specific truncation strategy selected among "
196
+ "`truncation='only_first'` (will only truncate the first sentence in the pairs) "
197
+ "`truncation='only_second'` (will only truncate the second sentence in the pairs) "
198
+ "or `truncation='longest_first'` (will iteratively remove tokens from the longest sentence in the pairs).",
199
+ FutureWarning,
200
+ )
201
+ truncation_strategy = TruncationStrategy(old_truncation_strategy)
202
+ elif truncation is not False:
203
+ if truncation is True:
204
+ truncation_strategy = (
205
+ TruncationStrategy.LONGEST_FIRST
206
+ ) # Default to truncate the longest sequences in pairs of inputs
207
+ elif not isinstance(truncation, TruncationStrategy):
208
+ truncation_strategy = TruncationStrategy(truncation)
209
+ elif isinstance(truncation, TruncationStrategy):
210
+ truncation_strategy = truncation
211
+ else:
212
+ truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
213
+
214
+ # Set max length if needed
215
+ if max_length is None:
216
+ if padding_strategy == PaddingStrategy.MAX_LENGTH:
217
+ if self.model_max_length > LARGE_INTEGER:
218
+ if verbose:
219
+ if not self.deprecation_warnings.get(
220
+ "Asking-to-pad-to-max_length", False
221
+ ):
222
+ logger.warning(
223
+ "Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. "
224
+ "Default to no padding."
225
+ )
226
+ self.deprecation_warnings["Asking-to-pad-to-max_length"] = True
227
+ padding_strategy = PaddingStrategy.DO_NOT_PAD
228
+ else:
229
+ max_length = self.model_max_length
230
+
231
+ if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE:
232
+ if self.model_max_length > LARGE_INTEGER:
233
+ if verbose:
234
+ if not self.deprecation_warnings.get(
235
+ "Asking-to-truncate-to-max_length", False
236
+ ):
237
+ logger.warning(
238
+ "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. "
239
+ "Default to no truncation."
240
+ )
241
+ self.deprecation_warnings[
242
+ "Asking-to-truncate-to-max_length"
243
+ ] = True
244
+ truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
245
+ else:
246
+ max_length = self.model_max_length
247
+
248
+ # Test if we have a padding token
249
+ if padding_strategy != PaddingStrategy.DO_NOT_PAD and (
250
+ not self.pad_token or self.pad_token_id < 0
251
+ ):
252
+ raise ValueError(
253
+ "Asking to pad but the tokenizer does not have a padding token. "
254
+ "Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` "
255
+ "or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`."
256
+ )
257
+
258
+ # Check that we will truncate to a multiple of pad_to_multiple_of if both are provided
259
+ if (
260
+ truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE
261
+ and padding_strategy != PaddingStrategy.DO_NOT_PAD
262
+ and pad_to_multiple_of is not None
263
+ and max_length is not None
264
+ and (max_length % pad_to_multiple_of != 0)
265
+ ):
266
+ raise ValueError(
267
+ f"Truncation and padding are both activated but "
268
+ f"truncation length ({max_length}) is not a multiple of pad_to_multiple_of ({pad_to_multiple_of})."
269
+ )
270
+
271
+ return padding_strategy, truncation_strategy, max_length, kwargs
272
+
273
+ def pad(
274
+ self,
275
+ encoded_inputs: Union[
276
+ BatchEncoding,
277
+ List[BatchEncoding],
278
+ Dict[str, EncodedInput],
279
+ Dict[str, List[EncodedInput]],
280
+ List[Dict[str, EncodedInput]],
281
+ ],
282
+ padding: Union[bool, str, PaddingStrategy] = True,
283
+ max_length: Optional[int] = None,
284
+ pad_to_multiple_of: Optional[int] = None,
285
+ return_attention_mask: Optional[bool] = True,
286
+ return_tensors: Optional[Union[str, TensorType]] = None,
287
+ verbose: bool = True,
288
+ ) -> BatchEncoding:
289
+ """
290
+ Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length
291
+ in the batch.
292
+
293
+ Padding side (left/right) padding token ids are defined at the tokenizer level (with ``self.padding_side``,
294
+ ``self.pad_token_id`` and ``self.pad_token_type_id``)
295
+
296
+ .. note::
297
+
298
+ If the ``encoded_inputs`` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the
299
+ result will use the same type unless you provide a different tensor type with ``return_tensors``. In the
300
+ case of PyTorch tensors, you will lose the specific device of your tensors however.
301
+
302
+ Args:
303
+ encoded_inputs (:class:`~transformers.BatchEncoding`, list of :class:`~transformers.BatchEncoding`, :obj:`Dict[str, List[int]]`, :obj:`Dict[str, List[List[int]]` or :obj:`List[Dict[str, List[int]]]`):
304
+ Tokenized inputs. Can represent one input (:class:`~transformers.BatchEncoding` or :obj:`Dict[str,
305
+ List[int]]`) or a batch of tokenized inputs (list of :class:`~transformers.BatchEncoding`, `Dict[str,
306
+ List[List[int]]]` or `List[Dict[str, List[int]]]`) so you can use this method during preprocessing as
307
+ well as in a PyTorch Dataloader collate function.
308
+
309
+ Instead of :obj:`List[int]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors),
310
+ see the note above for the return type.
311
+ padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
312
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
313
+ index) among:
314
+
315
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
316
+ single sequence if provided).
317
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
318
+ maximum acceptable input length for the model if that argument is not provided.
319
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
320
+ different lengths).
321
+ max_length (:obj:`int`, `optional`):
322
+ Maximum length of the returned list and optionally padding length (see above).
323
+ pad_to_multiple_of (:obj:`int`, `optional`):
324
+ If set will pad the sequence to a multiple of the provided value.
325
+
326
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
327
+ >= 7.5 (Volta).
328
+ return_attention_mask (:obj:`bool`, `optional`):
329
+ Whether to return the attention mask. If left to the default, will return the attention mask according
330
+ to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute.
331
+
332
+ `What are attention masks? <../glossary.html#attention-mask>`__
333
+ return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`):
334
+ If set, will return tensors instead of list of python integers. Acceptable values are:
335
+
336
+ * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
337
+ * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
338
+ * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
339
+ verbose (:obj:`bool`, `optional`, defaults to :obj:`True`):
340
+ Whether or not to print more information and warnings.
341
+ """
342
+ # If we have a list of dicts, let's convert it in a dict of lists
343
+ # We do this to allow using this method as a collate_fn function in PyTorch Dataloader
344
+ if isinstance(encoded_inputs, (list, tuple)) and isinstance(
345
+ encoded_inputs[0], (dict, BatchEncoding)
346
+ ):
347
+ encoded_inputs = {
348
+ key: [example[key] for example in encoded_inputs]
349
+ for key in encoded_inputs[0].keys()
350
+ }
351
+
352
+ # The model's main input name, usually `input_ids`, has be passed for padding
353
+ if self.model_input_names[0] not in encoded_inputs:
354
+ raise ValueError(
355
+ "You should supply an encoding or a list of encodings to this method"
356
+ f"that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}"
357
+ )
358
+
359
+ required_input = encoded_inputs[self.model_input_names[0]]
360
+
361
+ if not required_input:
362
+ if return_attention_mask:
363
+ encoded_inputs["attention_mask"] = []
364
+ return encoded_inputs
365
+
366
+ # If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects
367
+ # and rebuild them afterwards if no return_tensors is specified
368
+ # Note that we lose the specific device the tensor may be on for PyTorch
369
+
370
+ first_element = required_input[0]
371
+ if isinstance(first_element, (list, tuple)):
372
+ # first_element might be an empty list/tuple in some edge cases so we grab the first non empty element.
373
+ index = 0
374
+ while len(required_input[index]) == 0:
375
+ index += 1
376
+ if index < len(required_input):
377
+ first_element = required_input[index][0]
378
+ # At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
379
+ if not isinstance(first_element, (int, list, tuple)):
380
+ if is_tf_available() and _is_tensorflow(first_element):
381
+ return_tensors = "tf" if return_tensors is None else return_tensors
382
+ elif is_torch_available() and _is_torch(first_element):
383
+ return_tensors = "pt" if return_tensors is None else return_tensors
384
+ if isinstance(first_element, np.ndarray):
385
+ return_tensors = "np" if return_tensors is None else return_tensors
386
+ else:
387
+ raise ValueError(
388
+ f"type of {first_element} unknown: {type(first_element)}. "
389
+ f"Should be one of a python, numpy, pytorch or tensorflow object."
390
+ )
391
+
392
+ for key, value in encoded_inputs.items():
393
+ encoded_inputs[key] = to_py_obj(value)
394
+
395
+
396
+ # Convert padding_strategy in PaddingStrategy
397
+ padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies(
398
+ padding=padding, max_length=max_length, verbose=verbose
399
+ )
400
+
401
+ required_input = encoded_inputs[self.model_input_names[0]]
402
+ if required_input and not isinstance(required_input[0], (list, tuple)):
403
+ encoded_inputs = self._pad(
404
+ encoded_inputs,
405
+ max_length=max_length,
406
+ padding_strategy=padding_strategy,
407
+ pad_to_multiple_of=pad_to_multiple_of,
408
+ return_attention_mask=return_attention_mask,
409
+ )
410
+ return BatchEncoding(encoded_inputs, tensor_type=return_tensors)
411
+
412
+ batch_size = len(required_input)
413
+ assert all(
414
+ len(v) == batch_size for v in encoded_inputs.values()
415
+ ), "Some items in the output dictionary have a different batch size than others."
416
+
417
+ if padding_strategy == PaddingStrategy.LONGEST:
418
+ max_length = max(len(inputs) for inputs in required_input)
419
+ padding_strategy = PaddingStrategy.MAX_LENGTH
420
+
421
+ batch_outputs = {}
422
+ for i in range(batch_size):
423
+ inputs = dict((k, v[i]) for k, v in encoded_inputs.items())
424
+ outputs = self._pad(
425
+ inputs,
426
+ max_length=max_length,
427
+ padding_strategy=padding_strategy,
428
+ pad_to_multiple_of=pad_to_multiple_of,
429
+ return_attention_mask=return_attention_mask,
430
+ )
431
+
432
+ for key, value in outputs.items():
433
+ if key not in batch_outputs:
434
+ batch_outputs[key] = []
435
+ batch_outputs[key].append(value)
436
+
437
+ return BatchEncoding(batch_outputs, tensor_type=return_tensors)
438
+
439
+ def _pad(
440
+ self,
441
+ encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
442
+ max_length: Optional[int] = None,
443
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
444
+ pad_to_multiple_of: Optional[int] = None,
445
+ return_attention_mask: Optional[bool] = None,
446
+ ) -> dict:
447
+ """
448
+ Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
449
+
450
+ Args:
451
+ encoded_inputs: Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
452
+ max_length: maximum length of the returned list and optionally padding length (see below).
453
+ Will truncate by taking into account the special tokens.
454
+ padding_strategy: PaddingStrategy to use for padding.
455
+
456
+ - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
457
+ - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
458
+ - PaddingStrategy.DO_NOT_PAD: Do not pad
459
+ The tokenizer padding sides are defined in self.padding_side:
460
+
461
+ - 'left': pads on the left of the sequences
462
+ - 'right': pads on the right of the sequences
463
+ pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
464
+ This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
465
+ >= 7.5 (Volta).
466
+ return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics)
467
+ """
468
+ # Load from model defaults
469
+ if return_attention_mask is None:
470
+ return_attention_mask = "attention_mask" in self.model_input_names
471
+
472
+ required_input = encoded_inputs[self.model_input_names[0]]
473
+
474
+ if padding_strategy == PaddingStrategy.LONGEST:
475
+ max_length = len(required_input)
476
+
477
+ if (
478
+ max_length is not None
479
+ and pad_to_multiple_of is not None
480
+ and (max_length % pad_to_multiple_of != 0)
481
+ ):
482
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
483
+
484
+ needs_to_be_padded = (
485
+ padding_strategy != PaddingStrategy.DO_NOT_PAD
486
+ and len(required_input) != max_length
487
+ )
488
+
489
+ if needs_to_be_padded:
490
+ difference = max_length - len(required_input)
491
+ if self.padding_side == "right":
492
+ if return_attention_mask:
493
+ encoded_inputs["attention_mask"] = [1] * len(required_input) + [
494
+ 0
495
+ ] * difference
496
+ if "token_type_ids" in encoded_inputs:
497
+ encoded_inputs["token_type_ids"] = (
498
+ encoded_inputs["token_type_ids"]
499
+ + [self.pad_token_type_id] * difference
500
+ )
501
+ if "special_tokens_mask" in encoded_inputs:
502
+ encoded_inputs["special_tokens_mask"] = (
503
+ encoded_inputs["special_tokens_mask"] + [1] * difference
504
+ )
505
+ encoded_inputs[self.model_input_names[0]] = (
506
+ required_input + [self.pad_token_id] * difference
507
+ )
508
+ elif self.padding_side == "left":
509
+ if return_attention_mask:
510
+ encoded_inputs["attention_mask"] = [0] * difference + [1] * len(
511
+ required_input
512
+ )
513
+ if "token_type_ids" in encoded_inputs:
514
+ encoded_inputs["token_type_ids"] = [
515
+ self.pad_token_type_id
516
+ ] * difference + encoded_inputs["token_type_ids"]
517
+ if "special_tokens_mask" in encoded_inputs:
518
+ encoded_inputs["special_tokens_mask"] = [
519
+ 1
520
+ ] * difference + encoded_inputs["special_tokens_mask"]
521
+ encoded_inputs[self.model_input_names[0]] = [
522
+ self.pad_token_id
523
+ ] * difference + required_input
524
+ else:
525
+ raise ValueError("Invalid padding strategy:" + str(self.padding_side))
526
+ elif return_attention_mask and "attention_mask" not in encoded_inputs:
527
+ encoded_inputs["attention_mask"] = [1] * len(required_input)
528
+
529
+ return encoded_inputs
530
+
531
+ def get_special_tokens_mask(
532
+ self,
533
+ token_ids_0: List[int],
534
+ token_ids_1: Optional[List[int]] = None,
535
+ already_has_special_tokens: bool = False,
536
+ ) -> List[int]:
537
+ """
538
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
539
+ special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
540
+ Args:
541
+ token_ids_0 (:obj:`List[int]`):
542
+ List of ids of the first sequence.
543
+ token_ids_1 (:obj:`List[int]`, `optional`):
544
+ List of ids of the second sequence.
545
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
546
+ Whether or not the token list is already formatted with special tokens for the model.
547
+ Returns:
548
+ A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
549
+ """
550
+ assert already_has_special_tokens and token_ids_1 is None, (
551
+ "You cannot use ``already_has_special_tokens=False`` with this tokenizer. "
552
+ "Please use a slow (full python) tokenizer to activate this argument."
553
+ "Or set `return_special_tokens_mask=True` when calling the encoding method "
554
+ "to get the special tokens mask in any tokenizer. "
555
+ )
556
+
557
+ all_special_ids = self.all_special_ids # cache the property
558
+
559
+ special_tokens_mask = [
560
+ 1 if token in all_special_ids else 0 for token in token_ids_0
561
+ ]
562
+
563
+ return special_tokens_mask
564
+
565
+ def convert_tokens_to_ids(
566
+ self, tokens: Union[str, List[str]]
567
+ ) -> Union[int, List[int]]:
568
+ """
569
+ Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the
570
+ vocabulary.
571
+ Args:
572
+ tokens (:obj:`str` or :obj:`List[str]`): One or several token(s) to convert to token id(s).
573
+ Returns:
574
+ :obj:`int` or :obj:`List[int]`: The token id or list of token ids.
575
+ """
576
+ if tokens is None:
577
+ return None
578
+
579
+ if isinstance(tokens, str):
580
+ return self._convert_token_to_id_with_added_voc(tokens)
581
+
582
+ ids = []
583
+ for token in tokens:
584
+ ids.append(self._convert_token_to_id_with_added_voc(token))
585
+ return ids
586
+
587
+ def _convert_token_to_id_with_added_voc(self, token):
588
+ if token is None:
589
+ return None
590
+
591
+ return self.token_dictionary.get(token)
592
+
593
+ def __len__(self):
594
+ return len(self.token_dictionary)
595
+
596
+
597
+ class GeneformerPretrainer(Trainer):
598
+ def __init__(self, *args, **kwargs):
599
+ data_collator = kwargs.get("data_collator",None)
600
+ token_dictionary = kwargs.pop("token_dictionary")
601
+
602
+ if data_collator is None:
603
+ precollator = GeneformerPreCollator(token_dictionary=token_dictionary)
604
+
605
+ # # Data Collator Functions
606
+ data_collator = DataCollatorForLanguageModeling(
607
+ tokenizer=precollator, mlm=True, mlm_probability=0.15
608
+ )
609
+ kwargs["data_collator"] = data_collator
610
+
611
+ # load previously saved length vector for dataset to speed up LengthGroupedSampler
612
+ # pre-obtained with [dataset[i]["length"] for i in range(len(dataset))]
613
+ example_lengths_file = kwargs.pop("example_lengths_file")
614
+ if example_lengths_file:
615
+ with open(example_lengths_file, "rb") as f:
616
+ self.example_lengths = pickle.load(f)
617
+ else:
618
+ raise Exception(
619
+ "example_lengths_file is required; e.g. https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/genecorpus_30M_2048_sorted_lengths.pkl"
620
+ )
621
+ super().__init__(*args, **kwargs)
622
+
623
+ # modify LengthGroupedSampler to avoid dataset[length_column_name] hanging
624
+ def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
625
+ if not isinstance(self.train_dataset, collections.abc.Sized):
626
+ return None
627
+
628
+ generator = None
629
+ if self.args.world_size <= 1 and _is_torch_generator_available:
630
+ generator = torch.Generator()
631
+ generator.manual_seed(
632
+ int(torch.empty((), dtype=torch.int64).random_().item())
633
+ )
634
+
635
+ # Build the sampler.
636
+ if self.args.group_by_length:
637
+ if is_datasets_available() and isinstance(self.train_dataset, Dataset):
638
+ lengths = self.example_lengths
639
+ else:
640
+ lengths = None
641
+ model_input_name = (
642
+ self.tokenizer.model_input_names[0]
643
+ if self.tokenizer is not None
644
+ else None
645
+ )
646
+ if self.args.world_size <= 1:
647
+ return LengthGroupedSampler(
648
+ dataset=self.train_dataset,
649
+ batch_size=self.args.train_batch_size,
650
+ lengths=lengths,
651
+ model_input_name=model_input_name,
652
+ generator=generator,
653
+ )
654
+ else:
655
+ return CustomDistributedLengthGroupedSampler(
656
+ dataset=self.train_dataset,
657
+ batch_size=self.args.train_batch_size,
658
+ num_replicas=self.args.world_size,
659
+ rank=self.args.process_index,
660
+ lengths=lengths,
661
+ model_input_name=model_input_name,
662
+ seed=self.args.seed,
663
+ )
664
+
665
+ else:
666
+ if self.args.world_size <= 1:
667
+ if _is_torch_generator_available:
668
+ return RandomSampler(self.train_dataset, generator=generator)
669
+ return RandomSampler(self.train_dataset)
670
+ elif (
671
+ self.args.parallel_mode
672
+ in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL]
673
+ and not self.args.dataloader_drop_last
674
+ ):
675
+ # Use a loop for TPUs when drop_last is False to have all batches have the same size.
676
+ return DistributedSamplerWithLoop(
677
+ self.train_dataset,
678
+ batch_size=self.args.per_device_train_batch_size,
679
+ num_replicas=self.args.world_size,
680
+ rank=self.args.process_index,
681
+ seed=self.args.seed,
682
+ )
683
+ else:
684
+ return DistributedSampler(
685
+ self.train_dataset,
686
+ num_replicas=self.args.world_size,
687
+ rank=self.args.process_index,
688
+ seed=self.args.seed,
689
+ )
690
+
691
+
692
+ class CustomDistributedLengthGroupedSampler(DistributedLengthGroupedSampler):
693
+ r"""
694
+ Distributed Sampler that samples indices in a way that groups together features of the dataset of roughly the same
695
+ length while keeping a bit of randomness.
696
+ """
697
+ # Copied and adapted from PyTorch DistributedSampler.
698
+ def __init__(
699
+ self,
700
+ dataset: Dataset,
701
+ batch_size: int,
702
+ num_replicas: Optional[int] = None,
703
+ rank: Optional[int] = None,
704
+ seed: int = 0,
705
+ drop_last: bool = False,
706
+ lengths: Optional[List[int]] = None,
707
+ model_input_name: Optional[str] = None,
708
+ ):
709
+ if num_replicas is None:
710
+ if not dist.is_available():
711
+ raise RuntimeError("Requires distributed package to be available")
712
+ num_replicas = dist.get_world_size()
713
+ if rank is None:
714
+ if not dist.is_available():
715
+ raise RuntimeError("Requires distributed package to be available")
716
+ rank = dist.get_rank()
717
+ self.dataset = dataset
718
+ self.batch_size = batch_size
719
+ self.num_replicas = num_replicas
720
+ self.rank = rank
721
+ self.epoch = 0
722
+ self.drop_last = drop_last
723
+ # If the dataset length is evenly divisible by # of replicas, then there
724
+ # is no need to drop any data, since the dataset will be split equally.
725
+ if self.drop_last and len(self.dataset) % self.num_replicas != 0:
726
+ # Split to nearest available length that is evenly divisible.
727
+ # This is to ensure each rank receives the same amount of data when
728
+ # using this Sampler.
729
+ self.num_samples = math.ceil(
730
+ (len(self.dataset) - self.num_replicas) / self.num_replicas
731
+ )
732
+ else:
733
+ self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
734
+ self.total_size = self.num_samples * self.num_replicas
735
+ self.seed = seed
736
+ self.model_input_name = (
737
+ model_input_name if model_input_name is not None else "input_ids"
738
+ )
739
+
740
+ if lengths is None:
741
+ print("Lengths is none - calculating lengths.")
742
+ if (
743
+ not (
744
+ isinstance(dataset[0], dict)
745
+ or isinstance(dataset[0], BatchEncoding)
746
+ )
747
+ or self.model_input_name not in dataset[0]
748
+ ):
749
+ raise ValueError(
750
+ "Can only automatically infer lengths for datasets whose items are dictionaries with an "
751
+ f"'{self.model_input_name}' key."
752
+ )
753
+ lengths = [len(feature[self.model_input_name]) for feature in dataset]
754
+ self.lengths = lengths
755
+
756
+ def __iter__(self) -> Iterator:
757
+ # Deterministically shuffle based on epoch and seed
758
+ g = torch.Generator()
759
+ g.manual_seed(self.seed + self.epoch)
760
+
761
+ indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=g)
762
+
763
+ if not self.drop_last:
764
+ # add extra samples to make it evenly divisible
765
+ indices += indices[: (self.total_size - len(indices))]
766
+ else:
767
+ # remove tail of data to make it evenly divisible.
768
+ indices = indices[: self.total_size]
769
+ assert len(indices) == self.total_size
770
+
771
+ # subsample
772
+ indices = indices[self.rank : self.total_size : self.num_replicas]
773
+ assert len(indices) == self.num_samples
774
+
775
+ return iter(indices)
776
+
777
+
778
+ def get_length_grouped_indices(
779
+ lengths, batch_size, mega_batch_mult=None, generator=None
780
+ ):
781
+ """
782
+ Return a list of indices so that each slice of :obj:`batch_size` consecutive indices correspond to elements of
783
+ similar lengths. To do this, the indices are:
784
+
785
+ - randomly permuted
786
+ - grouped in mega-batches of size :obj:`mega_batch_mult * batch_size`
787
+ - sorted by length in each mega-batch
788
+
789
+ The result is the concatenation of all mega-batches, with the batch of :obj:`batch_size` containing the element of
790
+ maximum length placed first, so that an OOM happens sooner rather than later.
791
+ """
792
+ # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
793
+ if mega_batch_mult is None:
794
+ # mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
795
+ mega_batch_mult = min(len(lengths) // (batch_size * 4), 1000)
796
+ # Just in case, for tiny datasets
797
+ if mega_batch_mult == 0:
798
+ mega_batch_mult = 1
799
+
800
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
801
+ indices = torch.randperm(len(lengths), generator=generator)
802
+ megabatch_size = mega_batch_mult * batch_size
803
+ megabatches = [
804
+ indices[i : i + megabatch_size].tolist()
805
+ for i in range(0, len(lengths), megabatch_size)
806
+ ]
807
+ megabatches = [
808
+ list(sorted(megabatch, key=lambda i: lengths[i], reverse=True))
809
+ for megabatch in megabatches
810
+ ]
811
+
812
+ # The rest is to get the biggest batch first.
813
+ # Since each megabatch is sorted by descending length, the longest element is the first
814
+ megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
815
+ max_idx = torch.argmax(torch.tensor(megabatch_maximums)).item()
816
+ # Switch to put the longest element in first position
817
+ megabatches[0][0], megabatches[max_idx][0] = (
818
+ megabatches[max_idx][0],
819
+ megabatches[0][0],
820
+ )
821
+
822
+ return [item for sublist in megabatches for item in sublist]
geneformer/token_dictionary.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fcf53d1c87c08786f73aaf7c09da9778bfb8299e86b03411daa4143ac64ac0a7
3
+ size 270111
geneformer/tokenizer.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Geneformer tokenizer.
3
+
4
+ Input data:
5
+ Required format: raw counts scRNAseq data without feature selection as .loom file
6
+ Required row (gene) attribute: "ensembl_id"; Ensembl ID for each gene
7
+ Required col (cell) attribute: "n_counts"; total read counts in that cell
8
+ Optional col (cell) attribute: "filter_pass"; binary indicator of whether cell should be tokenized based on user-defined filtering criteria
9
+ Optional col (cell) attributes: any other cell metadata can be passed on to the tokenized dataset as a custom attribute dictionary as shown below
10
+
11
+ Usage:
12
+ from geneformer import TranscriptomeTokenizer
13
+ tk = TranscriptomeTokenizer({"cell_type": "cell_type", "organ_major": "organ_major"}, nproc=4)
14
+ tk.tokenize_data("loom_data_directory", "output_directory", "output_prefix")
15
+ """
16
+
17
+ import pickle
18
+ from pathlib import Path
19
+
20
+ import logging
21
+
22
+ import warnings
23
+ warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*")
24
+
25
+ import loompy as lp
26
+ import numpy as np
27
+ from datasets import Dataset
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+ GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl"
32
+ TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl"
33
+
34
+
35
+ def tokenize_cell(gene_vector, gene_tokens):
36
+ """
37
+ Convert normalized gene expression vector to tokenized rank value encoding.
38
+ """
39
+ # create array of gene vector with token indices
40
+ # mask undetected genes
41
+ nonzero_mask = np.nonzero(gene_vector)[0]
42
+ # sort by median-scaled gene values
43
+ sorted_indices = np.argsort(-gene_vector[nonzero_mask])
44
+ # tokenize
45
+ sentence_tokens = gene_tokens[nonzero_mask][sorted_indices]
46
+ return sentence_tokens
47
+
48
+
49
+ class TranscriptomeTokenizer:
50
+ def __init__(
51
+ self,
52
+ custom_attr_name_dict=None,
53
+ nproc=1,
54
+ gene_median_file=GENE_MEDIAN_FILE,
55
+ token_dictionary_file=TOKEN_DICTIONARY_FILE,
56
+ ):
57
+ """
58
+ Initialize tokenizer.
59
+
60
+ Parameters
61
+ ----------
62
+ custom_attr_name_dict : None, dict
63
+ Dictionary of custom attributes to be added to the dataset.
64
+ Keys are the names of the attributes in the loom file.
65
+ Values are the names of the attributes in the dataset.
66
+ nproc : int
67
+ Number of processes to use for dataset mapping.
68
+ gene_median_file : Path
69
+ Path to pickle file containing dictionary of non-zero median
70
+ gene expression values across Genecorpus-30M.
71
+ token_dictionary_file : Path
72
+ Path to pickle file containing token dictionary (Ensembl IDs:token).
73
+ """
74
+ # dictionary of custom attributes {output dataset column name: input .loom column name}
75
+ self.custom_attr_name_dict = custom_attr_name_dict
76
+
77
+ # number of processes for dataset mapping
78
+ self.nproc = nproc
79
+
80
+ # load dictionary of gene normalization factors
81
+ # (non-zero median value of expression across Genecorpus-30M)
82
+ with open(gene_median_file, "rb") as f:
83
+ self.gene_median_dict = pickle.load(f)
84
+
85
+ # load token dictionary (Ensembl IDs:token)
86
+ with open(token_dictionary_file, "rb") as f:
87
+ self.gene_token_dict = pickle.load(f)
88
+
89
+ # gene keys for full vocabulary
90
+ self.gene_keys = list(self.gene_median_dict.keys())
91
+
92
+ # protein-coding and miRNA gene list dictionary for selecting .loom rows for tokenization
93
+ self.genelist_dict = dict(zip(self.gene_keys, [True] * len(self.gene_keys)))
94
+
95
+ def tokenize_data(self, loom_data_directory, output_directory, output_prefix):
96
+ """
97
+ Tokenize .loom files in loom_data_directory and save as tokenized .dataset in output_directory.
98
+
99
+ Parameters
100
+ ----------
101
+ loom_data_directory : Path
102
+ Path to directory containing loom files
103
+ output_directory : Path
104
+ Path to directory where tokenized data will be saved as .dataset
105
+ output_prefix : str
106
+ Prefix for output .dataset
107
+ """
108
+ tokenized_cells, cell_metadata = self.tokenize_files(Path(loom_data_directory))
109
+ tokenized_dataset = self.create_dataset(tokenized_cells, cell_metadata)
110
+
111
+ output_path = (Path(output_directory) / output_prefix).with_suffix(".dataset")
112
+ tokenized_dataset.save_to_disk(output_path)
113
+
114
+ def tokenize_files(self, loom_data_directory):
115
+ tokenized_cells = []
116
+ if self.custom_attr_name_dict is not None:
117
+ loom_cell_attr = [attr_key for attr_key in self.custom_attr_name_dict.keys()]
118
+ cell_metadata = {attr_key: [] for attr_key in self.custom_attr_name_dict.values()}
119
+
120
+ # loops through directories to tokenize .loom files
121
+ file_found = 0
122
+ for loom_file_path in loom_data_directory.glob("*.loom"):
123
+ file_found = 1
124
+ print(f"Tokenizing {loom_file_path}")
125
+ file_tokenized_cells, file_cell_metadata = self.tokenize_file(
126
+ loom_file_path
127
+ )
128
+ tokenized_cells += file_tokenized_cells
129
+ if self.custom_attr_name_dict is not None:
130
+ for k in loom_cell_attr:
131
+ cell_metadata[self.custom_attr_name_dict[k]] += file_cell_metadata[k]
132
+ else:
133
+ cell_metadata = None
134
+
135
+ if file_found == 0:
136
+ logger.error(
137
+ f"No .loom files found in directory {loom_data_directory}.")
138
+ raise
139
+ return tokenized_cells, cell_metadata
140
+
141
+ def tokenize_file(self, loom_file_path):
142
+ if self.custom_attr_name_dict is not None:
143
+ file_cell_metadata = {
144
+ attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
145
+ }
146
+
147
+ with lp.connect(str(loom_file_path)) as data:
148
+ # define coordinates of detected protein-coding or miRNA genes and vector of their normalization factors
149
+ coding_miRNA_loc = np.where(
150
+ [self.genelist_dict.get(i, False) for i in data.ra["ensembl_id"]]
151
+ )[0]
152
+ norm_factor_vector = np.array(
153
+ [
154
+ self.gene_median_dict[i]
155
+ for i in data.ra["ensembl_id"][coding_miRNA_loc]
156
+ ]
157
+ )
158
+ coding_miRNA_ids = data.ra["ensembl_id"][coding_miRNA_loc]
159
+ coding_miRNA_tokens = np.array(
160
+ [self.gene_token_dict[i] for i in coding_miRNA_ids]
161
+ )
162
+
163
+ # define coordinates of cells passing filters for inclusion (e.g. QC)
164
+ try:
165
+ data.ca["filter_pass"]
166
+ except AttributeError:
167
+ var_exists = False
168
+ else:
169
+ var_exists = True
170
+
171
+ if var_exists is True:
172
+ filter_pass_loc = np.where(
173
+ [True if i == 1 else False for i in data.ca["filter_pass"]]
174
+ )[0]
175
+ elif var_exists is False:
176
+ print(
177
+ f"{loom_file_path} has no column attribute 'filter_pass'; tokenizing all cells."
178
+ )
179
+ filter_pass_loc = np.array([i for i in range(data.shape[1])])
180
+
181
+ # scan through .loom files and tokenize cells
182
+ tokenized_cells = []
183
+ for (_ix, _selection, view) in data.scan(items=filter_pass_loc, axis=1):
184
+ # select subview with protein-coding and miRNA genes
185
+ subview = view.view[coding_miRNA_loc, :]
186
+
187
+ # normalize by total counts per cell and multiply by 10,000 to allocate bits to precision
188
+ # and normalize by gene normalization factors
189
+ subview_norm_array = (
190
+ subview[:, :]
191
+ / subview.ca.n_counts
192
+ * 10_000
193
+ / norm_factor_vector[:, None]
194
+ )
195
+ # tokenize subview gene vectors
196
+ tokenized_cells += [
197
+ tokenize_cell(subview_norm_array[:, i], coding_miRNA_tokens)
198
+ for i in range(subview_norm_array.shape[1])
199
+ ]
200
+
201
+ # add custom attributes for subview to dict
202
+ if self.custom_attr_name_dict is not None:
203
+ for k in file_cell_metadata.keys():
204
+ file_cell_metadata[k] += subview.ca[k].tolist()
205
+ else:
206
+ file_cell_metadata = None
207
+
208
+ return tokenized_cells, file_cell_metadata
209
+
210
+ def create_dataset(self, tokenized_cells, cell_metadata):
211
+ # create dict for dataset creation
212
+ dataset_dict = {"input_ids": tokenized_cells}
213
+ if self.custom_attr_name_dict is not None:
214
+ dataset_dict.update(cell_metadata)
215
+
216
+ # create dataset
217
+ output_dataset = Dataset.from_dict(dataset_dict)
218
+
219
+ # truncate dataset
220
+ def truncate(example):
221
+ example["input_ids"] = example["input_ids"][0:2048]
222
+ return example
223
+
224
+ output_dataset_truncated = output_dataset.map(truncate, num_proc=self.nproc)
225
+
226
+ # measure lengths of dataset
227
+ def measure_length(example):
228
+ example["length"] = len(example["input_ids"])
229
+ return example
230
+
231
+ output_dataset_truncated_w_length = output_dataset_truncated.map(
232
+ measure_length, num_proc=self.nproc
233
+ )
234
+
235
+ return output_dataset_truncated_w_length
generation_config.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "pad_token_id": 0,
4
+ "transformers_version": "4.32.0"
5
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:199d33652d295dfe6ef97b3d3dccdc2f528931ffbe683243ec5a70842637e329
3
+ size 31494773
setup.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup
2
+
3
+ setup(
4
+ name="geneformer",
5
+ version="0.0.1",
6
+ author="Christina Theodoris",
7
+ author_email="christina.theodoris@gladstone.ucsf.edu",
8
+ description="Geneformer is a transformer model pretrained \
9
+ on a large-scale corpus of ~30 million single \
10
+ cell transcriptomes to enable context-aware \
11
+ predictions in settings with limited data in \
12
+ network biology.",
13
+ packages=["geneformer"],
14
+ include_package_data=True,
15
+ install_requires=[
16
+ "datasets",
17
+ "loompy",
18
+ "numpy",
19
+ "transformers",
20
+ ],
21
+ )
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6aa4702ebe332247df4beb04ad957645d492c05ebca9f9b600770a7c658e7800
3
+ size 4219