Spaces:
Sleeping
Sleeping
Commit
·
0a04cd7
1
Parent(s):
59ae052
mapping activation arg to the actual function
Browse files- nn/backprop.go +8 -0
- nn/main.go +27 -35
- nn/split.go +6 -4
- nn/subset.go +1 -1
- nn/train.go +13 -0
- server.go +0 -4
nn/backprop.go
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
package nn
|
2 |
+
|
3 |
+
func (nn *NN) Backprop() {
|
4 |
+
|
5 |
+
for i := 0; i < nn.Epochs; i++ {
|
6 |
+
}
|
7 |
+
|
8 |
+
}
|
nn/main.go
CHANGED
@@ -10,24 +10,27 @@ import (
|
|
10 |
)
|
11 |
|
12 |
type NN struct {
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
|
|
21 |
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
|
|
|
|
31 |
}
|
32 |
|
33 |
func NewNN(c *fiber.Ctx) (*NN, error) {
|
@@ -37,24 +40,13 @@ func NewNN(c *fiber.Ctx) (*NN, error) {
|
|
37 |
return nil, fmt.Errorf("invalid JSON data: %v", err)
|
38 |
}
|
39 |
df := dataframe.ReadCSV(strings.NewReader(newNN.CSVData))
|
|
|
|
|
40 |
newNN.Df = &df
|
|
|
41 |
return newNN, nil
|
42 |
}
|
43 |
|
44 |
-
func (nn *NN) Train() {
|
45 |
-
// train test split the data
|
46 |
-
_, _, _, _ = nn.trainTestSplit()
|
47 |
-
|
48 |
-
nn.InitWnB()
|
49 |
-
|
50 |
-
// iterate n times where n = nn.Epochs
|
51 |
-
// use backprop algorithm on each iteration
|
52 |
-
// to fit the model to the data
|
53 |
-
for i := 0; i < nn.Epochs; i++ {
|
54 |
-
}
|
55 |
-
|
56 |
-
}
|
57 |
-
|
58 |
func (nn *NN) InitWnB() {
|
59 |
// randomly initialize weights and biases to start
|
60 |
inputSize := len(nn.Features)
|
@@ -89,8 +81,8 @@ func (nn *NN) InitWnB() {
|
|
89 |
bo[i] = rand.Float64() - 0.5
|
90 |
}
|
91 |
|
92 |
-
nn.Wh = wh
|
93 |
-
nn.Bh = bh
|
94 |
-
nn.Wo = wo
|
95 |
-
nn.Bo = bo
|
96 |
}
|
|
|
10 |
)
|
11 |
|
12 |
type NN struct {
|
13 |
+
// attributes set by request
|
14 |
+
CSVData string `json:"csv_data"`
|
15 |
+
Features []string `json:"features"`
|
16 |
+
Target string `json:"target"`
|
17 |
+
Epochs int `json:"epochs"`
|
18 |
+
HiddenSize int `json:"hidden_size"`
|
19 |
+
LearningRate float64 `json:"learning_rate"`
|
20 |
+
Activation string `json:"activation"`
|
21 |
+
TestSize float64 `json:"test_size"`
|
22 |
|
23 |
+
// attributes set after args above are parsed
|
24 |
+
ActivationFunc *func(float64) float64
|
25 |
+
Df *dataframe.DataFrame
|
26 |
+
XTrain *dataframe.DataFrame
|
27 |
+
YTrain *dataframe.DataFrame
|
28 |
+
XTest *dataframe.DataFrame
|
29 |
+
YTest *dataframe.DataFrame
|
30 |
+
Wh *[][]float64
|
31 |
+
Bh *[]float64
|
32 |
+
Wo *[][]float64
|
33 |
+
Bo *[]float64
|
34 |
}
|
35 |
|
36 |
func NewNN(c *fiber.Ctx) (*NN, error) {
|
|
|
40 |
return nil, fmt.Errorf("invalid JSON data: %v", err)
|
41 |
}
|
42 |
df := dataframe.ReadCSV(strings.NewReader(newNN.CSVData))
|
43 |
+
activation := ActivationMap[newNN.Activation]
|
44 |
+
|
45 |
newNN.Df = &df
|
46 |
+
newNN.ActivationFunc = &activation
|
47 |
return newNN, nil
|
48 |
}
|
49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
func (nn *NN) InitWnB() {
|
51 |
// randomly initialize weights and biases to start
|
52 |
inputSize := len(nn.Features)
|
|
|
81 |
bo[i] = rand.Float64() - 0.5
|
82 |
}
|
83 |
|
84 |
+
nn.Wh = &wh
|
85 |
+
nn.Bh = &bh
|
86 |
+
nn.Wo = &wo
|
87 |
+
nn.Bo = &bo
|
88 |
}
|
nn/split.go
CHANGED
@@ -3,11 +3,9 @@ package nn
|
|
3 |
import (
|
4 |
"math"
|
5 |
"math/rand"
|
6 |
-
|
7 |
-
"github.com/go-gota/gota/dataframe"
|
8 |
)
|
9 |
|
10 |
-
func (nn *NN)
|
11 |
// now we split the data into training
|
12 |
// and testing based on user specified
|
13 |
// nn.TestSize.
|
@@ -36,5 +34,9 @@ func (nn *NN) trainTestSplit() (dataframe.DataFrame, dataframe.DataFrame, datafr
|
|
36 |
XTest := test.Select(nn.Features)
|
37 |
YTest := test.Select(nn.Target)
|
38 |
|
39 |
-
|
|
|
|
|
|
|
|
|
40 |
}
|
|
|
3 |
import (
|
4 |
"math"
|
5 |
"math/rand"
|
|
|
|
|
6 |
)
|
7 |
|
8 |
+
func (nn *NN) TrainTestSplit() {
|
9 |
// now we split the data into training
|
10 |
// and testing based on user specified
|
11 |
// nn.TestSize.
|
|
|
34 |
XTest := test.Select(nn.Features)
|
35 |
YTest := test.Select(nn.Target)
|
36 |
|
37 |
+
nn.XTrain = &XTrain
|
38 |
+
nn.YTrain = &YTrain
|
39 |
+
nn.XTest = &XTest
|
40 |
+
nn.YTest = &YTest
|
41 |
+
|
42 |
}
|
nn/subset.go
CHANGED
@@ -3,4 +3,4 @@ package nn
|
|
3 |
// subset the data frame into just the
|
4 |
// features and target that the user specify
|
5 |
|
6 |
-
func
|
|
|
3 |
// subset the data frame into just the
|
4 |
// features and target that the user specify
|
5 |
|
6 |
+
func Subset() {}
|
nn/train.go
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
package nn
|
2 |
+
|
3 |
+
func (nn *NN) Train() {
|
4 |
+
nn.InitWnB()
|
5 |
+
nn.TrainTestSplit()
|
6 |
+
|
7 |
+
// iterate n times where n = nn.Epochs
|
8 |
+
// use backprop algorithm on each iteration
|
9 |
+
// to fit the model to the data
|
10 |
+
for i := 0; i < nn.Epochs; i++ {
|
11 |
+
}
|
12 |
+
|
13 |
+
}
|
server.go
CHANGED
@@ -1,8 +1,6 @@
|
|
1 |
package main
|
2 |
|
3 |
import (
|
4 |
-
"fmt"
|
5 |
-
|
6 |
"github.com/Jensen-holm/ml-from-scratch/nn"
|
7 |
"github.com/gofiber/fiber/v2"
|
8 |
)
|
@@ -13,7 +11,6 @@ func main() {
|
|
13 |
// eventually we might want to add a key to this endpoint
|
14 |
// that we will be able to validate.
|
15 |
app.Post("/neural-network", func(c *fiber.Ctx) error {
|
16 |
-
|
17 |
nn, err := nn.NewNN(c)
|
18 |
if err != nil {
|
19 |
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
|
@@ -22,7 +19,6 @@ func main() {
|
|
22 |
}
|
23 |
|
24 |
nn.Train()
|
25 |
-
fmt.Println(nn.Wo)
|
26 |
|
27 |
return c.SendString("No error")
|
28 |
})
|
|
|
1 |
package main
|
2 |
|
3 |
import (
|
|
|
|
|
4 |
"github.com/Jensen-holm/ml-from-scratch/nn"
|
5 |
"github.com/gofiber/fiber/v2"
|
6 |
)
|
|
|
11 |
// eventually we might want to add a key to this endpoint
|
12 |
// that we will be able to validate.
|
13 |
app.Post("/neural-network", func(c *fiber.Ctx) error {
|
|
|
14 |
nn, err := nn.NewNN(c)
|
15 |
if err != nil {
|
16 |
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
|
|
|
19 |
}
|
20 |
|
21 |
nn.Train()
|
|
|
22 |
|
23 |
return c.SendString("No error")
|
24 |
})
|