hchen725 commited on
Commit
9154763
1 Parent(s): bcaf65e

fix overexpression with anchor genes and add perturbation batch with special tokens

Browse files
Files changed (1) hide show
  1. geneformer/perturber_utils.py +90 -16
geneformer/perturber_utils.py CHANGED
@@ -155,6 +155,9 @@ def quant_layers(model):
155
  layer_nums += [int(name.split("layer.")[1].split(".")[0])]
156
  return int(max(layer_nums)) + 1
157
 
 
 
 
158
 
159
  def get_model_input_size(model):
160
  return int(re.split("\(|,", str(model.bert.embeddings.position_embeddings))[1])
@@ -222,9 +225,10 @@ def overexpress_indices(example):
222
  indices = example["perturb_index"]
223
  if any(isinstance(el, list) for el in indices):
224
  indices = flatten_list(indices)
225
- for index in sorted(indices, reverse=True):
226
- example["input_ids"].insert(0, example["input_ids"].pop(index))
227
-
 
228
  example["length"] = len(example["input_ids"])
229
  return example
230
 
@@ -233,15 +237,15 @@ def overexpress_indices_special(example):
233
  indices = example["perturb_index"]
234
  if any(isinstance(el, list) for el in indices):
235
  indices = flatten_list(indices)
236
- for index in sorted(indices, reverse=True):
237
- example["input_ids"].insert(1, example["input_ids"].pop(index))
238
-
 
239
  example["length"] = len(example["input_ids"])
240
  return example
241
 
242
  # for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
243
  def overexpress_tokens(example, max_len, special_token):
244
- original_len = example["length"]
245
  # -100 indicates tokens to overexpress are not present in rank value encoding
246
  if example["perturb_index"] != [-100]:
247
  example = delete_indices(example)
