Replicating results for Figure 6

#246
by MinieRosie - opened

Hello I am excited to implement Geneformer on my own data. What I am doing first is I am trying to replicate the results from the manuscript (in particular supplementary table 12, sheet titled DCM_del_tx) to make sure I am using Geneformer correctly.

Here is the code I am running, I copied a lot of the parameters from your example "in_silico_pertubation.ipynb". The input model is the 220224 fine-tuned model you provide and the human_dcm_hcm_nf.dataset tokenized file that is also provided.


from geneformer import InSilicoPerturber
from geneformer import InSilicoPerturberStats

isp = InSilicoPerturber(perturb_type="delete",
                        perturb_rank_shift=None,
                        genes_to_perturb="all",
                        combos=0,
                        anchor_gene=None,
                        model_type="CellClassifier",
                        num_classes=3,
                        emb_mode="cell",
                        cell_emb_style="mean_pool",
                        cell_states_to_model={'state_key':'disease',
                                              'start_state':'dcm',
                                              'goal_state':'nf',
                                              'alt_states':['hcm']},
                        max_ncells=10000,
                        emb_layer=0,
                        forward_batch_size=50,
                        nproc=32)
                    
isp.perturb_data("/yanketn/TOOLS/Geneformer/fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224/",
                 "/yanketn/ANALYSIS/Geneformer/human_dcm_hcm_nf.dataset/",
                 "/yanketn/ANALYSIS/Geneformer/",
                 "test_geneformer_fig6_geneformertoken_10k")

ispstats = InSilicoPerturberStats(mode="goal_state_shift",
                                  genes_perturbed="all",
                                  combos=0,
                                  anchor_gene=None,
                                  cell_states_to_model={"disease":(["dcm"],["nf"],["hcm"])})


ispstats.get_stats("/yanketn/ANALYSIS/Geneformer/",
                   None,
                   "/yanketn/ANALYSIS/Geneformer/",
                   "test_fig6_stats_geneformertoken_10k")

The first difference I saw were the number of unique genes detected. My run detects 17419 genes and from your supplemental table 12 there are 17031 unique genes. I get r^2 of 0.68 which seems reasonable to me.

Screen Shot 2023-09-16 at 6.20.49 PM.png

I next compare the shift and FDR columns between the supplemental table 12 and my own results and this is where I see big differences.

Screen Shot 2023-09-16 at 6.22.25 PM.png
Screen Shot 2023-09-16 at 6.23.36 PM.png
Screen Shot 2023-09-16 at 6.24.38 PM.png
Screen Shot 2023-09-16 at 6.25.10 PM.png

I also looked at the raw data to compare some of genes mentioned in the manuscript:

From table 12 DCM_del_tx :
Screen Shot 2023-09-16 at 6.25.53 PM.png

From my run:

Screen Shot 2023-09-16 at 6.40.04 PM.png

I also compared these genes too:

Screen Shot 2023-09-16 at 6.41.09 PM.png

Lastly, I intersected the "significant" candidate genes..the supplemental table 12 marks about 250 genes where the FDR is < 0.05, and the NF shift is positive, the HCM shift is negative in the table.

I only get about 135 genes that meet this criteria that match between the table from the manuscript and my own run of geneformer. However when I run gene-ontology BP, I get many of the same ontologies that you report in your supplemental table 15.

The only thing I could think that would make the outputs so different from one another is from the number of cells tested. Can you give any insight as to why my results would be so different? I just want to make sure I have geneformer running correctly before I run my own data on it. Thank you very much for your time.

MinieRosie changed discussion title from Trouble replicating results for Figure 6 to Replicating results for Figure 6

Thank you for your interest in Geneformer! If you'd like to exactly replicate the results from the manuscript, you can follow the steps below. Please let me know if your results still don't match after these steps. Please note that the following steps are not something that is necessarily specifically recommended for your own analysis. For example, shuffling the data is unnecessary since the forward pass is deterministic and you are not training the model during the in silico perturbation. We shuffled the cells for other reasons, but in this case to exactly replicate the results, you would need to follow the exact same shuffling so that the subset of cells at the end performed due to resource limitations leads you to the same set of cells.

