Update geneformer/perturber_utils.py

#353
by hchen725 - opened
Files changed (1) hide show
  1. geneformer/perturber_utils.py +18 -63
geneformer/perturber_utils.py CHANGED
@@ -4,8 +4,6 @@ import pickle
4
  import re
5
  from collections import defaultdict
6
  from typing import List
7
- from pathlib import Path
8
-
9
 
10
  import numpy as np
11
  import pandas as pd
@@ -18,11 +16,6 @@ from transformers import (
18
  BertForTokenClassification,
19
  )
20
 
21
- GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl"
22
- TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl"
23
- ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl"
24
-
25
-
26
  sns.set()
27
 
28
  logger = logging.getLogger(__name__)
@@ -218,26 +211,35 @@ def delete_indices(example):
218
 
219
 
220
  # for genes_to_perturb = "all" where only genes within cell are overexpressed
221
- def overexpress_indices(example):
222
  indices = example["perturb_index"]
223
  if any(isinstance(el, list) for el in indices):
224
  indices = flatten_list(indices)
225
  for index in sorted(indices, reverse=True):
226
- example["input_ids"].insert(0, example["input_ids"].pop(index))
 
 
 
227
 
228
  example["length"] = len(example["input_ids"])
229
  return example
230
 
231
 
232
  # for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
233
- def overexpress_tokens(example, max_len):
234
  # -100 indicates tokens to overexpress are not present in rank value encoding
235
  if example["perturb_index"] != [-100]:
236
  example = delete_indices(example)
237
- [
238
- example["input_ids"].insert(0, token)
239
- for token in example["tokens_to_perturb"][::-1]
240
- ]
 
 
 
 
 
 
241
 
242
  # truncate to max input size, must also truncate original emb to be comparable
243
  if len(example["input_ids"]) > max_len:
@@ -588,11 +590,9 @@ def quant_cos_sims(
588
  elif emb_mode == "cell":
589
  cos = torch.nn.CosineSimilarity(dim=1)
590
 
591
- # if emb_mode == "gene", can only calculate gene cos sims
592
- # against original cell anyways
593
- if cell_states_to_model is None or emb_mode == "gene":
594
  cos_sims = cos(perturbation_emb, original_emb).to("cuda")
595
- elif cell_states_to_model is not None and emb_mode == "cell":
596
  possible_states = get_possible_states(cell_states_to_model)
597
  cos_sims = dict(zip(possible_states, [[] for _ in range(len(possible_states))]))
598
  for state in possible_states:
@@ -714,48 +714,3 @@ def validate_cell_states_to_model(cell_states_to_model):
714
  "'alt_states': ['hcm', 'other1', 'other2']}"
715
  )
716
  raise
717
-
718
- class GeneIdHandler:
719
- def __init__(self, raise_errors=False):
720
- def invert_dict(dict_obj):
721
- return {v:k for k,v in dict_obj.items()}
722
-
723
- self.raise_errors = raise_errors
724
-
725
- with open(TOKEN_DICTIONARY_FILE, 'rb') as f:
726
- self.gene_token_dict = pickle.load(f)
727
- self.token_gene_dict = invert_dict(self.gene_token_dict)
728
-
729
- with open(ENSEMBL_DICTIONARY_FILE, 'rb') as f:
730
- self.id_gene_dict = pickle.load(f)
731
- self.gene_id_dict = invert_dict(self.id_gene_dict)
732
-
733
- def ens_to_token(self, ens_id):
734
- if not self.raise_errors:
735
- return self.gene_token_dict.get(ens_id, ens_id)
736
- else:
737
- return self.gene_token_dict[ens_id]
738
-
739
- def token_to_ens(self, token):
740
- if not self.raise_errors:
741
- return self.token_gene_dict.get(token, token)
742
- else:
743
- return self.token_gene_dict[token]
744
-
745
- def ens_to_symbol(self, ens_id):
746
- if not self.raise_errors:
747
- return self.gene_id_dict.get(ens_id, ens_id)
748
- else:
749
- return self.gene_id_dict[ens_id]
750
-
751
- def symbol_to_ens(self, symbol):
752
- if not self.raise_errors:
753
- return self.id_gene_dict.get(symbol, symbol)
754
- else:
755
- return self.id_gene_dict[symbol]
756
-
757
- def token_to_symbol(self, token):
758
- return self.ens_to_symbol(self.token_to_ens(token))
759
-
760
- def symbol_to_token(self, symbol):
761
- return self.ens_to_token(self.symbol_to_ens(symbol))
 
4
  import re
5
  from collections import defaultdict
6
  from typing import List
 
 
7
 
8
  import numpy as np
9
  import pandas as pd
 
16
  BertForTokenClassification,
17
  )
18
 
 
 
 
 
 
19
  sns.set()
20
 
21
  logger = logging.getLogger(__name__)
 
211
 
212
 
213
  # for genes_to_perturb = "all" where only genes within cell are overexpressed
214
+ def overexpress_indices(example, special_token):
215
  indices = example["perturb_index"]
216
  if any(isinstance(el, list) for el in indices):
217
  indices = flatten_list(indices)
218
  for index in sorted(indices, reverse=True):
219
+ if special_token:
220
+ example["input_ids"].insert(1, example["input_ids"].pop(index))
221
+ else:
222
+ example["input_ids"].insert(0, example["input_ids"].pop(index))
223
 
224
  example["length"] = len(example["input_ids"])
225
  return example
226
 
227
 
228
  # for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
229
+ def overexpress_tokens(example, max_len, special_token):
230
  # -100 indicates tokens to overexpress are not present in rank value encoding
231
  if example["perturb_index"] != [-100]:
232
  example = delete_indices(example)
233
+ if special_token:
234
+ [
235
+ example["input_ids"].insert(1, token)
236
+ for token in example["tokens_to_perturb"][::-1]
237
+ ]
238
+ else:
239
+ [
240
+ example["input_ids"].insert(0, token)
241
+ for token in example["tokens_to_perturb"][::-1]
242
+ ]
243
 
244
  # truncate to max input size, must also truncate original emb to be comparable
245
  if len(example["input_ids"]) > max_len:
 
590
  elif emb_mode == "cell":
591
  cos = torch.nn.CosineSimilarity(dim=1)
592
 
593
+ if cell_states_to_model is None:
 
 
594
  cos_sims = cos(perturbation_emb, original_emb).to("cuda")
595
+ else:
596
  possible_states = get_possible_states(cell_states_to_model)
597
  cos_sims = dict(zip(possible_states, [[] for _ in range(len(possible_states))]))
598
  for state in possible_states:
 
714
  "'alt_states': ['hcm', 'other1', 'other2']}"
715
  )
716
  raise