davidjwen commited on
Commit
771c8bd
1 Parent(s): 65b4915

Fixed error with perturbing individual genes and updated ways to specify cell_states_to_model

Browse files
Files changed (1) hide show
  1. geneformer/in_silico_perturber.py +138 -34
geneformer/in_silico_perturber.py CHANGED
@@ -105,6 +105,12 @@ def downsample_and_sort(data_shuffled, max_ncells):
105
  data_sorted = data_subset.sort("length",reverse=True)
106
  return data_sorted
107
 
 
 
 
 
 
 
108
  def forward_pass_single_cell(model, example_cell, layer_to_quant):
109
  example_cell.set_format(type="torch")
110
  input_data = example_cell["input_ids"]
@@ -235,13 +241,15 @@ def get_cell_state_avg_embs(model,
235
  num_proc):
236
 
237
  model_input_size = get_model_input_size(model)
238
- possible_states = [value[0]+value[1]+value[2] for value in cell_states_to_model.values()][0]
239
  state_embs_dict = dict()
240
  for possible_state in possible_states:
241
  state_embs_list = []
 
242
 
243
  def filter_states(example):
244
- return example[list(cell_states_to_model.keys())[0]] in [possible_state]
 
245
  filtered_input_data_state = filtered_input_data.filter(filter_states, num_proc=num_proc)
246
  total_batch_length = len(filtered_input_data_state)
247
  if ((total_batch_length-1)/forward_batch_size).is_integer():
