Spaces:
Sleeping
Sleeping
Jensen-holm
commited on
Commit
·
29cce3f
1
Parent(s):
2c781f8
working on new python implementation with cleaner code
Browse files- app.py +27 -0
- example/main.py +3 -2
- go.mod +0 -24
- go.sum +0 -109
- nn/activation.go +0 -41
- nn/args.go +0 -17
- nn/backprop.go +0 -99
- nn/backprop.py +0 -0
- nn/main.go +0 -84
- nn/nn.py +32 -0
- nn/split.go +0 -60
- nn/subset.go +0 -6
- nn/train.go +0 -7
- nn/train.py +15 -0
- requirements.txt +3 -0
- server.go +0 -26
app.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flask import Flask, request, jsonify
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
from nn.nn import NN
|
5 |
+
from nn import train as train_nn
|
6 |
+
|
7 |
+
app = Flask(__name__)
|
8 |
+
|
9 |
+
|
10 |
+
@app.route("/neural-network", methods=["POST"])
|
11 |
+
def neural_net():
|
12 |
+
args = request.json
|
13 |
+
|
14 |
+
try:
|
15 |
+
net = NN.from_dict(args)
|
16 |
+
df = pd.read_csv(args.pop("data"))
|
17 |
+
except Exception as e:
|
18 |
+
return jsonify({
|
19 |
+
"bad request": f"could not read csv data: {e}",
|
20 |
+
})
|
21 |
+
|
22 |
+
result = train_nn(nn=net)
|
23 |
+
return jsonify(result)
|
24 |
+
|
25 |
+
|
26 |
+
if __name__ == "__main__":
|
27 |
+
app.run(debug=True)
|
example/main.py
CHANGED
@@ -7,10 +7,11 @@ ARGS = {
|
|
7 |
"epochs": 100,
|
8 |
"hidden_size": 12,
|
9 |
"learning_rate": 0.01,
|
|
|
10 |
"activation": "tanh",
|
11 |
"features": ["sepal width", "sepal length", "petal width", "petal length"],
|
12 |
"target": "species",
|
13 |
-
"data": iris_data.decode(
|
14 |
}
|
15 |
|
16 |
r = requests.post(
|
@@ -19,4 +20,4 @@ r = requests.post(
|
|
19 |
)
|
20 |
|
21 |
if __name__ == "__main__":
|
22 |
-
print(r.json())
|
|
|
7 |
"epochs": 100,
|
8 |
"hidden_size": 12,
|
9 |
"learning_rate": 0.01,
|
10 |
+
"test_size": 0.3,
|
11 |
"activation": "tanh",
|
12 |
"features": ["sepal width", "sepal length", "petal width", "petal length"],
|
13 |
"target": "species",
|
14 |
+
"data": iris_data.decode("utf-8"),
|
15 |
}
|
16 |
|
17 |
r = requests.post(
|
|
|
20 |
)
|
21 |
|
22 |
if __name__ == "__main__":
|
23 |
+
print(r.json())
|
go.mod
DELETED
@@ -1,24 +0,0 @@
|
|
1 |
-
module github.com/Jensen-holm/ml-from-scratch
|
2 |
-
|
3 |
-
go 1.19
|
4 |
-
|
5 |
-
require (
|
6 |
-
github.com/go-gota/gota v0.12.0
|
7 |
-
github.com/gofiber/fiber/v2 v2.49.2
|
8 |
-
)
|
9 |
-
|
10 |
-
require (
|
11 |
-
github.com/andybalholm/brotli v1.0.5 // indirect
|
12 |
-
github.com/google/uuid v1.3.1 // indirect
|
13 |
-
github.com/klauspost/compress v1.16.7 // indirect
|
14 |
-
github.com/mattn/go-colorable v0.1.13 // indirect
|
15 |
-
github.com/mattn/go-isatty v0.0.19 // indirect
|
16 |
-
github.com/mattn/go-runewidth v0.0.15 // indirect
|
17 |
-
github.com/rivo/uniseg v0.2.0 // indirect
|
18 |
-
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
19 |
-
github.com/valyala/fasthttp v1.49.0 // indirect
|
20 |
-
github.com/valyala/tcplisten v1.0.0 // indirect
|
21 |
-
golang.org/x/net v0.17.0 // indirect
|
22 |
-
golang.org/x/sys v0.13.0 // indirect
|
23 |
-
gonum.org/v1/gonum v0.14.0 // indirect
|
24 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
go.sum
DELETED
@@ -1,109 +0,0 @@
|
|
1 |
-
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
|
2 |
-
gioui.org v0.0.0-20210308172011-57750fc8a0a6/go.mod h1:RSH6KIUZ0p2xy5zHDxgAM4zumjgTw83q2ge/PI+yyw8=
|
3 |
-
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
|
4 |
-
github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw=
|
5 |
-
github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/cCs=
|
6 |
-
github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
|
7 |
-
github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
|
8 |
-
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
9 |
-
github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
|
10 |
-
github.com/fogleman/gg v1.3.0/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
|
11 |
-
github.com/go-fonts/dejavu v0.1.0/go.mod h1:4Wt4I4OU2Nq9asgDCteaAaWZOV24E+0/Pwo0gppep4g=
|
12 |
-
github.com/go-fonts/latin-modern v0.2.0/go.mod h1:rQVLdDMK+mK1xscDwsqM5J8U2jrRa3T0ecnM9pNujks=
|
13 |
-
github.com/go-fonts/liberation v0.1.1/go.mod h1:K6qoJYypsmfVjWg8KOVDQhLc8UDgIK2HYqyqAO9z7GY=
|
14 |
-
github.com/go-fonts/stix v0.1.0/go.mod h1:w/c1f0ldAUlJmLBvlbkvVXLAD+tAMqobIIQpmnUIzUY=
|
15 |
-
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
|
16 |
-
github.com/go-gota/gota v0.12.0 h1:T5BDg1hTf5fZ/CO+T/N0E+DDqUhvoKBl+UVckgcAAQg=
|
17 |
-
github.com/go-gota/gota v0.12.0/go.mod h1:UT+NsWpZC/FhaOyWb9Hui0jXg0Iq8e/YugZHTbyW/34=
|
18 |
-
github.com/go-latex/latex v0.0.0-20210118124228-b3d85cf34e07/go.mod h1:CO1AlKB2CSIqUrmQPqA0gdRIlnLEY0gK5JGjh37zN5U=
|
19 |
-
github.com/gofiber/fiber/v2 v2.49.2 h1:ONEN3/Vc+dUCxxDgZZwpqvhISgHqb+bu+isBiEyKEQs=
|
20 |
-
github.com/gofiber/fiber/v2 v2.49.2/go.mod h1:gNsKnyrmfEWFpJxQAV0qvW6l70K1dZGno12oLtukcts=
|
21 |
-
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k=
|
22 |
-
github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4=
|
23 |
-
github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
24 |
-
github.com/jung-kurt/gofpdf v1.0.0/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
|
25 |
-
github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
|
26 |
-
github.com/klauspost/compress v1.16.7 h1:2mk3MPGNzKyxErAw8YaohYh69+pa4sIQSC0fPGCFR9I=
|
27 |
-
github.com/klauspost/compress v1.16.7/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE=
|
28 |
-
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
|
29 |
-
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
30 |
-
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
31 |
-
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
|
32 |
-
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
33 |
-
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
|
34 |
-
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
35 |
-
github.com/phpdave11/gofpdf v1.4.2/go.mod h1:zpO6xFn9yxo3YLyMvW8HcKWVdbNqgIfOOp2dXMnm1mY=
|
36 |
-
github.com/phpdave11/gofpdi v1.0.12/go.mod h1:vBmVV0Do6hSBHC8uKUQ71JGW+ZGQq74llk/7bXwjDoI=
|
37 |
-
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
38 |
-
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
39 |
-
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
40 |
-
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
41 |
-
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
42 |
-
github.com/ruudk/golang-pdf417 v0.0.0-20181029194003-1af4ab5afa58/go.mod h1:6lfFZQK844Gfx8o5WFuvpxWRwnSoipWe/p622j1v06w=
|
43 |
-
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
44 |
-
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
45 |
-
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
46 |
-
github.com/valyala/fasthttp v1.49.0 h1:9FdvCpmxB74LH4dPb7IJ1cOSsluR07XG3I1txXWwJpE=
|
47 |
-
github.com/valyala/fasthttp v1.49.0/go.mod h1:k2zXd82h/7UZc3VOdJ2WaUqt1uZ/XpXAfE9i+HBC3lA=
|
48 |
-
github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8=
|
49 |
-
github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc=
|
50 |
-
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
51 |
-
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
52 |
-
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
53 |
-
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
54 |
-
golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
55 |
-
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
56 |
-
golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3 h1:n9HxLrNxWWtEb1cA950nuEEj3QnKbtsCJ6KjcgisNUs=
|
57 |
-
golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3/go.mod h1:NOZ3BPKG0ec/BKJQgnvsSFpcKLM5xXVWnvZS97DWHgE=
|
58 |
-
golang.org/x/exp v0.0.0-20230321023759-10a507213a29 h1:ooxPy7fPvB4kwsA2h+iBNHkAbp/4JxTSwCmvdjEYmug=
|
59 |
-
golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs=
|
60 |
-
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
|
61 |
-
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
62 |
-
golang.org/x/image v0.0.0-20190910094157-69e4b8554b2a/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
63 |
-
golang.org/x/image v0.0.0-20200119044424-58c23975cae1/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
64 |
-
golang.org/x/image v0.0.0-20200430140353-33d19683fad8/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
65 |
-
golang.org/x/image v0.0.0-20200618115811-c13761719519/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
66 |
-
golang.org/x/image v0.0.0-20201208152932-35266b937fa6/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
67 |
-
golang.org/x/image v0.0.0-20210216034530-4410531fe030/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
68 |
-
golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o=
|
69 |
-
golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY=
|
70 |
-
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
71 |
-
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
72 |
-
golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk=
|
73 |
-
golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ=
|
74 |
-
golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
|
75 |
-
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
|
76 |
-
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
|
77 |
-
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
78 |
-
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
79 |
-
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
80 |
-
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
81 |
-
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
82 |
-
golang.org/x/sys v0.0.0-20210304124612-50617c2ba197/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
83 |
-
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
84 |
-
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
85 |
-
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
86 |
-
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
|
87 |
-
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
88 |
-
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
|
89 |
-
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
90 |
-
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
91 |
-
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
92 |
-
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
93 |
-
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
94 |
-
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
95 |
-
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
96 |
-
golang.org/x/tools v0.0.0-20190206041539-40960b6deb8e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
97 |
-
golang.org/x/tools v0.0.0-20190927191325-030b2cf1153e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
98 |
-
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
99 |
-
gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo=
|
100 |
-
gonum.org/v1/gonum v0.8.2/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0=
|
101 |
-
gonum.org/v1/gonum v0.9.1 h1:HCWmqqNoELL0RAQeKBXWtkp04mGk8koafcB4He6+uhc=
|
102 |
-
gonum.org/v1/gonum v0.9.1/go.mod h1:TZumC3NeyVQskjXqmyWt4S3bINhy7B4eYwW69EbyX+0=
|
103 |
-
gonum.org/v1/gonum v0.14.0 h1:2NiG67LD1tEH0D7kM+ps2V+fXmsAnpUeec7n8tcr4S0=
|
104 |
-
gonum.org/v1/gonum v0.14.0/go.mod h1:AoWeoz0becf9QMWtE8iWXNXc27fK4fNeHNf/oMejGfU=
|
105 |
-
gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0 h1:OE9mWmgKkjJyEmDAAtGMPjXu+YNeGvK9VTSHY6+Qihc=
|
106 |
-
gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw=
|
107 |
-
gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc=
|
108 |
-
gonum.org/v1/plot v0.9.0/go.mod h1:3Pcqqmp6RHvJI72kgb8fThyUnav364FOsdDo2aGW5lY=
|
109 |
-
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
nn/activation.go
DELETED
@@ -1,41 +0,0 @@
|
|
1 |
-
package nn
|
2 |
-
|
3 |
-
import "math"
|
4 |
-
|
5 |
-
var ActivationMap = map[string]func(float64) float64{
|
6 |
-
"sigmoid": Sigmoid,
|
7 |
-
"tanh": Tanh,
|
8 |
-
"relu": Relu,
|
9 |
-
}
|
10 |
-
|
11 |
-
func Sigmoid(x float64) float64 {
|
12 |
-
return 1.0 / (1.0 + math.Exp(-x))
|
13 |
-
}
|
14 |
-
|
15 |
-
func SigmoidPrime(x float64) float64 {
|
16 |
-
s := Sigmoid(x)
|
17 |
-
return s / (1.0 - s)
|
18 |
-
}
|
19 |
-
|
20 |
-
func Tanh(x float64) float64 {
|
21 |
-
return math.Tanh(x)
|
22 |
-
}
|
23 |
-
|
24 |
-
func TanhPrime(x float64) float64 {
|
25 |
-
return math.Pow((1.0 / math.Cosh(x)), 2)
|
26 |
-
}
|
27 |
-
|
28 |
-
func Relu(x float64) float64 {
|
29 |
-
if x > 0 {
|
30 |
-
return x
|
31 |
-
}
|
32 |
-
return 0
|
33 |
-
}
|
34 |
-
|
35 |
-
func ReluPrime(x float64) float64 {
|
36 |
-
// maybe want to look into edge case if x == 0
|
37 |
-
if x > 0 {
|
38 |
-
return 1
|
39 |
-
}
|
40 |
-
return 0
|
41 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
nn/args.go
DELETED
@@ -1,17 +0,0 @@
|
|
1 |
-
package nn
|
2 |
-
|
3 |
-
type NNArgs struct {
|
4 |
-
epochs int
|
5 |
-
hiddenSize int
|
6 |
-
learningRate float64
|
7 |
-
activationFunc func()
|
8 |
-
}
|
9 |
-
|
10 |
-
func NewArgs(argsMap map[string]interface{}) *NNArgs {
|
11 |
-
return &NNArgs{
|
12 |
-
epochs: argsMap["epochs"].(int),
|
13 |
-
hiddenSize: argsMap["hidden_size"].(int),
|
14 |
-
learningRate: argsMap["learning_rate"].(float64),
|
15 |
-
activationFunc: argsMap["activation"].(func()),
|
16 |
-
}
|
17 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
nn/backprop.go
DELETED
@@ -1,99 +0,0 @@
|
|
1 |
-
package nn
|
2 |
-
|
3 |
-
import (
|
4 |
-
"fmt"
|
5 |
-
|
6 |
-
"gonum.org/v1/gonum/mat"
|
7 |
-
)
|
8 |
-
|
9 |
-
func (nn *NN) Backprop() {
|
10 |
-
var (
|
11 |
-
activation = *nn.ActivationFunc
|
12 |
-
// lossHist []float64
|
13 |
-
)
|
14 |
-
|
15 |
-
for i := 0; i < nn.Epochs; i++ {
|
16 |
-
// compute output with current w + b
|
17 |
-
// then compute loss & backprop
|
18 |
-
hiddenOutput, err := computeOutput(
|
19 |
-
nn.XTrain,
|
20 |
-
nn.Wh,
|
21 |
-
nn.Bh,
|
22 |
-
activation,
|
23 |
-
)
|
24 |
-
if err != nil {
|
25 |
-
fmt.Printf("error computing hidden output: %v", err)
|
26 |
-
}
|
27 |
-
|
28 |
-
yHat, err := computeOutput(
|
29 |
-
hiddenOutput,
|
30 |
-
nn.Wo,
|
31 |
-
nn.Bo,
|
32 |
-
activation,
|
33 |
-
)
|
34 |
-
if err != nil {
|
35 |
-
fmt.Printf("error computing yHat: %v", err)
|
36 |
-
}
|
37 |
-
|
38 |
-
mse := meanSquaredError(nn.YTrain, yHat)
|
39 |
-
fmt.Println(mse)
|
40 |
-
|
41 |
-
}
|
42 |
-
|
43 |
-
}
|
44 |
-
|
45 |
-
func computeOutput(arr, w, b *mat.Dense, activationFunc func(float64) float64) (*mat.Dense, error) {
|
46 |
-
// Check if any of the input matrices is nil
|
47 |
-
if arr == nil || w == nil || b == nil {
|
48 |
-
return nil, fmt.Errorf("Input matrices cannot be nil")
|
49 |
-
}
|
50 |
-
|
51 |
-
// Check input dimensions
|
52 |
-
arrRows, arrCols := arr.Dims()
|
53 |
-
wRows, wCols := w.Dims()
|
54 |
-
bRows, bCols := b.Dims()
|
55 |
-
|
56 |
-
if arrCols != wRows || bCols != wCols {
|
57 |
-
return nil, fmt.Errorf("Matrix dimension mismatch: arr[%d, %d], w[%d, %d], b[%d, %d]", arrRows, arrCols, wRows, wCols, bRows, bCols)
|
58 |
-
}
|
59 |
-
|
60 |
-
// Compute the dot product between the input matrix 'arr' and the weight matrix 'w'
|
61 |
-
var product mat.Dense
|
62 |
-
product.Mul(arr, w)
|
63 |
-
|
64 |
-
// Check dimensions of product and bias
|
65 |
-
productRows, productCols := product.Dims()
|
66 |
-
if productCols != bCols {
|
67 |
-
return nil, fmt.Errorf("Matrix dimension mismatch: product[%d, %d], b[%d, %d]", productRows, productCols, bRows, bCols)
|
68 |
-
}
|
69 |
-
|
70 |
-
// Add the bias matrix 'b' to the product
|
71 |
-
var result mat.Dense
|
72 |
-
result.Add(&product, b)
|
73 |
-
|
74 |
-
// Apply the activation function to the result
|
75 |
-
applyActivation(&result, activationFunc)
|
76 |
-
|
77 |
-
return &result, nil
|
78 |
-
}
|
79 |
-
|
80 |
-
func applyActivation(m *mat.Dense, f func(float64) float64) {
|
81 |
-
r, c := m.Dims()
|
82 |
-
data := m.RawMatrix().Data
|
83 |
-
for i := 0; i < r*c; i++ {
|
84 |
-
data[i] = f(data[i])
|
85 |
-
}
|
86 |
-
}
|
87 |
-
|
88 |
-
func meanSquaredError(y, yHat *mat.Dense) float64 {
|
89 |
-
var sum float64
|
90 |
-
r, c := y.Dims()
|
91 |
-
|
92 |
-
for row := 0; row < r; row++ {
|
93 |
-
for col := 0; col < c; col++ {
|
94 |
-
diff := y.At(row, col) - yHat.At(row, col)
|
95 |
-
sum += (diff * diff)
|
96 |
-
}
|
97 |
-
}
|
98 |
-
return sum / float64((r * c))
|
99 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
nn/backprop.py
ADDED
File without changes
|
nn/main.go
DELETED
@@ -1,84 +0,0 @@
|
|
1 |
-
package nn
|
2 |
-
|
3 |
-
import (
|
4 |
-
"fmt"
|
5 |
-
"math/rand"
|
6 |
-
"strings"
|
7 |
-
|
8 |
-
"github.com/go-gota/gota/dataframe"
|
9 |
-
"github.com/gofiber/fiber/v2"
|
10 |
-
"gonum.org/v1/gonum/mat"
|
11 |
-
)
|
12 |
-
|
13 |
-
type NN struct {
|
14 |
-
// attributes set by request
|
15 |
-
CSVData string `json:"csv_data"`
|
16 |
-
Features []string `json:"features"`
|
17 |
-
Target string `json:"target"`
|
18 |
-
Epochs int `json:"epochs"`
|
19 |
-
HiddenSize int `json:"hidden_size"`
|
20 |
-
LearningRate float64 `json:"learning_rate"`
|
21 |
-
Activation string `json:"activation"`
|
22 |
-
TestSize float64 `json:"test_size"`
|
23 |
-
|
24 |
-
// attributes set after args above are parsed
|
25 |
-
ActivationFunc *func(float64) float64
|
26 |
-
Df *dataframe.DataFrame
|
27 |
-
XTrain *mat.Dense
|
28 |
-
YTrain *mat.Dense
|
29 |
-
XTest *mat.Dense
|
30 |
-
YTest *mat.Dense
|
31 |
-
Wh *mat.Dense
|
32 |
-
Bh *mat.Dense
|
33 |
-
Wo *mat.Dense
|
34 |
-
Bo *mat.Dense
|
35 |
-
}
|
36 |
-
|
37 |
-
func NewNN(c *fiber.Ctx) (*NN, error) {
|
38 |
-
newNN := new(NN)
|
39 |
-
err := c.BodyParser(newNN)
|
40 |
-
if err != nil {
|
41 |
-
return nil, fmt.Errorf("invalid JSON data: %v", err)
|
42 |
-
}
|
43 |
-
df := dataframe.ReadCSV(strings.NewReader(newNN.CSVData))
|
44 |
-
activation := ActivationMap[newNN.Activation]
|
45 |
-
|
46 |
-
newNN.Df = &df
|
47 |
-
newNN.ActivationFunc = &activation
|
48 |
-
return newNN, nil
|
49 |
-
}
|
50 |
-
|
51 |
-
func (nn *NN) InitWnB() {
|
52 |
-
// randomly initialize weights and biases to start
|
53 |
-
inputSize := len(nn.Features)
|
54 |
-
hiddenSize := nn.HiddenSize
|
55 |
-
outputSize := 1 // only predicting one thing
|
56 |
-
|
57 |
-
// Initialize input hidden layer weights as a Gonum matrix
|
58 |
-
wh := mat.NewDense(inputSize, hiddenSize, nil)
|
59 |
-
wh.Apply(func(i, j int, v float64) float64 {
|
60 |
-
return rand.Float64() - 0.5
|
61 |
-
}, wh)
|
62 |
-
|
63 |
-
// Initialize hidden layer bias as a Gonum matrix
|
64 |
-
bh := mat.NewDense(1, hiddenSize, nil)
|
65 |
-
bh.Apply(func(i, j int, v float64) float64 {
|
66 |
-
return rand.Float64() - 0.5
|
67 |
-
}, bh)
|
68 |
-
|
69 |
-
// Initialize weights and biases for hidden -> output layer as Gonum matrices
|
70 |
-
wo := mat.NewDense(hiddenSize, outputSize, nil)
|
71 |
-
wo.Apply(func(i, j int, v float64) float64 {
|
72 |
-
return rand.Float64() - 0.5
|
73 |
-
}, wo)
|
74 |
-
|
75 |
-
bo := mat.NewDense(1, outputSize, nil)
|
76 |
-
bo.Apply(func(i, j int, v float64) float64 {
|
77 |
-
return rand.Float64() - 0.5
|
78 |
-
}, bo)
|
79 |
-
|
80 |
-
nn.Wh = wh
|
81 |
-
nn.Bh = bh
|
82 |
-
nn.Wo = wo
|
83 |
-
nn.Bo = bo
|
84 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
nn/nn.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
|
3 |
+
|
4 |
+
class NN:
|
5 |
+
def __init__(
|
6 |
+
self,
|
7 |
+
epochs: int,
|
8 |
+
hidden_size: int,
|
9 |
+
learning_rate: float,
|
10 |
+
test_size: float,
|
11 |
+
activation: str,
|
12 |
+
features: list[str],
|
13 |
+
target: str,
|
14 |
+
data: str,
|
15 |
+
):
|
16 |
+
self.epochs = epochs
|
17 |
+
self.hidden_size = hidden_size
|
18 |
+
self.learning_rate = learning_rate
|
19 |
+
self.test_size = test_size
|
20 |
+
self.activation = activation
|
21 |
+
self.features = features
|
22 |
+
self.target = target
|
23 |
+
self.data = data
|
24 |
+
|
25 |
+
self.df: pd.DataFrame = None
|
26 |
+
|
27 |
+
@classmethod
|
28 |
+
def from_dict(cls, dct):
|
29 |
+
""" Creates an instance of NN given a dictionary
|
30 |
+
we can use this to make sure that the arguments are right
|
31 |
+
"""
|
32 |
+
return cls(**dct)
|
nn/split.go
DELETED
@@ -1,60 +0,0 @@
|
|
1 |
-
package nn
|
2 |
-
|
3 |
-
import (
|
4 |
-
"math"
|
5 |
-
"math/rand"
|
6 |
-
|
7 |
-
"github.com/go-gota/gota/dataframe"
|
8 |
-
"gonum.org/v1/gonum/mat"
|
9 |
-
)
|
10 |
-
|
11 |
-
func (nn *NN) TrainTestSplit() {
|
12 |
-
// now we split the data into training
|
13 |
-
// and testing based on user specified
|
14 |
-
// nn.TestSize.
|
15 |
-
nRows := nn.Df.Nrow()
|
16 |
-
testRows := int(math.Floor(float64(nRows) * nn.TestSize))
|
17 |
-
|
18 |
-
// subset the testing data
|
19 |
-
// randomly select trainRows number of rows
|
20 |
-
randStrt := rand.Intn(int(math.Floor(float64(nRows) * nn.TestSize)))
|
21 |
-
test := nn.Df.Subset([]int{randStrt, randStrt + testRows})
|
22 |
-
|
23 |
-
// use what is left for training
|
24 |
-
allIndices := make([]int, nRows)
|
25 |
-
for i := range allIndices {
|
26 |
-
allIndices[i] = i
|
27 |
-
}
|
28 |
-
|
29 |
-
// Remove the test indices using slice append and variadic parameter
|
30 |
-
trainIndices := append(allIndices[:randStrt], allIndices[randStrt+testRows:]...)
|
31 |
-
|
32 |
-
// Create the train DataFrame using the trainIndices
|
33 |
-
train := nn.Df.Subset(trainIndices)
|
34 |
-
|
35 |
-
XTrain := train.Select(nn.Features)
|
36 |
-
YTrain := train.Select(nn.Target)
|
37 |
-
XTest := test.Select(nn.Features)
|
38 |
-
YTest := test.Select(nn.Target)
|
39 |
-
|
40 |
-
// to make linear algebra easier & faster,
|
41 |
-
// we convert these dataframes that we are
|
42 |
-
// performing potentially expensive computations
|
43 |
-
// on into gonum matrices since we no longer need the
|
44 |
-
// column names.
|
45 |
-
nn.XTrain = df2mat(&XTrain)
|
46 |
-
nn.YTrain = df2mat(&YTrain)
|
47 |
-
nn.XTest = df2mat(&XTest)
|
48 |
-
nn.YTest = df2mat(&YTest)
|
49 |
-
}
|
50 |
-
|
51 |
-
// df2mat -> converts gota dataframe into gonum matrix
|
52 |
-
func df2mat(df *dataframe.DataFrame) *mat.Dense {
|
53 |
-
m := mat.NewDense(df.Nrow(), df.Ncol(), nil)
|
54 |
-
for i := 0; i < df.Nrow(); i++ {
|
55 |
-
for j := 0; j < df.Ncol(); j++ {
|
56 |
-
m.Set(i, j, df.Elem(i, j).Float())
|
57 |
-
}
|
58 |
-
}
|
59 |
-
return m
|
60 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
nn/subset.go
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
package nn
|
2 |
-
|
3 |
-
// subset the data frame into just the
|
4 |
-
// features and target that the user specify
|
5 |
-
|
6 |
-
func Subset() {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
nn/train.go
DELETED
@@ -1,7 +0,0 @@
|
|
1 |
-
package nn
|
2 |
-
|
3 |
-
func (nn *NN) Train() {
|
4 |
-
nn.InitWnB()
|
5 |
-
nn.TrainTestSplit()
|
6 |
-
nn.Backprop()
|
7 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
nn/train.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sklearn.model_selection import train_test_split
|
2 |
+
from nn.nn import NN
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
def train(nn: NN):
|
8 |
+
X_train, X_test, y_train, y_test = train_test_split(
|
9 |
+
nn.X,
|
10 |
+
nn.y,
|
11 |
+
test_size=nn.test_size,
|
12 |
+
random_state=88,
|
13 |
+
)
|
14 |
+
|
15 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
Flask==3.0.0
|
2 |
+
pandas==2.1.1
|
3 |
+
Requests==2.31.0
|
server.go
DELETED
@@ -1,26 +0,0 @@
|
|
1 |
-
package main
|
2 |
-
|
3 |
-
import (
|
4 |
-
"github.com/Jensen-holm/ml-from-scratch/nn"
|
5 |
-
"github.com/gofiber/fiber/v2"
|
6 |
-
)
|
7 |
-
|
8 |
-
func main() {
|
9 |
-
app := fiber.New()
|
10 |
-
|
11 |
-
// eventually we might want to add a key to this endpoint
|
12 |
-
// that we will be able to validate.
|
13 |
-
app.Post("/neural-network", func(c *fiber.Ctx) error {
|
14 |
-
nn, err := nn.NewNN(c)
|
15 |
-
if err != nil {
|
16 |
-
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
|
17 |
-
"error": err,
|
18 |
-
})
|
19 |
-
}
|
20 |
-
|
21 |
-
nn.Train()
|
22 |
-
return c.SendString("No error")
|
23 |
-
})
|
24 |
-
|
25 |
-
app.Listen(":3000")
|
26 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|