hchen725 commited on
Commit
cec6545
1 Parent(s): 9154763

Update geneformer/perturber_utils.py

Browse files
Files changed (1) hide show
  1. geneformer/perturber_utils.py +4 -2
geneformer/perturber_utils.py CHANGED
@@ -155,9 +155,11 @@ def quant_layers(model):
155
  layer_nums += [int(name.split("layer.")[1].split(".")[0])]
156
  return int(max(layer_nums)) + 1
157
 
158
- def get_model_embedding_dimensions(model):
159
- return int(re.split("\(|,", str(model.bert.embeddings.position_embeddings))[2].strip().replace(")", ""))
160
 
 
 
161
 
162
  def get_model_input_size(model):
163
  return int(re.split("\(|,", str(model.bert.embeddings.position_embeddings))[1])
 
155
  layer_nums += [int(name.split("layer.")[1].split(".")[0])]
156
  return int(max(layer_nums)) + 1
157
 
158
+ def get_model_emb_dims(model):
159
+ return model.config.hidden_size
160
 
161
+ def get_model_input_size(model):
162
+ return model.config.max_position_embeddings
163
 
164
  def get_model_input_size(model):
165
  return int(re.split("\(|,", str(model.bert.embeddings.position_embeddings))[1])