Christina Theodoris commited on
Commit
eeba323
1 Parent(s): f75f5ac

update examples for predict_eval and handle roc for 2 cell classes

Browse files
examples/cell_classification.ipynb CHANGED
@@ -266,8 +266,7 @@
266
  " id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
267
  " output_directory=output_dir,\n",
268
  " output_prefix=output_prefix,\n",
269
- " split_id_dict=train_valid_id_split_dict,\n",
270
- " predict=True)"
271
  ]
272
  },
273
  {
 
266
  " id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
267
  " output_directory=output_dir,\n",
268
  " output_prefix=output_prefix,\n",
269
+ " split_id_dict=train_valid_id_split_dict)"
 
270
  ]
271
  },
272
  {
geneformer/classifier.py CHANGED
@@ -30,7 +30,7 @@ Geneformer classifier.
30
  ... id_class_dict_file=f"path/to/output_directory/{output_prefix}_id_class_dict.pkl",
31
  ... output_directory="path/to/output_directory",
32
  ... output_prefix="output_prefix",
33
- ... predict=True)
34
  >>> cc.plot_conf_mat(conf_mat_dict={"Geneformer": all_metrics["conf_matrix"]},
35
  ... output_directory="path/to/output_directory",
36
  ... output_prefix="output_prefix",
@@ -308,7 +308,7 @@ class Classifier:
308
  output_directory,
309
  output_prefix,
310
  split_id_dict=None,
311
- test_size=0,
312
  attr_to_split=None,
313
  attr_to_balance=None,
314
  max_trials=100,
@@ -417,27 +417,48 @@ class Classifier:
417
  data_dict["test"].save_to_disk(test_data_output_path)
418
  elif (test_size is not None) and (self.classifier == "cell"):
419
  if 1 > test_size > 0:
420
- data_dict, balance_df = cu.balance_attr_splits(
421
- data,
422
- attr_to_split,
423
- attr_to_balance,
424
- test_size,
425
- max_trials,
426
- pval_threshold,
427
- self.cell_state_dict["state_key"],
428
- self.nproc,
429
- )
430
- balance_df.to_csv(
431
- f"{output_directory}/{output_prefix}_train_test_balance_df.csv"
432
- )
433
- train_data_output_path = (
434
- Path(output_directory) / f"{output_prefix}_labeled_train"
435
- ).with_suffix(".dataset")
436
- test_data_output_path = (
437
- Path(output_directory) / f"{output_prefix}_labeled_test"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
438
  ).with_suffix(".dataset")
439
- data_dict["train"].save_to_disk(train_data_output_path)
440
- data_dict["test"].save_to_disk(test_data_output_path)
441
  else:
442
  data_output_path = (
443
  Path(output_directory) / f"{output_prefix}_labeled"
@@ -1012,7 +1033,7 @@ class Classifier:
1012
  model = pu.load_model(model_type, num_classes, model_directory, "eval")
1013
 
1014
  # evaluate the model
1015
- results = self.evaluate_model(
1016
  model,
1017
  num_classes,
1018
  id_class_dict,
@@ -1023,24 +1044,21 @@ class Classifier:
1023
  )
1024
 
1025
  all_conf_mat_df = pd.DataFrame(
1026
- results["conf_mat"],
1027
  columns=id_class_dict.values(),
1028
  index=id_class_dict.values(),
1029
  )
1030
  all_metrics = {
1031
  "conf_matrix": all_conf_mat_df,
1032
- "macro_f1": results["macro_f1"],
1033
- "acc": results["acc"],
1034
  }
1035
  all_roc_metrics = None # roc metrics not reported for multiclass
 
1036
  if num_classes == 2:
1037
  mean_fpr = np.linspace(0, 1, 100)
1038
- all_tpr = [result["roc_metrics"]["interp_tpr"] for result in results]
1039
- all_roc_auc = [result["roc_metrics"]["auc"] for result in results]
1040
- all_tpr_wt = [result["roc_metrics"]["tpr_wt"] for result in results]
1041
- mean_tpr, roc_auc, roc_auc_sd = eu.get_cross_valid_roc_metrics(
1042
- all_tpr, all_roc_auc, all_tpr_wt
1043
- )
1044
  all_roc_metrics = {
1045
  "mean_tpr": mean_tpr,
1046
  "mean_fpr": mean_fpr,
@@ -1137,7 +1155,7 @@ class Classifier:
1137
 
1138
  predictions_file : path
1139
  | Path of model predictions output to plot
1140
- | (saved output from self.validate if predict=True)
1141
  | (or saved output from self.evaluate_saved_model)
1142
  id_class_dict_file : Path
1143
  | Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
@@ -1173,7 +1191,7 @@ class Classifier:
1173
  predictions_logits = np.array(predictions["predictions"])
1174
  true_ids = predictions["label_ids"]
1175
  else:
1176
- # format is output from self.validate if predict=True
1177
  predictions_logits = predictions.predictions
1178
  true_ids = predictions.label_ids
1179
 
 
30
  ... id_class_dict_file=f"path/to/output_directory/{output_prefix}_id_class_dict.pkl",
31
  ... output_directory="path/to/output_directory",
32
  ... output_prefix="output_prefix",
33
+ ... predict_eval=True)
34
  >>> cc.plot_conf_mat(conf_mat_dict={"Geneformer": all_metrics["conf_matrix"]},
35
  ... output_directory="path/to/output_directory",
36
  ... output_prefix="output_prefix",
 
308
  output_directory,
309
  output_prefix,
310
  split_id_dict=None,
311
+ test_size=None,
312
  attr_to_split=None,
313
  attr_to_balance=None,
314
  max_trials=100,
 
417
  data_dict["test"].save_to_disk(test_data_output_path)
418
  elif (test_size is not None) and (self.classifier == "cell"):
419
  if 1 > test_size > 0:
420
+ if attr_to_split is None:
421
+ data_dict = data.train_test_split(
422
+ test_size=test_size,
423
+ stratify_by_column=self.stratify_splits_col,
424
+ seed=42,
425
+ )
426
+ train_data_output_path = (
427
+ Path(output_directory) / f"{output_prefix}_labeled_train"
428
+ ).with_suffix(".dataset")
429
+ test_data_output_path = (
430
+ Path(output_directory) / f"{output_prefix}_labeled_test"
431
+ ).with_suffix(".dataset")
432
+ data_dict["train"].save_to_disk(train_data_output_path)
433
+ data_dict["test"].save_to_disk(test_data_output_path)
434
+ else:
435
+ data_dict, balance_df = cu.balance_attr_splits(
436
+ data,
437
+ attr_to_split,
438
+ attr_to_balance,
439
+ test_size,
440
+ max_trials,
441
+ pval_threshold,
442
+ self.cell_state_dict["state_key"],
443
+ self.nproc,
444
+ )
445
+ balance_df.to_csv(
446
+ f"{output_directory}/{output_prefix}_train_test_balance_df.csv"
447
+ )
448
+ train_data_output_path = (
449
+ Path(output_directory) / f"{output_prefix}_labeled_train"
450
+ ).with_suffix(".dataset")
451
+ test_data_output_path = (
452
+ Path(output_directory) / f"{output_prefix}_labeled_test"
453
+ ).with_suffix(".dataset")
454
+ data_dict["train"].save_to_disk(train_data_output_path)
455
+ data_dict["test"].save_to_disk(test_data_output_path)
456
+ else:
457
+ data_output_path = (
458
+ Path(output_directory) / f"{output_prefix}_labeled"
459
  ).with_suffix(".dataset")
460
+ data.save_to_disk(data_output_path)
461
+ print(data_output_path)
462
  else:
463
  data_output_path = (
464
  Path(output_directory) / f"{output_prefix}_labeled"
 
1033
  model = pu.load_model(model_type, num_classes, model_directory, "eval")
1034
 
1035
  # evaluate the model
1036
+ result = self.evaluate_model(
1037
  model,
1038
  num_classes,
1039
  id_class_dict,
 
1044
  )
1045
 
1046
  all_conf_mat_df = pd.DataFrame(
1047
+ result["conf_mat"],
1048
  columns=id_class_dict.values(),
1049
  index=id_class_dict.values(),
1050
  )
1051
  all_metrics = {
1052
  "conf_matrix": all_conf_mat_df,
1053
+ "macro_f1": result["macro_f1"],
1054
+ "acc": result["acc"],
1055
  }
1056
  all_roc_metrics = None # roc metrics not reported for multiclass
1057
+
1058
  if num_classes == 2:
1059
  mean_fpr = np.linspace(0, 1, 100)
1060
+ mean_tpr = result["roc_metrics"]["interp_tpr"]
1061
+ all_roc_auc = result["roc_metrics"]["auc"]
 
 
 
 
1062
  all_roc_metrics = {
1063
  "mean_tpr": mean_tpr,
1064
  "mean_fpr": mean_fpr,
 
1155
 
1156
  predictions_file : path
1157
  | Path of model predictions output to plot
1158
+ | (saved output from self.validate if predict_eval=True)
1159
  | (or saved output from self.evaluate_saved_model)
1160
  id_class_dict_file : Path
1161
  | Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
 
1191
  predictions_logits = np.array(predictions["predictions"])
1192
  true_ids = predictions["label_ids"]
1193
  else:
1194
+ # format is output from self.validate if predict_eval=True
1195
  predictions_logits = predictions.predictions
1196
  true_ids = predictions.label_ids
1197
 
geneformer/evaluation_utils.py CHANGED
@@ -201,10 +201,10 @@ def plot_ROC(roc_metric_dict, model_style_dict, title, output_dir, output_prefix
201
  plt.ylabel("True Positive Rate")
202
  plt.title(title)
203
  plt.legend(loc="lower right")
204
- plt.show()
205
 
206
  output_file = (Path(output_dir) / f"{output_prefix}_roc").with_suffix(".pdf")
207
  plt.savefig(output_file, bbox_inches="tight")
 
208
 
209
 
210
  # plot confusion matrix
 
201
  plt.ylabel("True Positive Rate")
202
  plt.title(title)
203
  plt.legend(loc="lower right")
 
204
 
205
  output_file = (Path(output_dir) / f"{output_prefix}_roc").with_suffix(".pdf")
206
  plt.savefig(output_file, bbox_inches="tight")
207
+ plt.show()
208
 
209
 
210
  # plot confusion matrix