@@ -347,7 +351,7 @@ def remove_perturbed_indices_set(
347
 
348
 
349
  def make_perturbation_batch(
350
- example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc, special_token
351
  ) -> tuple[Dataset, List[int]]:
352
  if combo_lvl == 0 and tokens_to_perturb == "all":
353
  if perturb_type in ["overexpress", "activate"]:
@@ -355,7 +359,7 @@ def make_perturbation_batch(
355
  elif perturb_type in ["delete", "inhibit"]:
356
  range_start = 0
357
  indices_to_perturb = [
358
- [i] for i in range(range_start, example_cell["length"][0])
359
  ]
360
  # elif combo_lvl > 0 and anchor_token is None:
361
  ## to implement
@@ -409,14 +413,84 @@ def make_perturbation_batch(
409
  delete_indices, num_proc=num_proc_i
410
  )
411
  elif perturb_type == "overexpress":
412
- if special_token:
413
- perturbation_dataset = perturbation_dataset.map(
414
- overexpress_indices_special, num_proc=num_proc_i
415
- )
416
- else:
417
- perturbation_dataset = perturbation_dataset.map(
418
  overexpress_indices, num_proc=num_proc_i
419
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
 
421
  perturbation_dataset = perturbation_dataset.map(measure_length, num_proc=num_proc_i)
422
 
 
155
  layer_nums += [int(name.split("layer.")[1].split(".")[0])]
156
  return int(max(layer_nums)) + 1
157
 
158
+ def get_model_embedding_dimensions(model):
159
+ return int(re.split("\(|,", str(model.bert.embeddings.position_embeddings))[2].strip().replace(")", ""))
160
+
161
 
162
  def get_model_input_size(model):
163
  return int(re.split("\(|,", str(model.bert.embeddings.position_embeddings))[1])
 
225
  indices = example["perturb_index"]
226
  if any(isinstance(el, list) for el in indices):
227
  indices = flatten_list(indices)
228
+ insert_pos = 0
229
+ for index in sorted(indices, reverse=False):
230
+ example["input_ids"].insert(insert_pos, example["input_ids"].pop(index))
231
+ insert_pos += 1
232
  example["length"] = len(example["input_ids"])
233
  return example
234
 
 
237
  indices = example["perturb_index"]
238
  if any(isinstance(el, list) for el in indices):
239
  indices = flatten_list(indices)
240
+ insert_pos = 1 # Insert starting after CLS token
241
+ for index in sorted(indices, reverse=False):
242
+ example["input_ids"].insert(insert_pos, example["input_ids"].pop(index))
243
+ insert_pos += 1
244
  example["length"] = len(example["input_ids"])
245
  return example
246
 
247
  # for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
248
  def overexpress_tokens(example, max_len, special_token):
 
249
  # -100 indicates tokens to overexpress are not present in rank value encoding
250
  if example["perturb_index"] != [-100]:
251
  example = delete_indices(example)
 
351
 
352
 
353
  def make_perturbation_batch(
354
+ example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
355
  ) -> tuple[Dataset, List[int]]:
356
  if combo_lvl == 0 and tokens_to_perturb == "all":
357
  if perturb_type in ["overexpress", "activate"]:
 
359
  elif perturb_type in ["delete", "inhibit"]:
360
  range_start = 0
361
  indices_to_perturb = [
362
+ [i] for i in range(range_start, example_cell["length"][0])
363
  ]
364
  # elif combo_lvl > 0 and anchor_token is None:
365
  ## to implement
 
413
  delete_indices, num_proc=num_proc_i
414
  )
415
  elif perturb_type == "overexpress":
416
+ perturbation_dataset = perturbation_dataset.map(
 
 
 
 
 
417
  overexpress_indices, num_proc=num_proc_i
418
+ )
419
+
420
+ perturbation_dataset = perturbation_dataset.map(measure_length, num_proc=num_proc_i)
421
+
422
+ return perturbation_dataset, indices_to_perturb
423
+
424
+
425
+ def make_perturbation_batch_special(
426
+ example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
427
+ ) -> tuple[Dataset, List[int]]:
428
+ if combo_lvl == 0 and tokens_to_perturb == "all":
429
+ if perturb_type in ["overexpress", "activate"]:
430
+ range_start = 1
431
+ elif perturb_type in ["delete", "inhibit"]:
432
+ range_start = 0
433
+ range_start += 1 # Starting after the CLS token
434
+ indices_to_perturb = [
435
+ [i] for i in range(range_start, example_cell["length"][0]-1) # And excluding the EOS token
436
+ ]
437
+
438
+ # elif combo_lvl > 0 and anchor_token is None:
439
+ ## to implement
440
+ elif combo_lvl > 0 and (anchor_token is not None):
441
+ example_input_ids = example_cell["input_ids"][0]
442
+ anchor_index = example_input_ids.index(anchor_token[0])
443
+ indices_to_perturb = [
444
+ sorted([anchor_index, i]) if i != anchor_index else None
445
+ for i in range(1, example_cell["length"][0]-1) # Exclude CLS and EOS tokens
446
+ ]
447
+ indices_to_perturb = [item for item in indices_to_perturb if item is not None]
448
+ else: # still need to update
449
+ example_input_ids = example_cell["input_ids"][0]
450
+ indices_to_perturb = [
451
+ [example_input_ids.index(token)] if token in example_input_ids else None
452
+ for token in tokens_to_perturb
453
+ ]
454
+ indices_to_perturb = [item for item in indices_to_perturb if item is not None]
455
+
456
+ # create all permutations of combo_lvl of modifiers from tokens_to_perturb
457
+ # still need to update
458
+ if combo_lvl > 0 and (anchor_token is None):
459
+ if tokens_to_perturb != "all":
460
+ if len(tokens_to_perturb) == combo_lvl + 1:
461
+ indices_to_perturb = [
462
+ list(x) for x in it.combinations(indices_to_perturb, combo_lvl + 1)
463
+ ]
464
+ else:
465
+ all_indices = [[i] for i in range(1, example_cell["length"][0]-1)] # Exclude CLS and EOS tokens
466
+ all_indices = [
467
+ index for index in all_indices if index not in indices_to_perturb
468
+ ]
469
+ indices_to_perturb = [
470
+ [[j for i in indices_to_perturb for j in i], x] for x in all_indices
471
+ ]
472
+
473
+ length = len(indices_to_perturb)
474
+ perturbation_dataset = Dataset.from_dict(
475
+ {
476
+ "input_ids": example_cell["input_ids"] * length,
477
+ "perturb_index": indices_to_perturb,
478
+ }
479
+ )
480
+
481
+ if length < 400:
482
+ num_proc_i = 1
483
+ else:
484
+ num_proc_i = num_proc
485
+
486
+ if perturb_type == "delete":
487
+ perturbation_dataset = perturbation_dataset.map(
488
+ delete_indices, num_proc=num_proc_i
489
+ )
490
+ elif perturb_type == "overexpress":
491
+ perturbation_dataset = perturbation_dataset.map(
492
+ overexpress_indices_special, num_proc=num_proc_i
493
+ )
494
 
495
  perturbation_dataset = perturbation_dataset.map(measure_length, num_proc=num_proc_i)
496