Spaces:
Sleeping
Sleeping
Jensen Holm
commited on
Commit
•
498c4e0
1
Parent(s):
d7ea050
making the example code cleaner
Browse files- example.py +12 -11
example.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
from sklearn import datasets
|
2 |
from sklearn.preprocessing import OneHotEncoder
|
3 |
from sklearn.model_selection import train_test_split
|
4 |
-
from sklearn.metrics import accuracy_score
|
5 |
import numpy as np
|
6 |
from numpyneuron import (
|
7 |
NN,
|
@@ -14,7 +14,7 @@ from numpyneuron import (
|
|
14 |
RANDOM_SEED = 2
|
15 |
|
16 |
|
17 |
-
def
|
18 |
seed: int,
|
19 |
) -> tuple[np.ndarray, ...]:
|
20 |
digits = datasets.load_digits(as_frame=False)
|
@@ -30,9 +30,10 @@ def _preprocess_digits(
|
|
30 |
return X_train, X_test, y_train, y_test
|
31 |
|
32 |
|
33 |
-
def train_nn_classifier(
|
34 |
-
X_train
|
35 |
-
|
|
|
36 |
nn_classifier = NN(
|
37 |
epochs=2_000,
|
38 |
hidden_size=16,
|
@@ -50,16 +51,16 @@ def train_nn_classifier() -> None:
|
|
50 |
X_train=X_train,
|
51 |
y_train=y_train,
|
52 |
)
|
|
|
|
|
53 |
|
54 |
-
|
|
|
|
|
55 |
|
|
|
56 |
pred = np.argmax(pred, axis=1)
|
57 |
y_test = np.argmax(y_test, axis=1)
|
58 |
|
59 |
accuracy = accuracy_score(y_true=y_test, y_pred=pred)
|
60 |
-
|
61 |
print(f"accuracy on validation set: {accuracy:.4f}")
|
62 |
-
|
63 |
-
|
64 |
-
if __name__ == "__main__":
|
65 |
-
train_nn_classifier()
|
|
|
1 |
from sklearn import datasets
|
2 |
from sklearn.preprocessing import OneHotEncoder
|
3 |
from sklearn.model_selection import train_test_split
|
4 |
+
from sklearn.metrics import accuracy_score
|
5 |
import numpy as np
|
6 |
from numpyneuron import (
|
7 |
NN,
|
|
|
14 |
RANDOM_SEED = 2
|
15 |
|
16 |
|
17 |
+
def preprocess_digits(
|
18 |
seed: int,
|
19 |
) -> tuple[np.ndarray, ...]:
|
20 |
digits = datasets.load_digits(as_frame=False)
|
|
|
30 |
return X_train, X_test, y_train, y_test
|
31 |
|
32 |
|
33 |
+
def train_nn_classifier(
|
34 |
+
X_train: np.ndarray,
|
35 |
+
y_train: np.ndarray,
|
36 |
+
) -> NN:
|
37 |
nn_classifier = NN(
|
38 |
epochs=2_000,
|
39 |
hidden_size=16,
|
|
|
51 |
X_train=X_train,
|
52 |
y_train=y_train,
|
53 |
)
|
54 |
+
return nn_classifier
|
55 |
+
|
56 |
|
57 |
+
if __name__ == "__main__":
|
58 |
+
X_train, X_test, y_train, y_test = preprocess_digits(seed=RANDOM_SEED)
|
59 |
+
classifier = train_nn_classifier(X_train, y_train)
|
60 |
|
61 |
+
pred = classifier.predict(X_test)
|
62 |
pred = np.argmax(pred, axis=1)
|
63 |
y_test = np.argmax(y_test, axis=1)
|
64 |
|
65 |
accuracy = accuracy_score(y_true=y_test, y_pred=pred)
|
|
|
66 |
print(f"accuracy on validation set: {accuracy:.4f}")
|
|
|
|
|
|
|
|