Please also note that these steps are in reference to the analysis in Table 12 tab "DCM_del_tx", which identifies genes that shift the dilated cardiomyopathy cells towards non-failing. In your post you also mentioned genes ADCY5 and SRPK3, so I just wanted to clarify that because those two genes were candidates from the analysis identifying genes that shift the hypertrophic cardiomyopathy cells towards non-failing, so the start state would be different in that analysis.

from datasets import load_from_disk
from geneformer import InSilicoPerturber
from geneformer import InSilicoPerturberStats

dataset = load_from_disk("/path/to/human_dcm_hcm_nf.dataset")

# subset to cardiomyocytes 
# (of note, in your post you did not seem to subset to cardiomyocytes - this would cause a large difference in the results as the embeddings are context-aware)
def if_cell_type(example):
    return example["cell_type"].startswith("Cardiomyocyte")

dataset = dataset.filter(if_cell_type, num_proc=16)

# subset the cells to the non-held-out individuals for which analysis was performed in the manuscript
indiv_subset = ["1610", "1371", "1515", "1678", "1630", "1726", "1504", "1430", "1558", "1447", "1472", "1622", "1304", "1707", "1582", "1617", "1479", "1549", "1425", "1561", "1722", "1600", "1631", "1358", "1290", "1300", "1508", "1735", "1462"]

def if_indiv(example):
    return example["individual"] in indiv_subset

dataset = dataset.filter(if_indiv, num_proc=16).shuffle(seed=42)

dataset = dataset.shuffle(seed=42)  # do not omit this step, while unnecessary, if you'd like to use the same set of cells as the original analysis

# generate embeddings for each individual using the emb_extractor, then average the embeddings by disease state to derive the 3 goal positions
# of note, the goal state embeddings used for the analysis in the manuscript did not include an attention mask for padding
# the embeddings with or without attention mask are correlated with coefficient >0.999 so this should have minimal impact on results
# if you prefer, you can modify the emb_extractor to not include an attention mask to exactly derive the same goal positions
# however, attention masks are recommended in general for your own analysis to fully minimize impact of padding

# you will have to modify the code of in_silico_perturber.py on line 967 to take your average embeddings generated as above as the state_embs_dict
# rather than using the ones that would be generated just from a subset of the data with get_cell_state_avg_embs

# after deriving the goal positions, subset to the start state of "dcm" to replicate Table 12 tab "DCM_del_tx"
def if_disease(example):
    return example["disease"] == "dcm"

dataset = dataset.filter(if_disease, num_proc=16)

# subset the cells as was needed for resource limitations of max runtime for original analysis to replicate Table 12 tab "DCM_del_tx"
dataset = dataset.select([i for i in range(9701)])

# save the subsetted dataset to use for in silico deletion analysis
dataset.save_to_disk(output_path/to/subsetted.dataset)

# in silico perturbation in deletion mode to determine genes whose 
# deletion in the dilated cardiomyopathy (dcm) state significantly shifts
# the embedding towards non-failing (nf) state
isp = InSilicoPerturber(perturb_type="delete",
                        perturb_rank_shift=None,
                        genes_to_perturb="all",
                        combos=0,
                        anchor_gene=None,
                        model_type="CellClassifier",
                        num_classes=3,
                        emb_mode="cell",
                        cell_emb_style="mean_pool",
                        filter_data={"cell_type":["Cardiomyocyte1","Cardiomyocyte2","Cardiomyocyte3"]},
                        cell_states_to_model={'state_key': 'disease', 
                                              'start_state': 'dcm', 
                                              'goal_state': 'nf', 
                                              'alt_states': ['hcm']},
                        max_ncells=None,
                        emb_layer=0,
                        forward_batch_size=400,
                        nproc=16)

# outputs intermediate files from in silico perturbation
isp.perturb_data("path/to/model",
                 "path/to/subsetted.dataset",   # use saved subsetted .dataset
                 "path/to/output_directory",
                 "output_prefix")

ispstats = InSilicoPerturberStats(mode="goal_state_shift",
                                  genes_perturbed="all",
                                  combos=0,
                                  anchor_gene=None,
                                  cell_states_to_model={'state_key': 'disease', 
                                              'start_state': 'dcm', 
                                              'goal_state': 'nf', 
                                              'alt_states': ['hcm']})

