In [None]:
from geneformer import InSilicoPerturber
from geneformer import InSilicoPerturberStats
from geneformer import EmbExtractor

### in silico perturbation in deletion mode to determine genes whose deletion in the dilated cardiomyopathy (dcm) state significantly shifts the embedding towards non-failing (nf) state

In [None]:
# first obtain start, goal, and alt embedding positions
# this function was changed to be separate from perturb_data
# to avoid repeating calcuations when parallelizing perturb_data
cell_states_to_model={"state_key": "disease", 
 "start_state": "dcm", 
 "goal_state": "nf", 
 "alt_states": ["hcm"]}

filter_data_dict={"cell_type":["Cardiomyocyte1","Cardiomyocyte2","Cardiomyocyte3"]}

embex = EmbExtractor(model_type="CellClassifier",
 num_classes=3,
 filter_data=filter_data_dict,
 max_ncells=1000,
 emb_layer=0,
 summary_stat="exact_mean",
 forward_batch_size=256,
 nproc=16)

state_embs_dict = embex.get_state_embs(cell_states_to_model,
 "path/to/model",
 "path/to/input_data",
 "path/to/output_directory",
 "output_prefix")

In [None]:
isp = InSilicoPerturber(perturb_type="delete",
 perturb_rank_shift=None,
 genes_to_perturb="all",
 combos=0,
 anchor_gene=None,
 model_type="CellClassifier",
 num_classes=3,
 emb_mode="cell",
 cell_emb_style="mean_pool",
 filter_data=filter_data_dict,
 cell_states_to_model=cell_states_to_model,
 state_embs_dict=state_embs_dict,
 max_ncells=2000,
 emb_layer=0,
 forward_batch_size=400,
 nproc=16)

In [None]:
# outputs intermediate files from in silico perturbation
isp.perturb_data("path/to/model",
 "path/to/input_data",
 "path/to/output_directory",
 "output_prefix")

In [None]:
ispstats = InSilicoPerturberStats(mode="goal_state_shift",
 genes_perturbed="all",
 combos=0,
 anchor_gene=None,
 cell_states_to_model=cell_states_to_model)

In [None]:
# extracts data from intermediate files and processes stats to output in final .csv
ispstats.get_stats("path/to/input_data",
 None,
 "path/to/output_directory",
 "output_prefix")