File size: 4,917 Bytes
e562c0c
 
025e1b8
e562c0c
025e1b8
 
 
 
 
 
 
e562c0c
025e1b8
 
 
 
 
 
 
 
e562c0c
 
 
025e1b8
 
e562c0c
025e1b8
 
e562c0c
025e1b8
e562c0c
 
 
 
 
 
 
 
025e1b8
e562c0c
 
025e1b8
 
e562c0c
025e1b8
 
e562c0c
025e1b8
e562c0c
 
025e1b8
e562c0c
 
 
025e1b8
 
e562c0c
 
025e1b8
e562c0c
 
025e1b8
e562c0c
 
025e1b8
e562c0c
 
025e1b8
e562c0c
 
025e1b8
e562c0c
 
025e1b8
e562c0c
025e1b8
 
e562c0c
 
025e1b8
e562c0c
 
025e1b8
e562c0c
025e1b8
e562c0c
025e1b8
e562c0c
 
025e1b8
e562c0c
025e1b8
 
 
 
 
e562c0c
025e1b8
 
 
 
 
 
 
 
 
 
 
e562c0c
025e1b8
e562c0c
 
 
025e1b8
e562c0c
 
025e1b8
e562c0c
 
025e1b8
e562c0c
 
025e1b8
e562c0c
 
025e1b8
e562c0c
 
025e1b8
e562c0c
 
025e1b8
e562c0c
 
025e1b8
e562c0c
 
025e1b8
 
e562c0c
 
025e1b8
e562c0c
 
025e1b8
e562c0c
 
025e1b8
e562c0c
 
025e1b8
e562c0c
 
025e1b8
 
e562c0c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# Cell classifier
def finetune_cells(token_set = Path('geneformer/token_dictionary.pkl'), median_set = Path('geneformer/gene_median_dictionary.pkl'), pretrained_model = ".",
 dataset = 'Genecorpus-30M/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/',
 dataset_split = None,
  filter_cells = .005,
  epochs = 1,
  cpu_cores = os.cpu_count(),
  geneformer_batch_size = 12,
  optimizer = 'adamw',
  max_lr = 5e-5,
  num_gpus = torch.cuda.device_count(),
  max_input_size = 2 ** 11,
  lr_schedule_fn = "linear",
  warmup_steps = 500,
  freeze_layers = 0,
  emb_extract = False,
  max_cells = 1000,
  emb_layer = 0,
  emb_filter = None,
  emb_dir = 'embeddings',
  overwrite = True,
  label = "cell_type",
  data_filter = None,
  forward_batch = 200, model_location = None,
  skip_training = False,
  sample_data = 1,
   inference = False,
   optimize_hyperparameters = False,
   output_dir = None):

    '''
    Primary Parameters
    -------------------
    dataset: path
        Path to fine-tuning/testing dataset for training

    model_location: path
        Path to location of existing model to use for inference and embedding extraction

    pretrained_model: path
        Path to pretrained GeneFormer 30M model before fine-tuning

    inference: bool
        Chooses whether to perform inference (which causes the function to return the list of similarities). Defaults to False

    skip_training: bool
        Chooses whether to skip training the model. Defaults to False

    emb_extract: bool
        Choose whether to extract embeddings and calculate similarities. Defaults to True

    optimize_hyperparameters: bool
        Choose whether to optimize model hyperparamters. Defaults to False
    label: string
		The label string in the formatted dataset that contains true class labels. Defaults to "label"

    Customization Parameters
    -------------------

    dataset_split: str
        How the dataset should be partitioned (if at all), and what ID should be used for partitioning

    data_filter: list
        (For embeddings and inference) Runs analysis subsets of the dataset by the ID defined by dataset_split

    label: str
        What feature should be read as a classification label

    emb_layer: int
        What layer embeddings should be extracted and compared from.

    emb_filter: ['cell1', 'cell2'...]
        Allows user to narrow down range of cells that embeddings will be extracted from.

    max_cells: int
        How many embeddings from cells should be extracted.

    freeze_layers: int
        Number of layers should be permanently frozen during fine-tuning (starting from the first layer, 4 brings it up to the pretrained model).

    sample_data: float
        What proportion of the HF dataset should be used

    '''

   # Gene Classifier
   def classify_genes(gene_info = "Genecorpus-30M/example_input_files/gene_info_table.csv",
   genes = "Genecorpus-30M/example_input_files/gene_classification/dosage_sensitive_tfs/dosage_sens_tf_labels.csv",
  corpus_30M = "Genecorpus-30M/genecorpus_30M_2048.dataset/", model = '.',
  max_input_size = 2 ** 11,
  max_lr = 5e-5,
  freeze_layers = 4,
  num_gpus = 1,
  num_proc = os.cpu_count(),
  geneformer_batch_size = 9,
  epochs = 1,
  filter_dataset = 50_000,
  emb_extract = True,
  emb_layer = 0,
  forward_batch = 200,
  filter_data = None,
  inference = False,
  k_validate = True,
  model_location = "230917_geneformer_GeneClassifier_dosageTF_L2048_B12_LR5e-05_LSlinear_WU500_E1_Oadamw_n10000_F4/",
  skip_training = False,
  emb_dir = 'gene_emb',
  output_dir = None,
  max_cells = 1000,
  num_cpus = os.cpu_count()):

    """"
    Primary Parameters
    -----------

    gene_info: path
        Path to gene mappings

    corpus_30M: path
        Path to 30M Gene Corpus

    model: path
        Path to pretrained GeneFormer model

    genes: path
        Path to csv file containing different columns of genes and the column labels

    inference: bool
        Whether the model should be used to run inference. If False, model will train with labeled data instead. Defaults to False

    k_validate: bool
        Whether the model should run k-fold validation or simply perform regular training/evaluate. Defaults to True

    skip_training: bool
        Whether the model should skip the training portion. Defaults to False

    emb_extract: bool
        WHether the model should extract embeddings for a given gene (WIP)


    Customization Parameters
    -----------

    freeze_layers: int
        Freezes x number of layers from the model. Default is 4 (2 non-frozen layers)

    filter_dataset: int
        Number of cells to filter from 30M dataset. Default is 50_000

    emb_layer: int
        What layer embeddings are extracted from. Default is 4

    filter_data: str, list
        Filters down embeddings to a single category. Default is None


    """