Christina Theodoris commited on
Commit
0d675a3
1 Parent(s): 316d817

add load model for train and fix validate anchor gene error

Browse files
geneformer/in_silico_perturber.py CHANGED
@@ -263,7 +263,7 @@ class InSilicoPerturber:
263
  "Current valid options for 'perturb_type': 'delete' or 'overexpress'"
264
  )
265
  raise
266
- if (self.combos > 0) and (self.anchor_token is None):
267
  logger.error(
268
  "Combination perturbation without anchor gene is currently under development. "
269
  "Currently, must provide anchor gene for combination perturbation."
@@ -416,7 +416,9 @@ class InSilicoPerturber:
416
  )
417
 
418
  ### load model and define parameters ###
419
- model = pu.load_model(self.model_type, self.num_classes, model_directory)
 
 
420
  self.max_len = pu.get_model_input_size(model)
421
  layer_to_quant = pu.quant_layers(model) + self.emb_layer
422
 
 
263
  "Current valid options for 'perturb_type': 'delete' or 'overexpress'"
264
  )
265
  raise
266
+ if (self.combos > 0) and (self.anchor_gene is None):
267
  logger.error(
268
  "Combination perturbation without anchor gene is currently under development. "
269
  "Currently, must provide anchor gene for combination perturbation."
 
416
  )
417
 
418
  ### load model and define parameters ###
419
+ model = pu.load_model(
420
+ self.model_type, self.num_classes, model_directory, mode="eval"
421
+ )
422
  self.max_len = pu.get_model_input_size(model)
423
  layer_to_quant = pu.quant_layers(model) + self.emb_layer
424
 
geneformer/perturber_utils.py CHANGED
@@ -108,28 +108,36 @@ def slice_by_inds_to_perturb(filtered_input_data, cell_inds_to_perturb):
108
 
109
 
110
  # load model to GPU
111
- def load_model(model_type, num_classes, model_directory):
 
 
 
 
 
112
  if model_type == "Pretrained":
113
  model = BertForMaskedLM.from_pretrained(
114
- model_directory, output_hidden_states=True, output_attentions=False
 
 
115
  )
116
  elif model_type == "GeneClassifier":
117
  model = BertForTokenClassification.from_pretrained(
118
  model_directory,
119
  num_labels=num_classes,
120
- output_hidden_states=True,
121
  output_attentions=False,
122
  )
123
  elif model_type == "CellClassifier":
124
  model = BertForSequenceClassification.from_pretrained(
125
  model_directory,
126
  num_labels=num_classes,
127
- output_hidden_states=True,
128
  output_attentions=False,
129
  )
130
- # put the model in eval mode for fwd pass
131
- model.eval()
132
- model = model.to("cuda:0")
 
133
  return model
134
 
135
 
 
108
 
109
 
110
  # load model to GPU
111
+ def load_model(model_type, num_classes, model_directory, mode):
112
+ if mode == "eval":
113
+ output_hidden_states = True
114
+ elif mode == "train":
115
+ output_hidden_states = False
116
+
117
  if model_type == "Pretrained":
118
  model = BertForMaskedLM.from_pretrained(
119
+ model_directory,
120
+ output_hidden_states=output_hidden_states,
121
+ output_attentions=False,
122
  )
123
  elif model_type == "GeneClassifier":
124
  model = BertForTokenClassification.from_pretrained(
125
  model_directory,
126
  num_labels=num_classes,
127
+ output_hidden_states=output_hidden_states,
128
  output_attentions=False,
129
  )
130
  elif model_type == "CellClassifier":
131
  model = BertForSequenceClassification.from_pretrained(
132
  model_directory,
133
  num_labels=num_classes,
134
+ output_hidden_states=output_hidden_states,
135
  output_attentions=False,
136
  )
137
+ # if eval mode, put the model in eval mode for fwd pass
138
+ if mode == "eval":
139
+ model.eval()
140
+ model = model.to("cuda")
141
  return model
142
 
143