cmpatino commited on
Commit
39ac7ff
·
1 Parent(s): 794f08d

Add ROC-AUC score for each feature

Browse files
Files changed (2) hide show
  1. app.py +18 -3
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,7 +1,9 @@
1
  import gradio as gr
 
2
  import matplotlib.pyplot as plt
3
  import seaborn as sns
4
 
 
5
  from datasets import load_dataset
6
 
7
  import histos
@@ -16,9 +18,22 @@ def get_plot(features, n_bins):
16
  plotting_df = dataset_df.copy()
17
  if len(features) == 1:
18
  fig, ax = plt.subplots()
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  values = [
20
- plotting_df[plotting_df["target"] == "spin-ON"][features[0]],
21
- plotting_df[plotting_df["target"] == "spin-OFF"][features[0]],
22
  ]
23
  labels = ["spin-ON", "spin-OFF"]
24
  fig = histos.ratio_hist(
@@ -27,7 +42,7 @@ def get_plot(features, n_bins):
27
  reference_label=labels[1],
28
  n_bins=n_bins,
29
  hist_range=None,
30
- title=features[0],
31
  )
32
  return fig
33
  if len(features) == 2:
 
1
  import gradio as gr
2
+ import numpy as np
3
  import matplotlib.pyplot as plt
4
  import seaborn as sns
5
 
6
+ from sklearn import metrics
7
  from datasets import load_dataset
8
 
9
  import histos
 
18
  plotting_df = dataset_df.copy()
19
  if len(features) == 1:
20
  fig, ax = plt.subplots()
21
+ pos_samples = plotting_df[plotting_df["target"] == "spin-ON"][features[0]]
22
+ neg_samples = plotting_df[plotting_df["target"] == "spin-OFF"][features[0]]
23
+ y_score = np.concatenate([pos_samples, neg_samples], axis=0)
24
+ if pos_samples.mean() >= neg_samples.mean():
25
+ y_true = np.concatenate(
26
+ [np.ones_like(pos_samples), np.zeros_like(neg_samples)], axis=0
27
+ )
28
+ roc_auc_score = metrics.roc_auc_score(y_true, y_score)
29
+ else:
30
+ y_true = np.concatenate(
31
+ [np.zeros_like(pos_samples), np.ones_like(neg_samples)], axis=0
32
+ )
33
+ roc_auc_score = metrics.roc_auc_score(y_true, y_score)
34
  values = [
35
+ pos_samples,
36
+ neg_samples,
37
  ]
38
  labels = ["spin-ON", "spin-OFF"]
39
  fig = histos.ratio_hist(
 
42
  reference_label=labels[1],
43
  n_bins=n_bins,
44
  hist_range=None,
45
+ title=f"{features[0]} (ROC AUC: {roc_auc_score:.3f})",
46
  )
47
  return fig
48
  if len(features) == 2:
requirements.txt CHANGED
@@ -1,2 +1,3 @@
1
  matplotlib==3.7.1
 
2
  seaborn==0.12.2
 
1
  matplotlib==3.7.1
2
+ scikit-learn==1.2.2
3
  seaborn==0.12.2