Jensen-holm commited on
Commit
9117cbd
1 Parent(s): de85694

getting rid of regression tab, thinking about having a custom data tab

Browse files

where users can upload their own dataset and hyper parameter tune on
that. and making the classification tab the Classification example tab.

Files changed (1) hide show
  1. app.py +6 -9
app.py CHANGED
@@ -34,21 +34,21 @@ X_train, X_test, y_train, y_test = _preprocess_digits(seed=1)
34
 
35
  def classification(
36
  seed: int,
37
- hidden_layer_activation_fn: str,
38
- output_layer_activation_fn: str,
39
  loss_fn_str: str,
40
  epochs: int,
41
  hidden_size: int,
42
  batch_size: float,
43
  learning_rate: float,
44
  ) -> tuple[gr.Plot, gr.Plot, gr.Label]:
45
- assert hidden_layer_activation_fn in nn.ACTIVATIONS
46
- assert output_layer_activation_fn in nn.ACTIVATIONS
47
  assert loss_fn_str in nn.LOSSES
48
 
49
  loss_fn: nn.Loss = nn.LOSSES[loss_fn_str]
50
- h_act_fn: nn.Activation = nn.ACTIVATIONS[hidden_layer_activation_fn]
51
- o_act_fn: nn.Activation = nn.ACTIVATIONS[output_layer_activation_fn]
52
 
53
  nn_classifier = nn.NN(
54
  epochs=epochs,
@@ -164,7 +164,4 @@ if __name__ == "__main__":
164
  outputs=plt_outputs + label_output,
165
  )
166
 
167
- with gr.Tab("Regression"):
168
- gr.Markdown("### Coming Soon")
169
-
170
  interface.launch(show_error=True)
 
34
 
35
  def classification(
36
  seed: int,
37
+ hidden_layer_activation_fn_str: str,
38
+ output_layer_activation_fn_str: str,
39
  loss_fn_str: str,
40
  epochs: int,
41
  hidden_size: int,
42
  batch_size: float,
43
  learning_rate: float,
44
  ) -> tuple[gr.Plot, gr.Plot, gr.Label]:
45
+ assert hidden_layer_activation_fn_str in nn.ACTIVATIONS
46
+ assert output_layer_activation_fn_str in nn.ACTIVATIONS
47
  assert loss_fn_str in nn.LOSSES
48
 
49
  loss_fn: nn.Loss = nn.LOSSES[loss_fn_str]
50
+ h_act_fn: nn.Activation = nn.ACTIVATIONS[hidden_layer_activation_fn_str]
51
+ o_act_fn: nn.Activation = nn.ACTIVATIONS[output_layer_activation_fn_str]
52
 
53
  nn_classifier = nn.NN(
54
  epochs=epochs,
 
164
  outputs=plt_outputs + label_output,
165
  )
166
 
 
 
 
167
  interface.launch(show_error=True)