a tiny bug fix missing default_training_args

#320
Files changed (1) hide show
  1. geneformer/classifier_utils.py +4 -0
geneformer/classifier_utils.py CHANGED
@@ -387,6 +387,10 @@ def get_default_train_args(model, classifier, data, output_dir):
387
  "per_device_train_batch_size": batch_size,
388
  "per_device_eval_batch_size": batch_size,
389
  }
 
 
 
 
390
 
391
  training_args = {
392
  "num_train_epochs": epochs,
 
387
  "per_device_train_batch_size": batch_size,
388
  "per_device_eval_batch_size": batch_size,
389
  }
390
+ else:
391
+ default_training_args = {
392
+ "per_device_train_batch_size": batch_size,
393
+ }
394
 
395
  training_args = {
396
  "num_train_epochs": epochs,