Upload in_silico_perturber.py
#432
by
davidjwen
- opened
geneformer/in_silico_perturber.py
CHANGED
|
@@ -40,7 +40,7 @@ import pickle
|
|
| 40 |
from collections import defaultdict
|
| 41 |
|
| 42 |
import torch
|
| 43 |
-
from datasets import Dataset
|
| 44 |
from multiprocess import set_start_method
|
| 45 |
from tqdm.auto import trange
|
| 46 |
|
|
@@ -48,7 +48,9 @@ from . import TOKEN_DICTIONARY_FILE
|
|
| 48 |
from . import perturber_utils as pu
|
| 49 |
from .emb_extractor import get_embs
|
| 50 |
|
| 51 |
-
|
|
|
|
|
|
|
| 52 |
|
| 53 |
logger = logging.getLogger(__name__)
|
| 54 |
|
|
@@ -794,6 +796,8 @@ class InSilicoPerturber:
|
|
| 794 |
return example
|
| 795 |
|
| 796 |
total_batch_length = len(filtered_input_data)
|
|
|
|
|
|
|
| 797 |
if self.cell_states_to_model is None:
|
| 798 |
cos_sims_dict = defaultdict(list)
|
| 799 |
else:
|
|
@@ -878,7 +882,7 @@ class InSilicoPerturber:
|
|
| 878 |
)
|
| 879 |
|
| 880 |
##### CLS and Gene Embedding Mode #####
|
| 881 |
-
elif self.emb_mode == "cls_and_gene":
|
| 882 |
full_original_emb = get_embs(
|
| 883 |
model,
|
| 884 |
minibatch,
|
|
@@ -891,6 +895,7 @@ class InSilicoPerturber:
|
|
| 891 |
silent=True,
|
| 892 |
)
|
| 893 |
indices_to_perturb = perturbation_batch["perturb_index"]
|
|
|
|
| 894 |
# remove indices that were perturbed
|
| 895 |
original_emb = pu.remove_perturbed_indices_set(
|
| 896 |
full_original_emb,
|
|
@@ -899,6 +904,7 @@ class InSilicoPerturber:
|
|
| 899 |
self.tokens_to_perturb,
|
| 900 |
minibatch["length"],
|
| 901 |
)
|
|
|
|
| 902 |
full_perturbation_emb = get_embs(
|
| 903 |
model,
|
| 904 |
perturbation_batch,
|
|
@@ -910,7 +916,7 @@ class InSilicoPerturber:
|
|
| 910 |
summary_stat=None,
|
| 911 |
silent=True,
|
| 912 |
)
|
| 913 |
-
|
| 914 |
# remove special tokens and padding
|
| 915 |
original_emb = original_emb[:, 1:-1, :]
|
| 916 |
if self.perturb_type == "overexpress":
|
|
@@ -921,9 +927,25 @@ class InSilicoPerturber:
|
|
| 921 |
perturbation_emb = full_perturbation_emb[
|
| 922 |
:, 1 : max(perturbation_batch["length"]) - 1, :
|
| 923 |
]
|
| 924 |
-
|
| 925 |
n_perturbation_genes = perturbation_emb.size()[1]
|
| 926 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 927 |
gene_cos_sims = pu.quant_cos_sims(
|
| 928 |
perturbation_emb,
|
| 929 |
original_emb,
|
|
|
|
| 40 |
from collections import defaultdict
|
| 41 |
|
| 42 |
import torch
|
| 43 |
+
from datasets import Dataset
|
| 44 |
from multiprocess import set_start_method
|
| 45 |
from tqdm.auto import trange
|
| 46 |
|
|
|
|
| 48 |
from . import perturber_utils as pu
|
| 49 |
from .emb_extractor import get_embs
|
| 50 |
|
| 51 |
+
import datasets
|
| 52 |
+
datasets.logging.disable_progress_bar()
|
| 53 |
+
|
| 54 |
|
| 55 |
logger = logging.getLogger(__name__)
|
| 56 |
|
|
|
|
| 796 |
return example
|
| 797 |
|
| 798 |
total_batch_length = len(filtered_input_data)
|
| 799 |
+
|
| 800 |
+
|
| 801 |
if self.cell_states_to_model is None:
|
| 802 |
cos_sims_dict = defaultdict(list)
|
| 803 |
else:
|
|
|
|
| 882 |
)
|
| 883 |
|
| 884 |
##### CLS and Gene Embedding Mode #####
|
| 885 |
+
elif self.emb_mode == "cls_and_gene":
|
| 886 |
full_original_emb = get_embs(
|
| 887 |
model,
|
| 888 |
minibatch,
|
|
|
|
| 895 |
silent=True,
|
| 896 |
)
|
| 897 |
indices_to_perturb = perturbation_batch["perturb_index"]
|
| 898 |
+
|
| 899 |
# remove indices that were perturbed
|
| 900 |
original_emb = pu.remove_perturbed_indices_set(
|
| 901 |
full_original_emb,
|
|
|
|
| 904 |
self.tokens_to_perturb,
|
| 905 |
minibatch["length"],
|
| 906 |
)
|
| 907 |
+
|
| 908 |
full_perturbation_emb = get_embs(
|
| 909 |
model,
|
| 910 |
perturbation_batch,
|
|
|
|
| 916 |
summary_stat=None,
|
| 917 |
silent=True,
|
| 918 |
)
|
| 919 |
+
|
| 920 |
# remove special tokens and padding
|
| 921 |
original_emb = original_emb[:, 1:-1, :]
|
| 922 |
if self.perturb_type == "overexpress":
|
|
|
|
| 927 |
perturbation_emb = full_perturbation_emb[
|
| 928 |
:, 1 : max(perturbation_batch["length"]) - 1, :
|
| 929 |
]
|
| 930 |
+
|
| 931 |
n_perturbation_genes = perturbation_emb.size()[1]
|
| 932 |
|
| 933 |
+
# truncate the original embedding as necessary
|
| 934 |
+
if self.perturb_type == "overexpress":
|
| 935 |
+
def calc_perturbation_length(ids):
|
| 936 |
+
if ids == [-100]:
|
| 937 |
+
return 0
|
| 938 |
+
else:
|
| 939 |
+
return len(ids)
|
| 940 |
+
|
| 941 |
+
max_tensor_size = max([length - calc_perturbation_length(ids) - 2 for length, ids in zip(minibatch["length"], indices_to_perturb)])
|
| 942 |
+
|
| 943 |
+
max_n_overflow = max(minibatch["n_overflow"])
|
| 944 |
+
if max_n_overflow > 0 and perturbation_emb.size()[1] < original_emb.size()[1]:
|
| 945 |
+
original_emb = original_emb[:, 0 : perturbation_emb.size()[1], :]
|
| 946 |
+
elif perturbation_emb.size()[1] < original_emb.size()[1]:
|
| 947 |
+
original_emb = original_emb[:, 0:max_tensor_size, :]
|
| 948 |
+
|
| 949 |
gene_cos_sims = pu.quant_cos_sims(
|
| 950 |
perturbation_emb,
|
| 951 |
original_emb,
|