Jensen Holm commited on
Commit
498c4e0
1 Parent(s): d7ea050

making the example code cleaner

Browse files
Files changed (1) hide show
  1. example.py +12 -11
example.py CHANGED
@@ -1,7 +1,7 @@
1
  from sklearn import datasets
2
  from sklearn.preprocessing import OneHotEncoder
3
  from sklearn.model_selection import train_test_split
4
- from sklearn.metrics import accuracy_score, precision_score, recall_score
5
  import numpy as np
6
  from numpyneuron import (
7
  NN,
@@ -14,7 +14,7 @@ from numpyneuron import (
14
  RANDOM_SEED = 2
15
 
16
 
17
- def _preprocess_digits(
18
  seed: int,
19
  ) -> tuple[np.ndarray, ...]:
20
  digits = datasets.load_digits(as_frame=False)
@@ -30,9 +30,10 @@ def _preprocess_digits(
30
  return X_train, X_test, y_train, y_test
31
 
32
 
33
- def train_nn_classifier() -> None:
34
- X_train, X_test, y_train, y_test = _preprocess_digits(seed=RANDOM_SEED)
35
-
 
36
  nn_classifier = NN(
37
  epochs=2_000,
38
  hidden_size=16,
@@ -50,16 +51,16 @@ def train_nn_classifier() -> None:
50
  X_train=X_train,
51
  y_train=y_train,
52
  )
 
 
53
 
54
- pred = nn_classifier.predict(X_test=X_test)
 
 
55
 
 
56
  pred = np.argmax(pred, axis=1)
57
  y_test = np.argmax(y_test, axis=1)
58
 
59
  accuracy = accuracy_score(y_true=y_test, y_pred=pred)
60
-
61
  print(f"accuracy on validation set: {accuracy:.4f}")
62
-
63
-
64
- if __name__ == "__main__":
65
- train_nn_classifier()
 
1
  from sklearn import datasets
2
  from sklearn.preprocessing import OneHotEncoder
3
  from sklearn.model_selection import train_test_split
4
+ from sklearn.metrics import accuracy_score
5
  import numpy as np
6
  from numpyneuron import (
7
  NN,
 
14
  RANDOM_SEED = 2
15
 
16
 
17
+ def preprocess_digits(
18
  seed: int,
19
  ) -> tuple[np.ndarray, ...]:
20
  digits = datasets.load_digits(as_frame=False)
 
30
  return X_train, X_test, y_train, y_test
31
 
32
 
33
+ def train_nn_classifier(
34
+ X_train: np.ndarray,
35
+ y_train: np.ndarray,
36
+ ) -> NN:
37
  nn_classifier = NN(
38
  epochs=2_000,
39
  hidden_size=16,
 
51
  X_train=X_train,
52
  y_train=y_train,
53
  )
54
+ return nn_classifier
55
+
56
 
57
+ if __name__ == "__main__":
58
+ X_train, X_test, y_train, y_test = preprocess_digits(seed=RANDOM_SEED)
59
+ classifier = train_nn_classifier(X_train, y_train)
60
 
61
+ pred = classifier.predict(X_test)
62
  pred = np.argmax(pred, axis=1)
63
  y_test = np.argmax(y_test, axis=1)
64
 
65
  accuracy = accuracy_score(y_true=y_test, y_pred=pred)
 
66
  print(f"accuracy on validation set: {accuracy:.4f}")