@@ -254,6 +262,7 @@ def get_cell_state_avg_embs(model,
254
  state_minibatch.set_format(type="torch")
255
 
256
  input_data_minibatch = state_minibatch["input_ids"]
 
257
  input_data_minibatch = pad_tensor_list(input_data_minibatch,
258
  max_len,
259
  pad_token_id,
@@ -271,8 +280,12 @@ def get_cell_state_avg_embs(model,
271
  del input_data_minibatch
272
  del state_embs_i
273
  torch.cuda.empty_cache()
274
- state_embs_stack = torch.cat(state_embs_list)
275
- avg_state_emb = torch.mean(state_embs_stack,dim=[0,1],keepdim=True)
 
 
 
 
276
  state_embs_dict[possible_state] = avg_state_emb
277
  return state_embs_dict
278
 
@@ -291,7 +304,6 @@ def quant_cos_sims(model,
291
  pad_token_id,
292
  model_input_size,
293
  nproc):
294
-
295
  cos = torch.nn.CosineSimilarity(dim=2)
296
  total_batch_length = len(perturbation_batch)
297
  if ((total_batch_length-1)/forward_batch_size).is_integer():
@@ -301,7 +313,7 @@ def quant_cos_sims(model,
301
  comparison_batch = make_comparison_batch(original_emb, indices_to_perturb, perturb_group)
302
  cos_sims = []
303
  else:
304
- possible_states = [value[0]+value[1]+value[2] for value in cell_states_to_model.values()][0]
305
  cos_sims_vs_alt_dict = dict(zip(possible_states,[[] for i in range(len(possible_states))]))
306
 
307
  # measure length of each element in perturbation_batch
@@ -316,6 +328,7 @@ def quant_cos_sims(model,
316
 
317
  # determine if need to pad or truncate batch
318
  minibatch_length_set = set(perturbation_minibatch["length"])
 
319
  if (len(minibatch_length_set) > 1) or (max(minibatch_length_set) > model_input_size):
320
  needs_pad_or_trunc = True
321
  else:
@@ -360,6 +373,7 @@ def quant_cos_sims(model,
360
  # truncate to the (model input size - # tokens to overexpress) to ensure comparability
361
  # since max input size of perturb batch will be reduced by # tokens to overexpress
362
  original_minibatch = original_emb.select([i for i in range(i, max_range)])
 
363
  original_minibatch_length_set = set(original_minibatch["length"])
364
  if perturb_type == "overexpress":
365
  new_max_len = model_input_size - len(tokens_to_perturb)
@@ -385,7 +399,32 @@ def quant_cos_sims(model,
385
  original_minibatch_emb = torch.squeeze(original_outputs.hidden_states[layer_to_quant])
386
  else:
387
  original_minibatch_emb = original_outputs.hidden_states[layer_to_quant]
388
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
389
  # cosine similarity between original emb and batch items
390
  if cell_states_to_model is None:
391
  if perturb_group == False:
@@ -406,7 +445,9 @@ def quant_cos_sims(model,
406
  cos_sims_vs_alt_dict[state] += cos_sim_shift(original_minibatch_emb,
407
  minibatch_emb,
408
  state_embs_dict[state],
409
- perturb_group)
 
 
410
  del outputs
411
  del minibatch_emb
412
  if cell_states_to_model is None:
@@ -421,14 +462,40 @@ def quant_cos_sims(model,
421
  return cos_sims_vs_alt_dict
422
 
423
  # calculate cos sim shift of perturbation with respect to origin and alternative cell
424
- def cos_sim_shift(original_emb, minibatch_emb, alt_emb, perturb_group):
425
  cos = torch.nn.CosineSimilarity(dim=2)
426
- original_emb = torch.mean(original_emb,dim=0,keepdim=True)
427
- if perturb_group == False:
428
  original_emb = original_emb[None, :]
429
- origin_v_end = cos(original_emb,alt_emb)
430
- perturb_emb = torch.mean(minibatch_emb,dim=1,keepdim=True)
431
- perturb_v_end = cos(perturb_emb,alt_emb)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
432
  return [(perturb_v_end-origin_v_end).to("cpu")]
433
 
434
  def pad_list(input_ids, pad_token_id, max_len):
@@ -706,6 +773,12 @@ class InSilicoPerturber:
706
 
707
  if self.cell_states_to_model is not None:
708
  if len(self.cell_states_to_model.items()) == 1:
 
 
 
 
 
 
709
  for key,value in self.cell_states_to_model.items():
710
  if (len(value) == 3) and isinstance(value, tuple):
711
  if isinstance(value[0],list) and isinstance(value[1],list) and isinstance(value[2],list):
@@ -713,14 +786,48 @@ class InSilicoPerturber:
713
  all_values = value[0]+value[1]+value[2]
714
  if len(all_values) == len(set(all_values)):
715
  continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
716
  else:
717
  logger.error(
718
- "Cell states to model must be a single-item dictionary with " \
719
- "key being cell attribute (e.g. 'disease') and value being " \
720
- "tuple of three lists indicating start state, goal end state, and alternate possible end states. " \
721
- "Values should all be unique. " \
722
- "For example: {'disease':(['dcm'],['ctrl'],['hcm'])}")
 
 
 
723
  raise
 
724
  if self.anchor_gene is not None:
725
  self.anchor_gene = None
726
  logger.warning(
@@ -770,6 +877,13 @@ class InSilicoPerturber:
770
  if self.cell_states_to_model is None:
771
  state_embs_dict = None
772
  else:
 
 
 
 
 
 
 
773
  # get dictionary of average cell state embeddings for comparison
774
  downsampled_data = downsample_and_sort(filtered_input_data, self.max_ncells)
775
  state_embs_dict = get_cell_state_avg_embs(model,
@@ -780,9 +894,9 @@ class InSilicoPerturber:
780
  self.forward_batch_size,
781
  self.nproc)
782
  # filter for start state cells
783
- start_state = list(self.cell_states_to_model.values())[0][0][0]
784
  def filter_for_origin(example):
785
- return example[list(self.cell_states_to_model.keys())[0]] in [start_state]
786
 
787
  filtered_input_data = filtered_input_data.filter(filter_for_origin, num_proc=self.nproc)
788
 
@@ -878,7 +992,6 @@ class InSilicoPerturber:
878
  # or (perturbed_genes, "cell_emb") for avg cell emb change
879
  cos_sims_data = cos_sims_data.to("cuda")
880
  max_padded_len = cos_sims_data.shape[1]
881
-
882
  for j in range(cos_sims_data.shape[0]):
883
  # remove padding before mean pooling cell embedding
884
  original_length = original_lengths[j]
@@ -900,21 +1013,13 @@ class InSilicoPerturber:
900
  # update cos sims dict
901
  # key is tuple of (perturbed_genes, "cell_emb")
902
  # value is list of tuples of cos sims for cell_states_to_model
903
- origin_state_key = [value[0] for value in self.cell_states_to_model.values()][0][0]
904
  cos_sims_origin = cos_sims_data[origin_state_key]
905
  for j in range(cos_sims_origin.shape[0]):
906
- original_length = original_lengths[j]
907
- max_padded_len = cos_sims_origin.shape[1]
908
- indices_removed = indices_to_perturb[j]
909
- padding_to_remove = max_padded_len - (original_length \
910
- - len(self.tokens_to_perturb) \
911
- - len(indices_removed))
912
  data_list = []
913
  for data in list(cos_sims_data.values()):
914
  data_item = data.to("cuda")
915
- nonpadding_data_item = data_item[j][:-padding_to_remove]
916
- cell_data = torch.mean(nonpadding_data_item).item()
917
- data_list += [cell_data]
918
  cos_sims_dict[(perturbed_genes, "cell_emb")] += [tuple(data_list)]
919
 
920
  with open(f"{output_path_prefix}_raw.pickle", "wb") as fp:
@@ -987,7 +1092,7 @@ class InSilicoPerturber:
987
  # update cos sims dict
988
  # key is tuple of (perturbed_gene, "cell_emb")
989
  # value is list of tuples of cos sims for cell_states_to_model
990
- origin_state_key = [value[0] for value in self.cell_states_to_model.values()][0][0]
991
  cos_sims_origin = cos_sims_data[origin_state_key]
992
 
993
  for j in range(cos_sims_origin.shape[0]):
@@ -1109,4 +1214,3 @@ class InSilicoPerturber:
1109
  # save remainder cells
1110
  with open(f"{output_path_prefix}{pickle_batch}_raw.pickle", "wb") as fp:
1111
  pickle.dump(cos_sims_dict, fp)
1112
-
 
105
  data_sorted = data_subset.sort("length",reverse=True)
106
  return data_sorted
107
 
108
+ def get_possible_states(cell_states_to_model):
109
+ if list(cell_states_to_model.values())[3] is not None:
110
+ return list(cell_states_to_model.values())[1:3] + list(cell_states_to_model.values())[3]
111
+ else:
112
+ return list(cell_states_to_model.values())[1:3]
113
+
114
  def forward_pass_single_cell(model, example_cell, layer_to_quant):
115
  example_cell.set_format(type="torch")
116
  input_data = example_cell["input_ids"]
 
241
  num_proc):
242
 
243
  model_input_size = get_model_input_size(model)
244
+ possible_states = get_possible_states(cell_states_to_model)
245
  state_embs_dict = dict()
246
  for possible_state in possible_states:
247
  state_embs_list = []
248
+ original_lens = []
249
 
250
  def filter_states(example):
251
+ state_key = cell_states_to_model["state_key"]
252
+ return example[state_key] in possible_state
253
  filtered_input_data_state = filtered_input_data.filter(filter_states, num_proc=num_proc)
254
  total_batch_length = len(filtered_input_data_state)
255
  if ((total_batch_length-1)/forward_batch_size).is_integer():
 
262
  state_minibatch.set_format(type="torch")
263
 
264
  input_data_minibatch = state_minibatch["input_ids"]
265
+ original_lens += [tensor.numel() for tensor in input_data_minibatch]
266
  input_data_minibatch = pad_tensor_list(input_data_minibatch,
267
  max_len,
268
  pad_token_id,
 
280
  del input_data_minibatch
281
  del state_embs_i
282
  torch.cuda.empty_cache()
283
+
284
+ # import here to avoid circular imports
285
+ from .emb_extractor import mean_nonpadding_embs
286
+ state_embs = torch.cat(state_embs_list)
287
+ avg_state_emb = mean_nonpadding_embs(state_embs, torch.Tensor(original_lens).to("cuda"))
288
+ avg_state_emb = torch.mean(avg_state_emb, dim=0, keepdim=True)
289
  state_embs_dict[possible_state] = avg_state_emb
290
  return state_embs_dict
291
 
 
304
  pad_token_id,
305
  model_input_size,
306
  nproc):
 
307
  cos = torch.nn.CosineSimilarity(dim=2)
308
  total_batch_length = len(perturbation_batch)
309
  if ((total_batch_length-1)/forward_batch_size).is_integer():
 
313
  comparison_batch = make_comparison_batch(original_emb, indices_to_perturb, perturb_group)
314
  cos_sims = []
315
  else:
316
+ possible_states = get_possible_states(cell_states_to_model)
317
  cos_sims_vs_alt_dict = dict(zip(possible_states,[[] for i in range(len(possible_states))]))
318
 
319
  # measure length of each element in perturbation_batch
 
328
 
329
  # determine if need to pad or truncate batch
330
  minibatch_length_set = set(perturbation_minibatch["length"])
331
+ minibatch_lengths = perturbation_minibatch["length"]
332
  if (len(minibatch_length_set) > 1) or (max(minibatch_length_set) > model_input_size):
333
  needs_pad_or_trunc = True
334
  else:
 
373
  # truncate to the (model input size - # tokens to overexpress) to ensure comparability
374
  # since max input size of perturb batch will be reduced by # tokens to overexpress
375
  original_minibatch = original_emb.select([i for i in range(i, max_range)])
376
+ original_minibatch_lengths = original_minibatch["length"]
377
  original_minibatch_length_set = set(original_minibatch["length"])
378
  if perturb_type == "overexpress":
379
  new_max_len = model_input_size - len(tokens_to_perturb)
 
399
  original_minibatch_emb = torch.squeeze(original_outputs.hidden_states[layer_to_quant])
400
  else:
401
  original_minibatch_emb = original_outputs.hidden_states[layer_to_quant]
402
+
403
+ # remove perturbed index before calculating the cos sims
404
+ def remove_indices_from_emb(emb, indices_to_remove, gene_dim):
405
+ # indices_to_remove is list of indices to remove
406
+ gene_dim -= 1 # removing a dim in calling the function
407
+ indices_to_keep = [i for i in range(emb.size()[gene_dim]) if i not in indices_to_remove]
408
+ num_dims = emb.dim()
409
+ emb_slice = [slice(None) if dim != gene_dim else indices_to_keep for dim in range(num_dims)]
410
+ sliced_emb = emb[emb_slice]
411
+ return sliced_emb
412
+
413
+ # this could probably be optimized
414
+ gene_dim = 1
415
+
416
+ # current there's the case if a gene is not expressed and is being overexpressed,
417
+ # the dimensions will be thrown off --> not removing indices to get around that issue
418
+ # not sure what's the best way to handle it
419
+ if perturb_type != "overexpress":
420
+ original_minibatch_emb = torch.stack([
421
+ remove_indices_from_emb(original_minibatch_emb[i, :, :], idx, gene_dim) for
422
+ i, idx in enumerate(indices_to_perturb)
423
+ ])
424
+
425
+ # do the averaging here
426
+
427
+
428
  # cosine similarity between original emb and batch items
429
  if cell_states_to_model is None:
430
  if perturb_group == False:
 
445
  cos_sims_vs_alt_dict[state] += cos_sim_shift(original_minibatch_emb,
446
  minibatch_emb,
447
  state_embs_dict[state],
448
+ perturb_group,
449
+ torch.tensor(original_minibatch_lengths, device="cuda"),
450
+ torch.tensor(minibatch_lengths, device="cuda"))
451
  del outputs
452
  del minibatch_emb
453
  if cell_states_to_model is None:
 
462
  return cos_sims_vs_alt_dict
463
 
464
  # calculate cos sim shift of perturbation with respect to origin and alternative cell
465
+ def cos_sim_shift(original_emb, minibatch_emb, alt_emb, perturb_group, original_minibatch_lengths = None, minibatch_lengths = None,):
466
  cos = torch.nn.CosineSimilarity(dim=2)
467
+ if not perturb_group:
468
+ original_emb = torch.mean(original_emb,dim=0,keepdim=True)
469
  original_emb = original_emb[None, :]
470
+ origin_v_end = torch.squeeze(cos(original_emb, alt_emb))
471
+ else:
472
+ if original_emb.size() != minibatch_emb.size():
473
+ logger.error(
474
+ f"Embeddings are not the same dimensions. " \
475
+ f"original_emb is {original_emb.size()}. " \
476
+ f"minibatch_emb is {minibatch_emb.size()}. "
477
+ )
478
+ raise
479
+ from .emb_extractor import mean_nonpadding_embs
480
+
481
+ if original_minibatch_lengths is not None:
482
+ original_emb = mean_nonpadding_embs(original_emb, original_minibatch_lengths)
483
+ # not sure if the else is necessary, but keeping it here in case
484
+ else:
485
+ original_emb = torch.mean(original_emb,dim=1,keepdim=True)
486
+
487
+ alt_emb = torch.unsqueeze(alt_emb, 1)
488
+ origin_v_end = cos(original_emb, alt_emb)
489
+ origin_v_end = torch.squeeze(origin_v_end)
490
+
491
+ if minibatch_lengths is not None:
492
+ perturb_emb = mean_nonpadding_embs(minibatch_emb, minibatch_lengths)
493
+ else:
494
+ perturb_emb = torch.mean(minibatch_emb,dim=1,keepdim=True)
495
+
496
+ perturb_v_end = cos(perturb_emb, alt_emb)
497
+ perturb_v_end = torch.squeeze(perturb_v_end)
498
+
499
  return [(perturb_v_end-origin_v_end).to("cpu")]
500
 
501
  def pad_list(input_ids, pad_token_id, max_len):
 
773
 
774
  if self.cell_states_to_model is not None:
775
  if len(self.cell_states_to_model.items()) == 1:
776
+ logger.warning(
777
+ "The single value dictionary for cell_states_to_model will be " \
778
+ "replaced with explicitly modeling start and end states. " \
779
+ "Please specify state_key, start_state, end_state, and alt_states " \
780
+ "in the cell_states_to_model dictionary for future use."
781
+ )
782
  for key,value in self.cell_states_to_model.items():
783
  if (len(value) == 3) and isinstance(value, tuple):
784
  if isinstance(value[0],list) and isinstance(value[1],list) and isinstance(value[2],list):
 
786
  all_values = value[0]+value[1]+value[2]
787
  if len(all_values) == len(set(all_values)):
788
  continue
789
+ # reformat to the new format
790
+ state_values = flatten_list(list(self.cell_states_to_model.values()))
791
+ self.cell_states_to_model = {
792
+ "state_key": list(self.cell_states_to_model.keys())[0],
793
+ "start_state": state_values[0][0],
794
+ "goal_state": state_values[1][0],
795
+ "alt_states": state_values[2:][0]
796
+ }
797
+ elif set(self.cell_states_to_model.keys()) == {"state_key", "start_state", "goal_state", "alt_states"}:
798
+ if self.cell_states_to_model["start_state"] is None or self.cell_states_to_model["goal_state"] is None:
799
+ logger.error(
800
+ "Please specify 'start_state' and 'goal_state' in cell_states_to_model.")
801
+ raise
802
+
803
+ if self.cell_states_to_model["start_state"] == self.cell_states_to_model["goal_state"]:
804
+ logger.error(
805
+ "All states must be unique.")
806
+ raise
807
+
808
+ if self.cell_states_to_model["alt_states"] is not None:
809
+ if type(self.cell_states_to_model["alt_states"]) is not list:
810
+ logger.error(
811
+ "self.cell_states_to_model['alt_states'] must be a list (even if it is one element)."
812
+ )
813
+ raise
814
+ if len(self.cell_states_to_model["alt_states"])!= len(set(self.cell_states_to_model["alt_states"])):
815
+ logger.error(
816
+ "All states must be unique.")
817
+ raise
818
+
819
  else:
820
  logger.error(
821
+ "states_to_model must only have the following four keys: 'state_key', 'start_state', 'goal_state', 'alt_states'." \
822
+ "For example, cell_states_to_model={ \
823
+ 'state_key': 'disease', \
824
+ 'start_state': 'dcm', \
825
+ 'goal_state': 'nf'', \
826
+ 'alt_states': ['hcm', 'other1', 'other2'] \
827
+ }"
828
+ )
829
  raise
830
+
831
  if self.anchor_gene is not None:
832
  self.anchor_gene = None
833
  logger.warning(
 
877
  if self.cell_states_to_model is None:
878
  state_embs_dict = None
879
  else:
880
+ # make sure that all states are valid; save time on filtering
881
+ state_name = self.cell_states_to_model["state_key"]
882
+ for value in get_possible_states(self.cell_states_to_model):
883
+ if value not in filtered_input_data[state_name]:
884
+ logger.error(
885
+ f"{value} is not a valid value in {state_name}.")
886
+ raise
887
  # get dictionary of average cell state embeddings for comparison
888
  downsampled_data = downsample_and_sort(filtered_input_data, self.max_ncells)
889
  state_embs_dict = get_cell_state_avg_embs(model,
 
894
  self.forward_batch_size,
895
  self.nproc)
896
  # filter for start state cells
897
+ start_state = self.cell_states_to_model["start_state"]
898
  def filter_for_origin(example):
899
+ return example[state_name] in [start_state]
900
 
901
  filtered_input_data = filtered_input_data.filter(filter_for_origin, num_proc=self.nproc)
902
 
 
992
  # or (perturbed_genes, "cell_emb") for avg cell emb change
993
  cos_sims_data = cos_sims_data.to("cuda")
994
  max_padded_len = cos_sims_data.shape[1]
 
995
  for j in range(cos_sims_data.shape[0]):
996
  # remove padding before mean pooling cell embedding
997
  original_length = original_lengths[j]
 
1013
  # update cos sims dict
1014
  # key is tuple of (perturbed_genes, "cell_emb")
1015
  # value is list of tuples of cos sims for cell_states_to_model
1016
+ origin_state_key = self.cell_states_to_model["start_state"]
1017
  cos_sims_origin = cos_sims_data[origin_state_key]
1018
  for j in range(cos_sims_origin.shape[0]):
 
 
 
 
 
 
1019
  data_list = []
1020
  for data in list(cos_sims_data.values()):
1021
  data_item = data.to("cuda")
1022
+ data_list += [data_item]
 
 
1023
  cos_sims_dict[(perturbed_genes, "cell_emb")] += [tuple(data_list)]
1024
 
1025
  with open(f"{output_path_prefix}_raw.pickle", "wb") as fp:
 
1092
  # update cos sims dict
1093
  # key is tuple of (perturbed_gene, "cell_emb")
1094
  # value is list of tuples of cos sims for cell_states_to_model
1095
+ origin_state_key = self.cell_states_to_model["start_state"]
1096
  cos_sims_origin = cos_sims_data[origin_state_key]
1097
 
1098
  for j in range(cos_sims_origin.shape[0]):
 
1214
  # save remainder cells
1215
  with open(f"{output_path_prefix}{pickle_batch}_raw.pickle", "wb") as fp:
1216
  pickle.dump(cos_sims_dict, fp)