# extracts data from intermediate files and processes stats to output in final .csv
ispstats.get_stats("path/to/input_data",
                   None,
                   "path/to/output_directory",
                   "output_prefix")
ctheodoris changed discussion status to closed
This comment has been hidden

Thank you for providing the code to replicate your results! I am trying to replicate the same table output (Table 12 tab "DCM_del_tx") but am unable to replicate the results exactly. Following the instructions above, I ran the following code:

from datasets import load_from_disk
from geneformer import InSilicoPerturber
from geneformer import InSilicoPerturberStats
from geneformer import EmbExtractor

dataset = load_from_disk("human_dcm_hcm_nf.dataset")

def if_cell_type(example):
    return example["cell_type"].startswith("Cardiomyocyte")

dataset = dataset.filter(if_cell_type, num_proc=16)

indiv_subset = ["1610", "1371", "1515", "1678", "1630", "1726", "1504", "1430", "1558", "1447", "1472", "1622", "1304", "1707", "1582", "1617", "1479", "1549", "1425", "1561", "1722", "1600", "1631", "1358", "1290", "1300", "1508", "1735", "1462"]

def if_indiv(example):
    return example["individual"] in indiv_subset

dataset = dataset.filter(if_indiv, num_proc=16).shuffle(seed=42)

dataset = dataset.shuffle(seed=42)  # do not omit this step, while unnecessary, if you'd like to use the same set of cells as the original analysis

dataset.save_to_disk("subsetted.dataset")

cell_states_to_model={"state_key": "disease", 
                      "start_state": "dcm", 
                      "goal_state": "nf", 
                      "alt_states": ["hcm"]}

filter_data_dict={"cell_type":["Cardiomyocyte1","Cardiomyocyte2","Cardiomyocyte3"]}

embex = EmbExtractor(model_type="CellClassifier",
                     num_classes=3,
                     filter_data=filter_data_dict,
                     max_ncells=1000,
                     emb_layer=0,
                     summary_stat="exact_mean",
                     forward_batch_size=256,
                     nproc=16)

state_embs_dict = embex.get_state_embs(cell_states_to_model,
                                       "Geneformer/fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224",
                                       "subsetted.dataset",
                                       "Embs",
                                       "embs")

def if_disease(example):
    return example["disease"] == "dcm"

dataset = dataset.filter(if_disease, num_proc=16)

dataset = dataset.select([i for i in range(9701)])

dataset.save_to_disk("dementia_subsetted.dataset")

isp = InSilicoPerturber(perturb_type="delete",
                        perturb_rank_shift=None,
                        genes_to_perturb="all",
                        combos=0,
                        anchor_gene=None,
                        model_type="CellClassifier",
                        num_classes=3,
                        emb_mode="cell",
                        cell_emb_style="mean_pool",
                        filter_data=filter_data_dict,
                        cell_states_to_model=cell_states_to_model,
                        state_embs_dict=state_embs_dict,
                        max_ncells=2000,
                        emb_layer=0,
                        forward_batch_size=150,
                        nproc=12)

isp.perturb_data("Geneformer/fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224",
                 "dementia_subsetted.dataset",
                 "PerturbationOutput",
                 "output")

ispstats = InSilicoPerturberStats(mode="goal_state_shift",
                                  genes_perturbed="all",
                                  combos=0,
                                  anchor_gene=None,
                                  cell_states_to_model=cell_states_to_model)

ispstats.get_stats("PerturbationOutput",
                   None,
                   "PerturbationOutput",
                   "outputtedGenes")

