Update geneformer/perturber_utils.py
#353
by
hchen725
- opened
- geneformer/perturber_utils.py +18 -63
geneformer/perturber_utils.py
CHANGED
@@ -4,8 +4,6 @@ import pickle
|
|
4 |
import re
|
5 |
from collections import defaultdict
|
6 |
from typing import List
|
7 |
-
from pathlib import Path
|
8 |
-
|
9 |
|
10 |
import numpy as np
|
11 |
import pandas as pd
|
@@ -18,11 +16,6 @@ from transformers import (
|
|
18 |
BertForTokenClassification,
|
19 |
)
|
20 |
|
21 |
-
GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl"
|
22 |
-
TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl"
|
23 |
-
ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl"
|
24 |
-
|
25 |
-
|
26 |
sns.set()
|
27 |
|
28 |
logger = logging.getLogger(__name__)
|
@@ -218,26 +211,35 @@ def delete_indices(example):
|
|
218 |
|
219 |
|
220 |
# for genes_to_perturb = "all" where only genes within cell are overexpressed
|
221 |
-
def overexpress_indices(example):
|
222 |
indices = example["perturb_index"]
|
223 |
if any(isinstance(el, list) for el in indices):
|
224 |
indices = flatten_list(indices)
|
225 |
for index in sorted(indices, reverse=True):
|
226 |
-
|
|
|
|
|
|
|
227 |
|
228 |
example["length"] = len(example["input_ids"])
|
229 |
return example
|
230 |
|
231 |
|
232 |
# for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
|
233 |
-
def overexpress_tokens(example, max_len):
|
234 |
# -100 indicates tokens to overexpress are not present in rank value encoding
|
235 |
if example["perturb_index"] != [-100]:
|
236 |
example = delete_indices(example)
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
|
242 |
# truncate to max input size, must also truncate original emb to be comparable
|
243 |
if len(example["input_ids"]) > max_len:
|
@@ -588,11 +590,9 @@ def quant_cos_sims(
|
|
588 |
elif emb_mode == "cell":
|
589 |
cos = torch.nn.CosineSimilarity(dim=1)
|
590 |
|
591 |
-
|
592 |
-
# against original cell anyways
|
593 |
-
if cell_states_to_model is None or emb_mode == "gene":
|
594 |
cos_sims = cos(perturbation_emb, original_emb).to("cuda")
|
595 |
-
|
596 |
possible_states = get_possible_states(cell_states_to_model)
|
597 |
cos_sims = dict(zip(possible_states, [[] for _ in range(len(possible_states))]))
|
598 |
for state in possible_states:
|
@@ -714,48 +714,3 @@ def validate_cell_states_to_model(cell_states_to_model):
|
|
714 |
"'alt_states': ['hcm', 'other1', 'other2']}"
|
715 |
)
|
716 |
raise
|
717 |
-
|
718 |
-
class GeneIdHandler:
|
719 |
-
def __init__(self, raise_errors=False):
|
720 |
-
def invert_dict(dict_obj):
|
721 |
-
return {v:k for k,v in dict_obj.items()}
|
722 |
-
|
723 |
-
self.raise_errors = raise_errors
|
724 |
-
|
725 |
-
with open(TOKEN_DICTIONARY_FILE, 'rb') as f:
|
726 |
-
self.gene_token_dict = pickle.load(f)
|
727 |
-
self.token_gene_dict = invert_dict(self.gene_token_dict)
|
728 |
-
|
729 |
-
with open(ENSEMBL_DICTIONARY_FILE, 'rb') as f:
|
730 |
-
self.id_gene_dict = pickle.load(f)
|
731 |
-
self.gene_id_dict = invert_dict(self.id_gene_dict)
|
732 |
-
|
733 |
-
def ens_to_token(self, ens_id):
|
734 |
-
if not self.raise_errors:
|
735 |
-
return self.gene_token_dict.get(ens_id, ens_id)
|
736 |
-
else:
|
737 |
-
return self.gene_token_dict[ens_id]
|
738 |
-
|
739 |
-
def token_to_ens(self, token):
|
740 |
-
if not self.raise_errors:
|
741 |
-
return self.token_gene_dict.get(token, token)
|
742 |
-
else:
|
743 |
-
return self.token_gene_dict[token]
|
744 |
-
|
745 |
-
def ens_to_symbol(self, ens_id):
|
746 |
-
if not self.raise_errors:
|
747 |
-
return self.gene_id_dict.get(ens_id, ens_id)
|
748 |
-
else:
|
749 |
-
return self.gene_id_dict[ens_id]
|
750 |
-
|
751 |
-
def symbol_to_ens(self, symbol):
|
752 |
-
if not self.raise_errors:
|
753 |
-
return self.id_gene_dict.get(symbol, symbol)
|
754 |
-
else:
|
755 |
-
return self.id_gene_dict[symbol]
|
756 |
-
|
757 |
-
def token_to_symbol(self, token):
|
758 |
-
return self.ens_to_symbol(self.token_to_ens(token))
|
759 |
-
|
760 |
-
def symbol_to_token(self, symbol):
|
761 |
-
return self.ens_to_token(self.symbol_to_ens(symbol))
|
|
|
4 |
import re
|
5 |
from collections import defaultdict
|
6 |
from typing import List
|
|
|
|
|
7 |
|
8 |
import numpy as np
|
9 |
import pandas as pd
|
|
|
16 |
BertForTokenClassification,
|
17 |
)
|
18 |
|
|
|
|
|
|
|
|
|
|
|
19 |
sns.set()
|
20 |
|
21 |
logger = logging.getLogger(__name__)
|
|
|
211 |
|
212 |
|
213 |
# for genes_to_perturb = "all" where only genes within cell are overexpressed
|
214 |
+
def overexpress_indices(example, special_token):
|
215 |
indices = example["perturb_index"]
|
216 |
if any(isinstance(el, list) for el in indices):
|
217 |
indices = flatten_list(indices)
|
218 |
for index in sorted(indices, reverse=True):
|
219 |
+
if special_token:
|
220 |
+
example["input_ids"].insert(1, example["input_ids"].pop(index))
|
221 |
+
else:
|
222 |
+
example["input_ids"].insert(0, example["input_ids"].pop(index))
|
223 |
|
224 |
example["length"] = len(example["input_ids"])
|
225 |
return example
|
226 |
|
227 |
|
228 |
# for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
|
229 |
+
def overexpress_tokens(example, max_len, special_token):
|
230 |
# -100 indicates tokens to overexpress are not present in rank value encoding
|
231 |
if example["perturb_index"] != [-100]:
|
232 |
example = delete_indices(example)
|
233 |
+
if special_token:
|
234 |
+
[
|
235 |
+
example["input_ids"].insert(1, token)
|
236 |
+
for token in example["tokens_to_perturb"][::-1]
|
237 |
+
]
|
238 |
+
else:
|
239 |
+
[
|
240 |
+
example["input_ids"].insert(0, token)
|
241 |
+
for token in example["tokens_to_perturb"][::-1]
|
242 |
+
]
|
243 |
|
244 |
# truncate to max input size, must also truncate original emb to be comparable
|
245 |
if len(example["input_ids"]) > max_len:
|
|
|
590 |
elif emb_mode == "cell":
|
591 |
cos = torch.nn.CosineSimilarity(dim=1)
|
592 |
|
593 |
+
if cell_states_to_model is None:
|
|
|
|
|
594 |
cos_sims = cos(perturbation_emb, original_emb).to("cuda")
|
595 |
+
else:
|
596 |
possible_states = get_possible_states(cell_states_to_model)
|
597 |
cos_sims = dict(zip(possible_states, [[] for _ in range(len(possible_states))]))
|
598 |
for state in possible_states:
|
|
|
714 |
"'alt_states': ['hcm', 'other1', 'other2']}"
|
715 |
)
|
716 |
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|