torch.cuda.OutOfMemoryError: CUDA out of memory in in_silico_perturbation

#276
by kchen360 - opened

Hi,

I tried to run in silico perturbation analysis as tutorial:

isp = InSilicoPerturber(perturb_type="delete",
                        perturb_rank_shift=None,
                        genes_to_perturb="all",
                        combos=0,
                        anchor_gene=None,
                        model_type="CellClassifier",
                        num_classes=3,
                        emb_mode="cell",
                        cell_emb_style="mean_pool",
                        filter_data={"cell_type":["Cardiomyocyte1","Cardiomyocyte2","Cardiomyocyte3"]},
                        cell_states_to_model={'state_key': 'disease', 
                                              'start_state': 'dcm', 
                                              'goal_state': 'nf', 
                                              'alt_states': ['hcm']},
                        max_ncells=None,
                        emb_layer=0,
                        forward_batch_size=5,
                        nproc=16)

isp.perturb_data("./fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224",
                 "./data/human_dcm_hcm_nf.dataset",
                 "./perturb_out/",
                 "emb_test")

ispstats = InSilicoPerturberStats(mode="goal_state_shift",
                                  genes_perturbed="all",
                                  combos=0,
                                  anchor_gene=None,
                                  #cell_states_to_model={"disease":(["dcm"],["nf"],["hcm"])})
                                  cell_states_to_model={'state_key': 'disease', 
                                              'start_state': 'dcm', 
                                              'goal_state': 'nf', 
                                              'alt_states': ['hcm']})

ispstats.get_stats("./perturb_out/",
                   None,
                   "./perturb_stats_out/",
                   "emb_test")

And error occurs after a while:

Filter (num_proc=16): 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 579159/579159 [00:15<00:00, 38123.09 examples/s]
Filter (num_proc=16): 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 579159/579159 [00:15<00:00, 36506.16 examples/s]
Traceback (most recent call last):
  File "/home/Geneformer/in_silico_perturbation.py", line 66, in <module>
    isp.perturb_data("./fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224",
  File "/home/geneformer/in_silico_perturber.py", line 969, in perturb_data
    state_embs_dict = get_cell_state_avg_embs(model,
  File "/home/Geneformer/geneformer/in_silico_perturber.py", line 290, in get_cell_state_avg_embs
    outputs = model(
  File "/home/anaconda3/envs/biomedical_py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/anaconda3/envs/biomedical_py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/anaconda3/envs/biomedical_py310/lib/python3.10/site-packages/transformers/models/bert/modeling_bert.py", line 1564, in forward
    outputs = self.bert(
  File "/home/anaconda3/envs/biomedical_py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/anaconda3/envs/biomedical_py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/anaconda3/envs/biomedical_py310/lib/python3.10/site-packages/transformers/models/bert/modeling_bert.py", line 1022, in forward
    encoder_outputs = self.encoder(
  File "/home/anaconda3/envs/biomedical_py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/anaconda3/envs/biomedical_py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/anaconda3/envs/biomedical_py310/lib/python3.10/site-packages/transformers/models/bert/modeling_bert.py", line 612, in forward
    layer_outputs = layer_module(
  File "/home/anaconda3/envs/biomedical_py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/anaconda3/envs/biomedical_py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/anaconda3/envs/biomedical_py310/lib/python3.10/site-packages/transformers/models/bert/modeling_bert.py", line 497, in forward
    self_attention_outputs = self.attention(
  File "/home/anaconda3/envs/biomedical_py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/anaconda3/envs/biomedical_py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/anaconda3/envs/biomedical_py310/lib/python3.10/site-packages/transformers/models/bert/modeling_bert.py", line 427, in forward
    self_outputs = self.self(
  File "/home/anaconda3/envs/biomedical_py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/anaconda3/envs/biomedical_py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/anaconda3/envs/biomedical_py310/lib/python3.10/site-packages/transformers/models/bert/modeling_bert.py", line 365, in forward
    context_layer = torch.matmul(attention_probs, value_layer)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 10.00 MiB. GPU 0 has a total capacty of 79.15 GiB of which 7.25 MiB is free. Including non-PyTorch memory, this process has 79.13 GiB memory in use. Of the allocated memory 78.63 GiB is allocated by PyTorch, and 12.50 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

I didn't change any source code and this is the example code I tried to run. We have a VRAM of 80 GiB and this error still occurs, could I learn if you ever encounter such error and any tip to solve it? Thank you very much.

Follow up this forum, I noticed that memory leaks might happen in function get_cell_state_avg_embs() at steps:

            with torch.no_grad():
                outputs = model(
                    input_ids = input_data_minibatch.to("cuda"),
                    attention_mask = attention_mask
                )
            state_embs_i = outputs.hidden_states[layer_to_quant]
            state_embs_list += [state_embs_i]

I noticed that approx 0.1GB will be added to cuda in each iteration, with max_ncells=None, it will take ~ 1211 iterations (total batch length 60579 with forward batch size of 50) to finish the loop. In that case it will need 120GB VRAM to complete this loop and I don't think that's what it should be. I know that memory will be increased in each iteration since state_embs_list stays in cuda, but could you provide reasonable estimated memory that will be added to the cuda in each iteration in your cardiomyopathy knockout cases? Thanks.

Thank you for your interest in Geneformer and for your patience! We pushed an update that should resolve this issue. If you continue to face errors after pulling the updated code, please let us know by either reopening this discussion if it's the same error or opening a new discussion if it's a new error. Please note that the state_embs_dict is now passed after deriving it with the EmbExtractor so that it occurs outside of the in silico perturber so that it doesn't need to be repeated when parallelizing the in silico perturbation. (see example in examples directory) Thank you!

ctheodoris changed discussion status to closed

Sign up or log in to comment