fix overexpression with anchor genes and add perturbation batch with special tokens
Browse files- geneformer/perturber_utils.py +90 -16
geneformer/perturber_utils.py
CHANGED
@@ -155,6 +155,9 @@ def quant_layers(model):
|
|
155 |
layer_nums += [int(name.split("layer.")[1].split(".")[0])]
|
156 |
return int(max(layer_nums)) + 1
|
157 |
|
|
|
|
|
|
|
158 |
|
159 |
def get_model_input_size(model):
|
160 |
return int(re.split("\(|,", str(model.bert.embeddings.position_embeddings))[1])
|
@@ -222,9 +225,10 @@ def overexpress_indices(example):
|
|
222 |
indices = example["perturb_index"]
|
223 |
if any(isinstance(el, list) for el in indices):
|
224 |
indices = flatten_list(indices)
|
225 |
-
|
226 |
-
|
227 |
-
|
|
|
228 |
example["length"] = len(example["input_ids"])
|
229 |
return example
|
230 |
|
@@ -233,15 +237,15 @@ def overexpress_indices_special(example):
|
|
233 |
indices = example["perturb_index"]
|
234 |
if any(isinstance(el, list) for el in indices):
|
235 |
indices = flatten_list(indices)
|
236 |
-
|
237 |
-
|
238 |
-
|
|
|
239 |
example["length"] = len(example["input_ids"])
|
240 |
return example
|
241 |
|
242 |
# for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
|
243 |
def overexpress_tokens(example, max_len, special_token):
|
244 |
-
original_len = example["length"]
|
245 |
# -100 indicates tokens to overexpress are not present in rank value encoding
|
246 |
if example["perturb_index"] != [-100]:
|
247 |
example = delete_indices(example)
|
@@ -347,7 +351,7 @@ def remove_perturbed_indices_set(
|
|
347 |
|
348 |
|
349 |
def make_perturbation_batch(
|
350 |
-
example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
|
351 |
) -> tuple[Dataset, List[int]]:
|
352 |
if combo_lvl == 0 and tokens_to_perturb == "all":
|
353 |
if perturb_type in ["overexpress", "activate"]:
|
@@ -355,7 +359,7 @@ def make_perturbation_batch(
|
|
355 |
elif perturb_type in ["delete", "inhibit"]:
|
356 |
range_start = 0
|
357 |
indices_to_perturb = [
|
358 |
-
|
359 |
]
|
360 |
# elif combo_lvl > 0 and anchor_token is None:
|
361 |
## to implement
|
@@ -409,14 +413,84 @@ def make_perturbation_batch(
|
|
409 |
delete_indices, num_proc=num_proc_i
|
410 |
)
|
411 |
elif perturb_type == "overexpress":
|
412 |
-
|
413 |
-
perturbation_dataset = perturbation_dataset.map(
|
414 |
-
overexpress_indices_special, num_proc=num_proc_i
|
415 |
-
)
|
416 |
-
else:
|
417 |
-
perturbation_dataset = perturbation_dataset.map(
|
418 |
overexpress_indices, num_proc=num_proc_i
|
419 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
420 |
|
421 |
perturbation_dataset = perturbation_dataset.map(measure_length, num_proc=num_proc_i)
|
422 |
|
|
|
155 |
layer_nums += [int(name.split("layer.")[1].split(".")[0])]
|
156 |
return int(max(layer_nums)) + 1
|
157 |
|
158 |
+
def get_model_embedding_dimensions(model):
|
159 |
+
return int(re.split("\(|,", str(model.bert.embeddings.position_embeddings))[2].strip().replace(")", ""))
|
160 |
+
|
161 |
|
162 |
def get_model_input_size(model):
|
163 |
return int(re.split("\(|,", str(model.bert.embeddings.position_embeddings))[1])
|
|
|
225 |
indices = example["perturb_index"]
|
226 |
if any(isinstance(el, list) for el in indices):
|
227 |
indices = flatten_list(indices)
|
228 |
+
insert_pos = 0
|
229 |
+
for index in sorted(indices, reverse=False):
|
230 |
+
example["input_ids"].insert(insert_pos, example["input_ids"].pop(index))
|
231 |
+
insert_pos += 1
|
232 |
example["length"] = len(example["input_ids"])
|
233 |
return example
|
234 |
|
|
|
237 |
indices = example["perturb_index"]
|
238 |
if any(isinstance(el, list) for el in indices):
|
239 |
indices = flatten_list(indices)
|
240 |
+
insert_pos = 1 # Insert starting after CLS token
|
241 |
+
for index in sorted(indices, reverse=False):
|
242 |
+
example["input_ids"].insert(insert_pos, example["input_ids"].pop(index))
|
243 |
+
insert_pos += 1
|
244 |
example["length"] = len(example["input_ids"])
|
245 |
return example
|
246 |
|
247 |
# for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
|
248 |
def overexpress_tokens(example, max_len, special_token):
|
|
|
249 |
# -100 indicates tokens to overexpress are not present in rank value encoding
|
250 |
if example["perturb_index"] != [-100]:
|
251 |
example = delete_indices(example)
|
|
|
351 |
|
352 |
|
353 |
def make_perturbation_batch(
|
354 |
+
example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
|
355 |
) -> tuple[Dataset, List[int]]:
|
356 |
if combo_lvl == 0 and tokens_to_perturb == "all":
|
357 |
if perturb_type in ["overexpress", "activate"]:
|
|
|
359 |
elif perturb_type in ["delete", "inhibit"]:
|
360 |
range_start = 0
|
361 |
indices_to_perturb = [
|
362 |
+
[i] for i in range(range_start, example_cell["length"][0])
|
363 |
]
|
364 |
# elif combo_lvl > 0 and anchor_token is None:
|
365 |
## to implement
|
|
|
413 |
delete_indices, num_proc=num_proc_i
|
414 |
)
|
415 |
elif perturb_type == "overexpress":
|
416 |
+
perturbation_dataset = perturbation_dataset.map(
|
|
|
|
|
|
|
|
|
|
|
417 |
overexpress_indices, num_proc=num_proc_i
|
418 |
+
)
|
419 |
+
|
420 |
+
perturbation_dataset = perturbation_dataset.map(measure_length, num_proc=num_proc_i)
|
421 |
+
|
422 |
+
return perturbation_dataset, indices_to_perturb
|
423 |
+
|
424 |
+
|
425 |
+
def make_perturbation_batch_special(
|
426 |
+
example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
|
427 |
+
) -> tuple[Dataset, List[int]]:
|
428 |
+
if combo_lvl == 0 and tokens_to_perturb == "all":
|
429 |
+
if perturb_type in ["overexpress", "activate"]:
|
430 |
+
range_start = 1
|
431 |
+
elif perturb_type in ["delete", "inhibit"]:
|
432 |
+
range_start = 0
|
433 |
+
range_start += 1 # Starting after the CLS token
|
434 |
+
indices_to_perturb = [
|
435 |
+
[i] for i in range(range_start, example_cell["length"][0]-1) # And excluding the EOS token
|
436 |
+
]
|
437 |
+
|
438 |
+
# elif combo_lvl > 0 and anchor_token is None:
|
439 |
+
## to implement
|
440 |
+
elif combo_lvl > 0 and (anchor_token is not None):
|
441 |
+
example_input_ids = example_cell["input_ids"][0]
|
442 |
+
anchor_index = example_input_ids.index(anchor_token[0])
|
443 |
+
indices_to_perturb = [
|
444 |
+
sorted([anchor_index, i]) if i != anchor_index else None
|
445 |
+
for i in range(1, example_cell["length"][0]-1) # Exclude CLS and EOS tokens
|
446 |
+
]
|
447 |
+
indices_to_perturb = [item for item in indices_to_perturb if item is not None]
|
448 |
+
else: # still need to update
|
449 |
+
example_input_ids = example_cell["input_ids"][0]
|
450 |
+
indices_to_perturb = [
|
451 |
+
[example_input_ids.index(token)] if token in example_input_ids else None
|
452 |
+
for token in tokens_to_perturb
|
453 |
+
]
|
454 |
+
indices_to_perturb = [item for item in indices_to_perturb if item is not None]
|
455 |
+
|
456 |
+
# create all permutations of combo_lvl of modifiers from tokens_to_perturb
|
457 |
+
# still need to update
|
458 |
+
if combo_lvl > 0 and (anchor_token is None):
|
459 |
+
if tokens_to_perturb != "all":
|
460 |
+
if len(tokens_to_perturb) == combo_lvl + 1:
|
461 |
+
indices_to_perturb = [
|
462 |
+
list(x) for x in it.combinations(indices_to_perturb, combo_lvl + 1)
|
463 |
+
]
|
464 |
+
else:
|
465 |
+
all_indices = [[i] for i in range(1, example_cell["length"][0]-1)] # Exclude CLS and EOS tokens
|
466 |
+
all_indices = [
|
467 |
+
index for index in all_indices if index not in indices_to_perturb
|
468 |
+
]
|
469 |
+
indices_to_perturb = [
|
470 |
+
[[j for i in indices_to_perturb for j in i], x] for x in all_indices
|
471 |
+
]
|
472 |
+
|
473 |
+
length = len(indices_to_perturb)
|
474 |
+
perturbation_dataset = Dataset.from_dict(
|
475 |
+
{
|
476 |
+
"input_ids": example_cell["input_ids"] * length,
|
477 |
+
"perturb_index": indices_to_perturb,
|
478 |
+
}
|
479 |
+
)
|
480 |
+
|
481 |
+
if length < 400:
|
482 |
+
num_proc_i = 1
|
483 |
+
else:
|
484 |
+
num_proc_i = num_proc
|
485 |
+
|
486 |
+
if perturb_type == "delete":
|
487 |
+
perturbation_dataset = perturbation_dataset.map(
|
488 |
+
delete_indices, num_proc=num_proc_i
|
489 |
+
)
|
490 |
+
elif perturb_type == "overexpress":
|
491 |
+
perturbation_dataset = perturbation_dataset.map(
|
492 |
+
overexpress_indices_special, num_proc=num_proc_i
|
493 |
+
)
|
494 |
|
495 |
perturbation_dataset = perturbation_dataset.map(measure_length, num_proc=num_proc_i)
|
496 |
|