ctheodoris
commited on
Commit
•
3e24216
1
Parent(s):
d6e949b
Update geneformer/in_silico_perturber.py
Browse files
geneformer/in_silico_perturber.py
CHANGED
@@ -99,6 +99,7 @@ class InSilicoPerturber:
|
|
99 |
forward_batch_size=100,
|
100 |
nproc=4,
|
101 |
token_dictionary_file=None,
|
|
|
102 |
):
|
103 |
"""
|
104 |
Initialize in silico perturber.
|
@@ -187,6 +188,8 @@ class InSilicoPerturber:
|
|
187 |
| Number of CPU processes to use.
|
188 |
token_dictionary_file : Path
|
189 |
| Path to pickle file containing token dictionary (Ensembl ID:token).
|
|
|
|
|
190 |
"""
|
191 |
try:
|
192 |
set_start_method("spawn")
|
@@ -224,6 +227,7 @@ class InSilicoPerturber:
|
|
224 |
self.forward_batch_size = forward_batch_size
|
225 |
self.nproc = nproc
|
226 |
self.token_dictionary_file = token_dictionary_file
|
|
|
227 |
|
228 |
self.validate_options()
|
229 |
|
@@ -235,17 +239,16 @@ class InSilicoPerturber:
|
|
235 |
self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
|
236 |
|
237 |
self.pad_token_id = self.gene_token_dict.get("<pad>")
|
|
|
|
|
238 |
|
239 |
|
240 |
# Identify if special token is present in the token dictionary
|
241 |
-
|
242 |
-
cls_present = any("cls" in value for value in lowercase_token_gene_dict.values())
|
243 |
-
eos_present = any("eos" in value for value in lowercase_token_gene_dict.values())
|
244 |
-
if cls_present or eos_present:
|
245 |
self.special_token = True
|
246 |
else:
|
247 |
if "cls" in self.emb_mode:
|
248 |
-
logger.error(f"emb_mode set to {self.emb_mode} but <cls> token not in token dictionary.")
|
249 |
raise
|
250 |
self.special_token = False
|
251 |
|
@@ -454,12 +457,17 @@ class InSilicoPerturber:
|
|
454 |
|
455 |
# Ensure emb_mode is cls if first token of the filtered input data is cls token
|
456 |
if self.special_token:
|
457 |
-
|
458 |
-
if (filtered_input_data["input_ids"][0][0] == cls_token_id) and ("cls" not in self.emb_mode):
|
459 |
logger.error(
|
460 |
"Emb mode 'cls' or 'cls_and_gene' required when first token is <cls>."
|
461 |
)
|
462 |
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
463 |
|
464 |
filtered_input_data = self.apply_additional_filters(filtered_input_data)
|
465 |
|
@@ -554,6 +562,7 @@ class InSilicoPerturber:
|
|
554 |
perturbed_data = filtered_input_data.map(
|
555 |
make_group_perturbation_batch, num_proc=self.nproc
|
556 |
)
|
|
|
557 |
if self.perturb_type == "overexpress":
|
558 |
filtered_input_data = filtered_input_data.add_column(
|
559 |
"n_overflow", perturbed_data["n_overflow"]
|
@@ -572,7 +581,7 @@ class InSilicoPerturber:
|
|
572 |
pu.truncate_by_n_overflow, num_proc=self.nproc
|
573 |
)
|
574 |
|
575 |
-
if self.emb_mode == "cell_and_gene":
|
576 |
stored_gene_embs_dict = defaultdict(list)
|
577 |
|
578 |
# iterate through batches
|
@@ -618,20 +627,24 @@ class InSilicoPerturber:
|
|
618 |
|
619 |
if "cls" not in self.emb_mode:
|
620 |
start = 0
|
|
|
|
|
621 |
else:
|
622 |
start = 1
|
|
|
|
|
623 |
|
624 |
-
# remove overexpressed genes and cls
|
625 |
original_emb = original_emb[
|
626 |
-
:, start
|
627 |
]
|
628 |
if self.perturb_type == "overexpress":
|
629 |
perturbation_emb = full_perturbation_emb[
|
630 |
-
:, start+len(self.tokens_to_perturb)
|
631 |
]
|
632 |
elif self.perturb_type == "delete":
|
633 |
perturbation_emb = full_perturbation_emb[
|
634 |
-
:, start : max(perturbation_batch["length"]), :
|
635 |
]
|
636 |
|
637 |
n_perturbation_genes = perturbation_emb.size()[1]
|
@@ -640,6 +653,7 @@ class InSilicoPerturber:
|
|
640 |
if (
|
641 |
self.cell_states_to_model is None
|
642 |
or self.emb_mode == "cell_and_gene"
|
|
|
643 |
):
|
644 |
gene_cos_sims = pu.quant_cos_sims(
|
645 |
perturbation_emb,
|
@@ -677,18 +691,23 @@ class InSilicoPerturber:
|
|
677 |
|
678 |
# get cosine similarities in gene embeddings
|
679 |
# if getting gene embeddings, need gene names
|
680 |
-
if self.emb_mode == "cell_and_gene":
|
681 |
gene_list = minibatch["input_ids"]
|
682 |
# need to truncate gene_list
|
|
|
|
|
|
|
683 |
gene_list = [
|
684 |
-
[g for g in genes if g not in
|
685 |
:n_perturbation_genes
|
686 |
]
|
687 |
for genes in gene_list
|
688 |
]
|
689 |
-
# remove CLS if present
|
690 |
-
if "cls" in self.emb_mode:
|
691 |
-
|
|
|
|
|
692 |
|
693 |
for cell_i, genes in enumerate(gene_list):
|
694 |
for gene_j, affected_gene in enumerate(genes):
|
@@ -760,10 +779,9 @@ class InSilicoPerturber:
|
|
760 |
del full_perturbation_emb
|
761 |
del perturbation_emb
|
762 |
del cos_sims_data
|
763 |
-
if "cls" in self.emb_mode:
|
764 |
del original_cls_emb
|
765 |
del perturbation_cls_emb
|
766 |
-
del cls_cos_sims
|
767 |
|
768 |
torch.cuda.empty_cache()
|
769 |
|
@@ -772,7 +790,7 @@ class InSilicoPerturber:
|
|
772 |
f"{output_path_prefix}_cell_embs_dict_{self.tokens_to_perturb}",
|
773 |
)
|
774 |
|
775 |
-
if self.emb_mode == "cell_and_gene":
|
776 |
pu.write_perturbation_dictionary(
|
777 |
stored_gene_embs_dict,
|
778 |
f"{output_path_prefix}_gene_embs_dict_{self.tokens_to_perturb}",
|
@@ -794,7 +812,7 @@ class InSilicoPerturber:
|
|
794 |
for state in pu.get_possible_states(self.cell_states_to_model)
|
795 |
}
|
796 |
|
797 |
-
if self.emb_mode == "cell_and_gene":
|
798 |
stored_gene_embs_dict = defaultdict(list)
|
799 |
for i in trange(len(filtered_input_data)):
|
800 |
example_cell = filtered_input_data.select([i])
|
@@ -840,27 +858,31 @@ class InSilicoPerturber:
|
|
840 |
)
|
841 |
|
842 |
num_inds_perturbed = 1 + self.combos
|
843 |
-
|
|
|
844 |
if "cls" not in self.emb_mode:
|
845 |
start = 0
|
|
|
846 |
else:
|
847 |
start = 1
|
|
|
848 |
if self.perturb_type == "overexpress":
|
849 |
-
perturbation_emb = full_perturbation_emb[:, start+num_inds_perturbed
|
850 |
gene_list = gene_list[
|
851 |
-
start+num_inds_perturbed:
|
852 |
-
] # cls and index 0 is not overexpressed
|
853 |
|
854 |
elif self.perturb_type == "delete":
|
855 |
-
perturbation_emb = full_perturbation_emb[:, start
|
856 |
-
gene_list = gene_list[start:]
|
857 |
|
858 |
-
full_original_emb = full_original_emb[:, start:, :]
|
859 |
original_batch = pu.make_comparison_batch(
|
860 |
full_original_emb, indices_to_perturb, perturb_group=False
|
861 |
)
|
862 |
|
863 |
-
|
|
|
|
|
864 |
gene_cos_sims = pu.quant_cos_sims(
|
865 |
perturbation_emb,
|
866 |
original_batch,
|
@@ -890,7 +912,7 @@ class InSilicoPerturber:
|
|
890 |
emb_mode="cell",
|
891 |
)
|
892 |
|
893 |
-
if self.emb_mode == "cell_and_gene":
|
894 |
# remove perturbed index for gene list
|
895 |
perturbed_gene_dict = {
|
896 |
gene: gene_list[:i] + gene_list[i + 1 :]
|
@@ -942,19 +964,19 @@ class InSilicoPerturber:
|
|
942 |
)
|
943 |
|
944 |
# save dict to disk every 100 cells
|
945 |
-
if i %
|
946 |
pu.write_perturbation_dictionary(
|
947 |
cos_sims_dict,
|
948 |
f"{output_path_prefix}_dict_cell_embs_1Kbatch{pickle_batch}",
|
949 |
)
|
950 |
-
if self.emb_mode == "cell_and_gene":
|
951 |
pu.write_perturbation_dictionary(
|
952 |
stored_gene_embs_dict,
|
953 |
f"{output_path_prefix}_dict_gene_embs_1Kbatch{pickle_batch}",
|
954 |
)
|
955 |
|
956 |
# reset and clear memory every 1000 cells
|
957 |
-
if i %
|
958 |
pickle_batch += 1
|
959 |
if self.cell_states_to_model is None:
|
960 |
cos_sims_dict = defaultdict(list)
|
@@ -964,7 +986,7 @@ class InSilicoPerturber:
|
|
964 |
for state in pu.get_possible_states(self.cell_states_to_model)
|
965 |
}
|
966 |
|
967 |
-
if self.emb_mode == "cell_and_gene":
|
968 |
stored_gene_embs_dict = defaultdict(list)
|
969 |
|
970 |
torch.cuda.empty_cache()
|
@@ -973,7 +995,7 @@ class InSilicoPerturber:
|
|
973 |
cos_sims_dict, f"{output_path_prefix}_dict_cell_embs_1Kbatch{pickle_batch}"
|
974 |
)
|
975 |
|
976 |
-
if self.emb_mode == "cell_and_gene":
|
977 |
pu.write_perturbation_dictionary(
|
978 |
stored_gene_embs_dict,
|
979 |
f"{output_path_prefix}_dict_gene_embs_1Kbatch{pickle_batch}",
|
|
|
99 |
forward_batch_size=100,
|
100 |
nproc=4,
|
101 |
token_dictionary_file=None,
|
102 |
+
clear_mem_ncells=1000,
|
103 |
):
|
104 |
"""
|
105 |
Initialize in silico perturber.
|
|
|
188 |
| Number of CPU processes to use.
|
189 |
token_dictionary_file : Path
|
190 |
| Path to pickle file containing token dictionary (Ensembl ID:token).
|
191 |
+
clear_mem_ncells : int
|
192 |
+
| Clear memory every n cells.
|
193 |
"""
|
194 |
try:
|
195 |
set_start_method("spawn")
|
|
|
227 |
self.forward_batch_size = forward_batch_size
|
228 |
self.nproc = nproc
|
229 |
self.token_dictionary_file = token_dictionary_file
|
230 |
+
self.clear_mem_ncells = clear_mem_ncells
|
231 |
|
232 |
self.validate_options()
|
233 |
|
|
|
239 |
self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
|
240 |
|
241 |
self.pad_token_id = self.gene_token_dict.get("<pad>")
|
242 |
+
self.cls_token_id = self.gene_token_dict.get("<cls>")
|
243 |
+
self.eos_token_id = self.gene_token_dict.get("<eos>")
|
244 |
|
245 |
|
246 |
# Identify if special token is present in the token dictionary
|
247 |
+
if (self.cls_token_id is not None) and (self.eos_token_id is not None):
|
|
|
|
|
|
|
248 |
self.special_token = True
|
249 |
else:
|
250 |
if "cls" in self.emb_mode:
|
251 |
+
logger.error(f"emb_mode set to {self.emb_mode} but <cls> or <eos> token not in token dictionary.")
|
252 |
raise
|
253 |
self.special_token = False
|
254 |
|
|
|
457 |
|
458 |
# Ensure emb_mode is cls if first token of the filtered input data is cls token
|
459 |
if self.special_token:
|
460 |
+
if (filtered_input_data["input_ids"][0][0] == self.cls_token_id) and ("cls" not in self.emb_mode):
|
|
|
461 |
logger.error(
|
462 |
"Emb mode 'cls' or 'cls_and_gene' required when first token is <cls>."
|
463 |
)
|
464 |
raise
|
465 |
+
if ("cls" in self.emb_mode):
|
466 |
+
if (filtered_input_data["input_ids"][0][0] != self.cls_token_id) or (filtered_input_data["input_ids"][0][-1] != self.eos_token_id):
|
467 |
+
logger.error(
|
468 |
+
"Emb mode 'cls' and 'cls_and_gene' require that first token is <cls> and last token is <eos>."
|
469 |
+
)
|
470 |
+
raise
|
471 |
|
472 |
filtered_input_data = self.apply_additional_filters(filtered_input_data)
|
473 |
|
|
|
562 |
perturbed_data = filtered_input_data.map(
|
563 |
make_group_perturbation_batch, num_proc=self.nproc
|
564 |
)
|
565 |
+
|
566 |
if self.perturb_type == "overexpress":
|
567 |
filtered_input_data = filtered_input_data.add_column(
|
568 |
"n_overflow", perturbed_data["n_overflow"]
|
|
|
581 |
pu.truncate_by_n_overflow, num_proc=self.nproc
|
582 |
)
|
583 |
|
584 |
+
if (self.emb_mode == "cell_and_gene") or (self.emb_mode == "cls_and_gene"):
|
585 |
stored_gene_embs_dict = defaultdict(list)
|
586 |
|
587 |
# iterate through batches
|
|
|
627 |
|
628 |
if "cls" not in self.emb_mode:
|
629 |
start = 0
|
630 |
+
end_add = 0
|
631 |
+
end = None
|
632 |
else:
|
633 |
start = 1
|
634 |
+
end_add = -1
|
635 |
+
end = -1
|
636 |
|
637 |
+
# remove overexpressed genes and cls/eos
|
638 |
original_emb = original_emb[
|
639 |
+
:, start : end, :
|
640 |
]
|
641 |
if self.perturb_type == "overexpress":
|
642 |
perturbation_emb = full_perturbation_emb[
|
643 |
+
:, start+len(self.tokens_to_perturb) : end, :
|
644 |
]
|
645 |
elif self.perturb_type == "delete":
|
646 |
perturbation_emb = full_perturbation_emb[
|
647 |
+
:, start : max(perturbation_batch["length"])+end_add, :
|
648 |
]
|
649 |
|
650 |
n_perturbation_genes = perturbation_emb.size()[1]
|
|
|
653 |
if (
|
654 |
self.cell_states_to_model is None
|
655 |
or self.emb_mode == "cell_and_gene"
|
656 |
+
or self.emb_mode == "cls_and_gene"
|
657 |
):
|
658 |
gene_cos_sims = pu.quant_cos_sims(
|
659 |
perturbation_emb,
|
|
|
691 |
|
692 |
# get cosine similarities in gene embeddings
|
693 |
# if getting gene embeddings, need gene names
|
694 |
+
if (self.emb_mode == "cell_and_gene") or (self.emb_mode == "cls_and_gene"):
|
695 |
gene_list = minibatch["input_ids"]
|
696 |
# need to truncate gene_list
|
697 |
+
genes_to_exclude = self.tokens_to_perturb
|
698 |
+
if self.emb_mode == "cls_and_gene":
|
699 |
+
genes_to_exclude = genes_to_exclude + [self.cls_token_id, self.eos_token_id]
|
700 |
gene_list = [
|
701 |
+
[g for g in genes if g not in genes_to_exclude][
|
702 |
:n_perturbation_genes
|
703 |
]
|
704 |
for genes in gene_list
|
705 |
]
|
706 |
+
# remove CLS and EOS if present
|
707 |
+
# if "cls" in self.emb_mode:
|
708 |
+
# cls_token_id = self.gene_token_dict["<cls>"]
|
709 |
+
# eos_token_id = self.gene_token_dict["<eos>"]
|
710 |
+
# gene_list = [e for e in gene_list if e not in [cls_token_id,eos_token_id]]
|
711 |
|
712 |
for cell_i, genes in enumerate(gene_list):
|
713 |
for gene_j, affected_gene in enumerate(genes):
|
|
|
779 |
del full_perturbation_emb
|
780 |
del perturbation_emb
|
781 |
del cos_sims_data
|
782 |
+
if ("cls" in self.emb_mode) and (self.cell_states_to_model is None):
|
783 |
del original_cls_emb
|
784 |
del perturbation_cls_emb
|
|
|
785 |
|
786 |
torch.cuda.empty_cache()
|
787 |
|
|
|
790 |
f"{output_path_prefix}_cell_embs_dict_{self.tokens_to_perturb}",
|
791 |
)
|
792 |
|
793 |
+
if (self.emb_mode == "cell_and_gene") or (self.emb_mode == "cls_and_gene"):
|
794 |
pu.write_perturbation_dictionary(
|
795 |
stored_gene_embs_dict,
|
796 |
f"{output_path_prefix}_gene_embs_dict_{self.tokens_to_perturb}",
|
|
|
812 |
for state in pu.get_possible_states(self.cell_states_to_model)
|
813 |
}
|
814 |
|
815 |
+
if (self.emb_mode == "cell_and_gene") or (self.emb_mode == "cls_and_gene"):
|
816 |
stored_gene_embs_dict = defaultdict(list)
|
817 |
for i in trange(len(filtered_input_data)):
|
818 |
example_cell = filtered_input_data.select([i])
|
|
|
858 |
)
|
859 |
|
860 |
num_inds_perturbed = 1 + self.combos
|
861 |
+
|
862 |
+
# need to remove overexpressed gene and cls/eos to quantify cosine shifts
|
863 |
if "cls" not in self.emb_mode:
|
864 |
start = 0
|
865 |
+
end = None
|
866 |
else:
|
867 |
start = 1
|
868 |
+
end = -1
|
869 |
if self.perturb_type == "overexpress":
|
870 |
+
perturbation_emb = full_perturbation_emb[:, start+num_inds_perturbed:end, :]
|
871 |
gene_list = gene_list[
|
872 |
+
start+num_inds_perturbed:end
|
873 |
+
] # cls/eos and index 0 is not overexpressed
|
874 |
|
875 |
elif self.perturb_type == "delete":
|
876 |
+
perturbation_emb = full_perturbation_emb[:, start:end, :]
|
877 |
+
gene_list = gene_list[start:end]
|
878 |
|
|
|
879 |
original_batch = pu.make_comparison_batch(
|
880 |
full_original_emb, indices_to_perturb, perturb_group=False
|
881 |
)
|
882 |
|
883 |
+
original_batch = original_batch[:, start:end, :]
|
884 |
+
|
885 |
+
if self.cell_states_to_model is None or self.emb_mode == "cell_and_gene" or self.emb_mode == "cls_and_gene":
|
886 |
gene_cos_sims = pu.quant_cos_sims(
|
887 |
perturbation_emb,
|
888 |
original_batch,
|
|
|
912 |
emb_mode="cell",
|
913 |
)
|
914 |
|
915 |
+
if (self.emb_mode == "cell_and_gene") or (self.emb_mode == "cls_and_gene"):
|
916 |
# remove perturbed index for gene list
|
917 |
perturbed_gene_dict = {
|
918 |
gene: gene_list[:i] + gene_list[i + 1 :]
|
|
|
964 |
)
|
965 |
|
966 |
# save dict to disk every 100 cells
|
967 |
+
if i % clear_mem_ncells/10 == 0:
|
968 |
pu.write_perturbation_dictionary(
|
969 |
cos_sims_dict,
|
970 |
f"{output_path_prefix}_dict_cell_embs_1Kbatch{pickle_batch}",
|
971 |
)
|
972 |
+
if (self.emb_mode == "cell_and_gene") or (self.emb_mode == "cls_and_gene"):
|
973 |
pu.write_perturbation_dictionary(
|
974 |
stored_gene_embs_dict,
|
975 |
f"{output_path_prefix}_dict_gene_embs_1Kbatch{pickle_batch}",
|
976 |
)
|
977 |
|
978 |
# reset and clear memory every 1000 cells
|
979 |
+
if i % clear_mem_ncells == 0:
|
980 |
pickle_batch += 1
|
981 |
if self.cell_states_to_model is None:
|
982 |
cos_sims_dict = defaultdict(list)
|
|
|
986 |
for state in pu.get_possible_states(self.cell_states_to_model)
|
987 |
}
|
988 |
|
989 |
+
if (self.emb_mode == "cell_and_gene") or (self.emb_mode == "cls_and_gene"):
|
990 |
stored_gene_embs_dict = defaultdict(list)
|
991 |
|
992 |
torch.cuda.empty_cache()
|
|
|
995 |
cos_sims_dict, f"{output_path_prefix}_dict_cell_embs_1Kbatch{pickle_batch}"
|
996 |
)
|
997 |
|
998 |
+
if (self.emb_mode == "cell_and_gene") or (self.emb_mode == "cls_and_gene"):
|
999 |
pu.write_perturbation_dictionary(
|
1000 |
stored_gene_embs_dict,
|
1001 |
f"{output_path_prefix}_dict_gene_embs_1Kbatch{pickle_batch}",
|