I noticed that the outputs of this code included a 'Sig' column that denoted significance according to FDR-corrected p-values < 0.5, which differed from the Shift_Towards_NF_and_Away_from_DCM_and_HCM column in table 12 output. Thus, I added a column called Sig_Shift in my results that replicated the Shift_Towards_NF_and_Away_from_DCM_and_HCM column in order to compare the two results. I denoted Sig_Shift to be 1 if there was a positive shift_to_goal_end and negative shift_to_alt_end_hcm. I then conducted a Fisher's test to examine the overlap denoting my significant genes to be Sig_Shift = 1 and my nonsignificant genes to be Sig_Shift = 0 whereas the original significant genes were Shift_Towards_NF_and_Away_from_DCM_and_HCM = 1 and the original nonsignificant genes were Shift_Towards_NF_and_Away_from_DCM_and_HCM = 0. I am attaching the contingency table and Fisher's test results below. I noticed that not all the genes overlapped and am wondering why that may be the case. Is there an error in my code above or in my followup integration analysis that could have led to this difference? Thank you for all your help in advance!
Fisher_Compare_Mine_Theodoris_Geneformer.png

Thank you for your question! There are a few places I see that are different from the code provided that may be affecting your results:

  • In generating the state embs dict, we used all the cells after subsetting to the given individuals/cell types, not a max of 1000 as you have set.
  • Additionally, as discussed above, we first generated embeddings for each individual using the emb_extractor, then averaged the embeddings by disease state to derive the 3 goal positions. This ensured that the final embedding position was not weighted by the number of cells from each individual (i.e. that an individual where more cells were sequenced did not have a greater weight on the embedding position of the given state).
  • Finally, we used all of the subsetted 9701 cells for the in silico perturbation, so the max_ncells should be None for the InSilicoPerturber, as indicated in the code we provide above.

Thank you so much for the help! I can fix the first and third issues, but I am slightly confused about how to do the second. I thought that I was following the instructions by using the embeddings extractor on the subsetted individuals rather than the entire dataset, but is my issue that I need to again average the embeddings by disease state? I'm not exactly sure how to average the embeddings by disease state and am wondering if you can provide some additional guidance.

Sure - so, the way you have set it up, cells from the subsetted individuals are provided to the EmbExtractor, which then proceeds with get_state_embs to get the cell embeddings for all of those cells, separate them into groups based on disease state, and then average those groups. Therefore, if there is a healthy individual A with N cardiomyocytes sequenced and another healthy individual B with 100*N cardiomyocytes sequenced, the average embedding position for healthy cardiomyocytes will be 100x more influenced by individual B than individual A. The way that we did the analysis, we extracted embeddings from all the cardiomyocytes, then averaged them within each individual so that there was a single mean embedding position for each individual's cardiomyocytes, and then averaged these individual embedding positions for each disease state. That way, individual A and individual B contribute equally to the final embedding position of the healthy cardiomyocyte state. To do this, you can use extract_embs to output the embeddings of all the cells, and then average them as stated, first by individual and then by state, and arrange them in a dictionary similar to what the output of get_state_embs would be. This approach isn't necessarily needed for all studies (e.g. less needed if the number of cells per individual are already adequately balanced), but this is what we did for our analysis.

Thank you for all your guidance! I have resolved these issues and am currently running the in silico perturbation. However, due to out of memory errors, I am only able to use 175 batch size, and I am worried that may affect the outputs for me to try to replicate your results. I am currently using an A100 GPU with 80 GB memory. When I am running the in silico perturbation scripts with all the cells as you specified, the runtime estimate is giving me 58 hours. I am wondering if it is normal for me to be running into these memory issues with my given memory capacity as well as if this runtime is expected with my GPU. In addition, I am wondering which GPU and how much memory you used during your analysis and if you used any other methods to optimize your analyses to be able to run with larger batch size. Thanks for all your time and attention.

@hkamineni It took mine about ~72 hours to run through all 9700 cells. I was able to replicate the results from the paper almost exactly using instructions given above.

@hkamineni Thank you for following up. The batch size will not affect the results since in this case you are using the model only for inference and not fine-tuning the model. The memory can be affected both by batch size and the total accumulation of data across 1000 cells, which is the current setting for clearing the memory (you can change this if needed). However, the batch size you have selected seems reasonable, as does the runtime. Keep in mind you are deleting every single gene in every single cell and measuring the cosine shift compared to the original cell and goal and alternate position, so for every cell, that's usually about 2048 simulated cells you are testing as well, which is a large number of computations.

@MinieRosie Thank you also for following up with that information and glad to hear you were able to replicate the results with that process.

@MinieRosie @ctheodoris Thank you both so much for the information!

Sign up or log in to comment