Spaces:
Sleeping
Sleeping
Jensen-holm
commited on
Commit
·
88a3c01
1
Parent(s):
ba1e2db
fixed dimensions issue with the iris dataset
Browse files- neural_network/backprop.py +3 -2
- neural_network/main.py +3 -6
neural_network/backprop.py
CHANGED
@@ -48,8 +48,9 @@ def bp(
|
|
48 |
error * model.func_prime(y_hat),
|
49 |
)
|
50 |
db2 = np.sum(error * model.func_prime(y_hat), axis=0)
|
51 |
-
db1 = np.sum(
|
52 |
-
|
|
|
53 |
|
54 |
# update weights & biases using gradient descent.
|
55 |
# this is -= and not += because if the gradient descent
|
|
|
48 |
error * model.func_prime(y_hat),
|
49 |
)
|
50 |
db2 = np.sum(error * model.func_prime(y_hat), axis=0)
|
51 |
+
db1 = np.sum(
|
52 |
+
np.dot(error * model.func_prime(y_hat), model.w2.T) * model.func_prime(node1), axis=0,
|
53 |
+
)
|
54 |
|
55 |
# update weights & biases using gradient descent.
|
56 |
# this is -= and not += because if the gradient descent
|
neural_network/main.py
CHANGED
@@ -5,10 +5,7 @@ from neural_network.opts import activation
|
|
5 |
from neural_network.backprop import bp
|
6 |
|
7 |
|
8 |
-
def init(
|
9 |
-
X: np.array,
|
10 |
-
hidden_size: int
|
11 |
-
) -> dict:
|
12 |
"""
|
13 |
returns a dictionary containing randomly initialized
|
14 |
weights and biases to start off the neural_network
|
@@ -16,8 +13,8 @@ def init(
|
|
16 |
return {
|
17 |
"w1": np.random.randn(X.shape[1], hidden_size),
|
18 |
"b1": np.zeros((1, hidden_size)),
|
19 |
-
"w2": np.random.randn(hidden_size,
|
20 |
-
"b2": np.zeros((1,
|
21 |
}
|
22 |
|
23 |
|
|
|
5 |
from neural_network.backprop import bp
|
6 |
|
7 |
|
8 |
+
def init(X: np.array, hidden_size: int) -> dict:
|
|
|
|
|
|
|
9 |
"""
|
10 |
returns a dictionary containing randomly initialized
|
11 |
weights and biases to start off the neural_network
|
|
|
13 |
return {
|
14 |
"w1": np.random.randn(X.shape[1], hidden_size),
|
15 |
"b1": np.zeros((1, hidden_size)),
|
16 |
+
"w2": np.random.randn(hidden_size, 3), # Output layer has 3 neurons
|
17 |
+
"b2": np.zeros((1, 3)), # Output layer has 3 neurons
|
18 |
}
|
19 |
|
20 |
|