EduardoPacheco commited on
Commit
bf23c3b
1 Parent(s): 9c6a64c

Event listener

Browse files
Files changed (1) hide show
  1. app.py +34 -10
app.py CHANGED
@@ -17,7 +17,7 @@ def plot_validation_curve(x: np.array, ys: list[np.array], yerros: list[np.array
17
  fig.add_trace(
18
  go.Scatter(
19
  x=x,
20
- y=y,
21
  name=name,
22
  line_color=color
23
  )
@@ -39,23 +39,37 @@ def plot_validation_curve(x: np.array, ys: list[np.array], yerros: list[np.array
39
  if log_x:
40
  fig.update_xaxes(type="log")
41
 
42
- fig.update_layout(title=title, xaxis_title="gamma", yaxis_title="Accuracy")
 
 
 
 
 
43
 
44
  return fig
45
 
46
 
47
 
48
- def app_fn(n_points: int):
49
  X, y = load_digits(return_X_y=True)
50
  subset_mask = np.isin(y, [1, 2]) # binary classification: 1 vs 2
51
  X, y = X[subset_mask], y[subset_mask]
52
 
53
- param_range = np.logspace(-6, -1, n_points)
 
 
 
 
 
 
 
 
 
54
  train_scores, test_scores = validation_curve(
55
  SVC(),
56
  X,
57
  y,
58
- param_name="gamma",
59
  param_range=param_range,
60
  scoring="accuracy",
61
  n_jobs=-1,
@@ -72,7 +86,8 @@ def app_fn(n_points: int):
72
  [train_scores_std, test_scores_std],
73
  ["Training score", "Cross-validation score"],
74
  ["orange", "navy"],
75
- title="Validation Curve with SVM for Gamma Hyperparameter"
 
76
  )
77
 
78
  return fig
@@ -90,12 +105,21 @@ with gr.Blocks(title=title) as demo:
90
  [Original Example](https://scikit-learn.org/stable/auto_examples/model_selection/plot_validation_curve.html#sphx-glr-auto-examples-model-selection-plot-validation-curve-py)
91
  """
92
  )
 
 
 
 
93
 
94
- n_points = gr.inputs.Slider(5, 100, 5, 5,label="Number of points")
95
- btn = gr.Button("Run")
96
  fig = gr.Plot(label="Validation Curve")
97
 
98
- btn.click(fn=app_fn, inputs=[n_points], outputs=[fig])
99
- demo.load(fn=app_fn, inputs=[n_points], outputs=[fig])
 
 
 
 
 
 
 
100
 
101
  demo.launch()
 
17
  fig.add_trace(
18
  go.Scatter(
19
  x=x,
20
+ y=np.round(y, 3),
21
  name=name,
22
  line_color=color
23
  )
 
39
  if log_x:
40
  fig.update_xaxes(type="log")
41
 
42
+ fig.update_layout(
43
+ title=title,
44
+ xaxis_title="Hyperparameter",
45
+ yaxis_title="Accuracy",
46
+ hovermode="x unified",
47
+ )
48
 
49
  return fig
50
 
51
 
52
 
53
+ def app_fn(n_points: int, param_name: str):
54
  X, y = load_digits(return_X_y=True)
55
  subset_mask = np.isin(y, [1, 2]) # binary classification: 1 vs 2
56
  X, y = X[subset_mask], y[subset_mask]
57
 
58
+ if param_name=="gamma":
59
+ param_range = np.logspace(-6, -1, n_points)
60
+ log_x = True
61
+ elif param_name=="C":
62
+ param_range = np.logspace(-2, 0, n_points)
63
+ log_x = True
64
+ elif param_name=="kernel":
65
+ param_range = np.array(["rbf", "linear", "poly", "sigmoid"])
66
+ log_x = False
67
+
68
  train_scores, test_scores = validation_curve(
69
  SVC(),
70
  X,
71
  y,
72
+ param_name=param_name,
73
  param_range=param_range,
74
  scoring="accuracy",
75
  n_jobs=-1,
 
86
  [train_scores_std, test_scores_std],
87
  ["Training score", "Cross-validation score"],
88
  ["orange", "navy"],
89
+ title=f"Validation Curve with SVM for {param_name} Hyperparameter",
90
+ log_x=log_x
91
  )
92
 
93
  return fig
 
105
  [Original Example](https://scikit-learn.org/stable/auto_examples/model_selection/plot_validation_curve.html#sphx-glr-auto-examples-model-selection-plot-validation-curve-py)
106
  """
107
  )
108
+ with gr.Row():
109
+ n_points = gr.inputs.Slider(5, 100, 5, 5,label="Number of points")
110
+ param_name = gr.inputs.Dropdown(["gamma", "C", "kernel"], label="Hyperparameter", default="gamma")
111
+
112
 
 
 
113
  fig = gr.Plot(label="Validation Curve")
114
 
115
+ n_points.release(fn=app_fn, inputs=[n_points, param_name], outputs=[fig])
116
+ param_name.change(fn=app_fn, inputs=[n_points, param_name], outputs=[fig])
117
+ # C.change(fn=app_fn, inputs=[n_points, param_name, C, gamma, kernel, degree], outputs=[fig])
118
+ # gamma.change(fn=app_fn, inputs=[n_points, param_name, C, gamma, kernel, degree], outputs=[fig])
119
+ # kernel.change(fn=app_fn, inputs=[n_points, param_name, C, gamma, kernel, degree], outputs=[fig])
120
+ # degree.change(fn=app_fn, inputs=[n_points, param_name, C, gamma, kernel, degree], outputs=[fig])
121
+
122
+
123
+ demo.load(fn=app_fn, inputs=[n_points, param_name], outputs=[fig])
124
 
125
  demo.launch()