Spaces:
Build error
Build error
feat: update
Browse files- .gitignore +6 -1
- .vscode/launch.json +16 -0
- .vscode/settings.json +10 -0
- Makefile +3 -1
- assets/example-wandb.jpeg +0 -0
- main.py +9 -4
- model.pth +0 -3
- notebooks/notes.ipynb +0 -0
- requirements.txt +223 -39
- src/configs/model_config.py +2 -1
- src/data/data_loader.py +9 -2
- src/models/model.py +25 -8
- src/train.py +26 -37
- src/utils/__init__.py +17 -0
- src/utils/logs.py +2 -0
- src/utils/test.py +18 -0
- src/utils/train.py +23 -0
- src/utils/utils.py +1 -0
.gitignore
CHANGED
@@ -165,4 +165,9 @@ data/raw/*
|
|
165 |
data/processed/*
|
166 |
|
167 |
!data/raw/.gitkeep
|
168 |
-
!data/processed/.gitkeep
|
|
|
|
|
|
|
|
|
|
|
|
165 |
data/processed/*
|
166 |
|
167 |
!data/raw/.gitkeep
|
168 |
+
!data/processed/.gitkeep
|
169 |
+
|
170 |
+
uvenv/
|
171 |
+
runs/
|
172 |
+
wandb/
|
173 |
+
ubuntu-venv/
|
.vscode/launch.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
// Use IntelliSense to learn about possible attributes.
|
3 |
+
// Hover to view descriptions of existing attributes.
|
4 |
+
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
5 |
+
"version": "0.2.0",
|
6 |
+
"configurations": [
|
7 |
+
{
|
8 |
+
"name": "Python: Current File",
|
9 |
+
"type": "python",
|
10 |
+
"request": "launch",
|
11 |
+
"program": "${file}",
|
12 |
+
"console": "integratedTerminal",
|
13 |
+
"justMyCode": false
|
14 |
+
}
|
15 |
+
]
|
16 |
+
}
|
.vscode/settings.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"[python]": {
|
3 |
+
"editor.defaultFormatter": "ms-python.autopep8",
|
4 |
+
|
5 |
+
},
|
6 |
+
"jupyter.debugJustMyCode": false,
|
7 |
+
"debug.allowBreakpointsEverywhere": true,
|
8 |
+
|
9 |
+
"python.formatting.provider": "none"
|
10 |
+
}
|
Makefile
CHANGED
@@ -1,2 +1,4 @@
|
|
1 |
package:
|
2 |
-
pip freeze > requirements.txt
|
|
|
|
|
|
1 |
package:
|
2 |
+
pip freeze > requirements.txt
|
3 |
+
venv:
|
4 |
+
source /mnt/d/ubuntu/env/mlenv/bin/activate
|
assets/example-wandb.jpeg
ADDED
![]() |
main.py
CHANGED
@@ -1,5 +1,10 @@
|
|
1 |
-
from src.train import
|
2 |
-
|
|
|
|
|
|
|
|
|
3 |
if __name__ == "__main__":
|
4 |
-
|
5 |
-
|
|
|
|
1 |
+
from src.train import train_runner
|
2 |
+
from src.auto import auto_hyper_parameter
|
3 |
+
import os
|
4 |
+
# set WANDB_API_KEY=$YOUR_API_KEY
|
5 |
+
# os.environ["WANDB_API_KEY"] = '7c0f2b9470a0a5c82bfae5bab4705344cb53288b'
|
6 |
+
# os.environ['WANDB_MODE'] = "offline"
|
7 |
if __name__ == "__main__":
|
8 |
+
print("Training the model...")
|
9 |
+
# train_runner()
|
10 |
+
auto_hyper_parameter()
|
model.pth
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:e72cf086b03927eeb2527f1e94fe3fbcdda64d8749746a308108f73bb47d9455
|
3 |
-
size 33560199
|
|
|
|
|
|
|
|
notebooks/notes.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
requirements.txt
CHANGED
@@ -1,68 +1,252 @@
|
|
|
|
1 |
aiofiles==23.2.1
|
2 |
-
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
5 |
attrs==23.1.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
certifi==2022.12.7
|
|
|
7 |
charset-normalizer==2.1.1
|
8 |
click==8.1.7
|
9 |
colorama==0.4.6
|
|
|
|
|
10 |
contourpy==1.1.1
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
exceptiongroup==1.1.3
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
15 |
filelock==3.9.0
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
h11==0.14.0
|
|
|
|
|
|
|
21 |
httpcore==0.18.0
|
22 |
httpx==0.25.0
|
23 |
-
|
|
|
|
|
|
|
24 |
idna==3.4
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
Jinja2==3.1.2
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
kiwisolver==1.4.5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
MarkupSafe==2.1.2
|
31 |
matplotlib==3.8.0
|
|
|
|
|
|
|
32 |
mpmath==1.3.0
|
|
|
|
|
|
|
33 |
networkx==3.0
|
34 |
-
|
35 |
-
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
pandas==2.1.1
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
pyparsing==3.1.1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
python-dateutil==2.8.2
|
44 |
-
python-
|
45 |
-
pytz==
|
|
|
46 |
PyYAML==6.0.1
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
six==1.16.0
|
|
|
52 |
sniffio==1.3.0
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
sympy==1.12
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
tqdm==4.66.1
|
|
|
60 |
typing_extensions==4.8.0
|
61 |
tzdata==2023.3
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==2.0.0
|
2 |
aiofiles==23.2.1
|
3 |
+
anyio==4.0.0
|
4 |
+
anylabeling==0.3.3
|
5 |
+
appdirs==1.4.4
|
6 |
+
argcomplete==3.1.2
|
7 |
+
asgiref==3.7.2
|
8 |
+
asttokens==2.4.1
|
9 |
+
attr==0.3.1
|
10 |
attrs==23.1.0
|
11 |
+
azure-core==1.29.4
|
12 |
+
azure-storage-blob==12.18.3
|
13 |
+
beautifulsoup4==4.12.2
|
14 |
+
bleach==5.0.1
|
15 |
+
boto==2.49.0
|
16 |
+
boto3==1.16.28
|
17 |
+
botocore==1.19.28
|
18 |
+
boxing==0.1.4
|
19 |
+
Brotli==1.1.0
|
20 |
+
cachetools==5.3.1
|
21 |
certifi==2022.12.7
|
22 |
+
cffi==1.16.0
|
23 |
charset-normalizer==2.1.1
|
24 |
click==8.1.7
|
25 |
colorama==0.4.6
|
26 |
+
coloredlogs==15.0.1
|
27 |
+
comm==0.2.0
|
28 |
contourpy==1.1.1
|
29 |
+
cryptography==41.0.4
|
30 |
+
cycler==0.12.1
|
31 |
+
dacite==1.7.0
|
32 |
+
darkdetect==0.8.0
|
33 |
+
debugpy==1.8.0
|
34 |
+
decorator==5.1.1
|
35 |
+
defusedxml==0.7.1
|
36 |
+
Deprecated==1.2.14
|
37 |
+
dill==0.3.7
|
38 |
+
Django==3.2.20
|
39 |
+
django-annoying==0.10.6
|
40 |
+
django-cors-headers==3.6.0
|
41 |
+
django-debug-toolbar==3.2.1
|
42 |
+
django-environ==0.10.0
|
43 |
+
django-extensions==3.1.0
|
44 |
+
django-filter==2.4.0
|
45 |
+
django-model-utils==4.1.1
|
46 |
+
django-ranged-fileresponse==0.1.2
|
47 |
+
django-rq==2.5.1
|
48 |
+
django-storages==1.12.3
|
49 |
+
django-user-agents==0.4.0
|
50 |
+
djangorestframework==3.13.1
|
51 |
+
dnspython==2.4.2
|
52 |
+
docker-pycreds==0.4.0
|
53 |
+
drf-dynamic-fields==0.3.0
|
54 |
+
drf-flex-fields==0.9.5
|
55 |
+
drf-generators==0.3.0
|
56 |
exceptiongroup==1.1.3
|
57 |
+
executing==2.0.1
|
58 |
+
expiringdict==1.2.2
|
59 |
+
fiftyone==0.22.1
|
60 |
+
fiftyone-brain==0.13.2
|
61 |
+
fiftyone-db==0.4.0
|
62 |
+
fiftyone-desktop==0.29.0
|
63 |
filelock==3.9.0
|
64 |
+
flatbuffers==23.5.26
|
65 |
+
fonttools==4.43.1
|
66 |
+
fsspec==2023.4.0
|
67 |
+
ftfy==6.1.1
|
68 |
+
future==0.18.3
|
69 |
+
gitdb==4.0.11
|
70 |
+
GitPython==3.1.40
|
71 |
+
glob2==0.7
|
72 |
+
google-api-core==2.11.0
|
73 |
+
google-auth==2.23.4
|
74 |
+
google-auth-oauthlib==1.1.0
|
75 |
+
google-cloud-appengine-logging==1.1.0
|
76 |
+
google-cloud-audit-log==0.2.0
|
77 |
+
google-cloud-core==2.3.2
|
78 |
+
google-cloud-logging==2.7.2
|
79 |
+
google-cloud-storage==2.5.0
|
80 |
+
google-crc32c==1.5.0
|
81 |
+
google-resumable-media==2.3.3
|
82 |
+
googleapis-common-protos==1.56.4
|
83 |
+
graphql-core==3.2.3
|
84 |
+
grpc-google-iam-v1==0.12.4
|
85 |
+
grpcio==1.59.0
|
86 |
+
grpcio-status==1.48.2
|
87 |
h11==0.14.0
|
88 |
+
h2==4.1.0
|
89 |
+
hpack==4.0.0
|
90 |
+
htmlmin==0.1.12
|
91 |
httpcore==0.18.0
|
92 |
httpx==0.25.0
|
93 |
+
humanfriendly==10.0
|
94 |
+
humansignal-drf-yasg==1.21.9
|
95 |
+
hypercorn==0.14.4
|
96 |
+
hyperframe==6.0.1
|
97 |
idna==3.4
|
98 |
+
ijson==3.2.3
|
99 |
+
imageio==2.31.5
|
100 |
+
imgviz==1.7.4
|
101 |
+
inflate64==0.3.1
|
102 |
+
inflection==0.5.1
|
103 |
+
ipykernel==6.26.0
|
104 |
+
ipython==8.17.2
|
105 |
+
isodate==0.6.1
|
106 |
+
jedi==0.19.1
|
107 |
Jinja2==3.1.2
|
108 |
+
jmespath==0.10.0
|
109 |
+
joblib==1.3.2
|
110 |
+
jsonlines==4.0.0
|
111 |
+
jsonpatch==1.33
|
112 |
+
jsonpointer==2.4
|
113 |
+
jsonschema==3.2.0
|
114 |
+
jupyter_client==8.6.0
|
115 |
+
jupyter_core==5.5.0
|
116 |
+
kaleido==0.2.1
|
117 |
kiwisolver==1.4.5
|
118 |
+
label-studio==1.9.1.post0
|
119 |
+
label-studio-converter==0.0.57
|
120 |
+
label-studio-tools==0.0.3
|
121 |
+
launchdarkly-server-sdk==7.5.0
|
122 |
+
lazy_loader==0.3
|
123 |
+
lockfile==0.12.2
|
124 |
+
lxml==4.9.3
|
125 |
+
Markdown==3.5.1
|
126 |
MarkupSafe==2.1.2
|
127 |
matplotlib==3.8.0
|
128 |
+
matplotlib-inline==0.1.6
|
129 |
+
mongoengine==0.24.2
|
130 |
+
motor==3.3.1
|
131 |
mpmath==1.3.0
|
132 |
+
multivolumefile==0.2.3
|
133 |
+
natsort==8.4.0
|
134 |
+
nest-asyncio==1.5.8
|
135 |
networkx==3.0
|
136 |
+
nltk==3.6.7
|
137 |
+
numpy==1.24.3
|
138 |
+
oauthlib==3.2.2
|
139 |
+
onnx==1.13.1
|
140 |
+
onnxruntime==1.14.1
|
141 |
+
opencv-python==4.8.1.78
|
142 |
+
opencv-python-headless==4.8.1.78
|
143 |
+
ordered-set==4.0.2
|
144 |
+
packaging==23.2
|
145 |
pandas==2.1.1
|
146 |
+
parso==0.8.3
|
147 |
+
pathtools==0.1.2
|
148 |
+
Pillow==10.0.1
|
149 |
+
platformdirs==4.0.0
|
150 |
+
plotly==5.17.0
|
151 |
+
pprintpp==0.4.0
|
152 |
+
priority==2.0.0
|
153 |
+
prompt-toolkit==3.0.40
|
154 |
+
proto-plus==1.22.3
|
155 |
+
protobuf==3.20.3
|
156 |
+
psutil==5.9.5
|
157 |
+
psycopg2-binary==2.9.6
|
158 |
+
pure-eval==0.2.2
|
159 |
+
py-cpuinfo==9.0.0
|
160 |
+
py7zr==0.20.6
|
161 |
+
pyasn1==0.5.0
|
162 |
+
pyasn1-modules==0.3.0
|
163 |
+
pybcj==1.0.1
|
164 |
+
pycparser==2.21
|
165 |
+
pycryptodomex==3.19.0
|
166 |
+
pydantic==1.10.13
|
167 |
+
Pygments==2.16.1
|
168 |
+
pymongo==4.5.0
|
169 |
pyparsing==3.1.1
|
170 |
+
pyppmd==1.0.0
|
171 |
+
PyQt5==5.15.10
|
172 |
+
PyQt5-Qt5==5.15.2
|
173 |
+
PyQt5-sip==12.13.0
|
174 |
+
pyreadline3==3.4.1
|
175 |
+
pyRFC3339==1.1
|
176 |
+
pyrsistent==0.19.3
|
177 |
python-dateutil==2.8.2
|
178 |
+
python-json-logger==2.0.4
|
179 |
+
pytz==2022.7.1
|
180 |
+
pywin32==306
|
181 |
PyYAML==6.0.1
|
182 |
+
pyzmq==25.1.1
|
183 |
+
pyzstd==0.15.9
|
184 |
+
qimage2ndarray==1.10.0
|
185 |
+
rarfile==4.1
|
186 |
+
redis==3.5.3
|
187 |
+
regex==2023.10.3
|
188 |
+
requests==2.31.0
|
189 |
+
requests-oauthlib==1.3.1
|
190 |
+
retrying==1.3.4
|
191 |
+
rq==1.10.1
|
192 |
+
rsa==4.9
|
193 |
+
rules==2.2
|
194 |
+
s3transfer==0.3.7
|
195 |
+
scikit-image==0.22.0
|
196 |
+
scikit-learn==1.3.1
|
197 |
+
scipy==1.11.3
|
198 |
+
seaborn==0.13.0
|
199 |
+
semver==2.13.0
|
200 |
+
sentry-sdk==1.32.0
|
201 |
+
setproctitle==1.3.3
|
202 |
+
simplejson==3.19.2
|
203 |
six==1.16.0
|
204 |
+
smmap==5.0.1
|
205 |
sniffio==1.3.0
|
206 |
+
sortedcontainers==2.4.0
|
207 |
+
soupsieve==2.5
|
208 |
+
sqlparse==0.4.4
|
209 |
+
sse-starlette==0.10.3
|
210 |
+
sseclient-py==1.8.0
|
211 |
+
stack-data==0.6.3
|
212 |
+
starlette==0.31.1
|
213 |
+
strawberry-graphql==0.138.1
|
214 |
sympy==1.12
|
215 |
+
tabulate==0.9.0
|
216 |
+
tenacity==8.2.3
|
217 |
+
tensorboard==2.15.1
|
218 |
+
tensorboard-data-server==0.7.2
|
219 |
+
termcolor==2.3.0
|
220 |
+
texttable==1.7.0
|
221 |
+
thop==0.1.1.post2209072238
|
222 |
+
threadpoolctl==3.2.0
|
223 |
+
tifffile==2023.9.26
|
224 |
+
tomli==2.0.1
|
225 |
+
torch==2.1.0+cu118
|
226 |
+
torchaudio==2.1.0+cu118
|
227 |
+
torchsummary==1.5.1
|
228 |
+
torchvision==0.16.0+cu118
|
229 |
+
tornado==6.3.3
|
230 |
tqdm==4.66.1
|
231 |
+
traitlets==5.13.0
|
232 |
typing_extensions==4.8.0
|
233 |
tzdata==2023.3
|
234 |
+
tzlocal==5.1
|
235 |
+
ua-parser==0.18.0
|
236 |
+
ujson==5.8.0
|
237 |
+
ultralytics==8.0.197
|
238 |
+
universal-analytics-python3==1.1.1
|
239 |
+
uritemplate==4.1.1
|
240 |
+
urllib3==1.26.16
|
241 |
+
user-agents==2.2.0
|
242 |
+
visdom==0.2.4
|
243 |
+
voxel51-eta==0.12.0
|
244 |
+
wandb==0.16.0
|
245 |
+
wcwidth==0.2.8
|
246 |
+
webencodings==0.5.1
|
247 |
+
websocket-client==1.6.4
|
248 |
+
Werkzeug==3.0.1
|
249 |
+
wrapt==1.15.0
|
250 |
+
wsproto==1.2.0
|
251 |
+
xmljson==0.2.0
|
252 |
+
xmltodict==0.13.0
|
src/configs/model_config.py
CHANGED
@@ -5,6 +5,7 @@ class ModelConfig:
|
|
5 |
self.learning_rate = 0.001
|
6 |
self.batch_size = 32
|
7 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
8 |
-
self.epochs =
|
|
|
9 |
def get_config(self):
|
10 |
return self
|
|
|
5 |
self.learning_rate = 0.001
|
6 |
self.batch_size = 32
|
7 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
8 |
+
self.epochs = 5
|
9 |
+
self.log_interval = 2 # Log every 2 batches => number of items is 32*2 = 64
|
10 |
def get_config(self):
|
11 |
return self
|
src/data/data_loader.py
CHANGED
@@ -7,7 +7,8 @@ import os
|
|
7 |
num_classes = 3
|
8 |
config = ModelConfig().get_config()
|
9 |
|
10 |
-
train_dataset = CustomDataset(data_folder=os.path.join(
|
|
|
11 |
|
12 |
# # Calculate the split point
|
13 |
# split_index = int(0.8 * len(dataset))
|
@@ -17,5 +18,11 @@ train_dataset = CustomDataset(data_folder=os.path.join("data", 'raw'), transform
|
|
17 |
# test_dataset = dataset[split_index:]
|
18 |
|
19 |
|
20 |
-
train_loader = DataLoader(
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
|
|
7 |
num_classes = 3
|
8 |
config = ModelConfig().get_config()
|
9 |
|
10 |
+
train_dataset = CustomDataset(data_folder=os.path.join(
|
11 |
+
"data", 'raw'), transform=data_transform)
|
12 |
|
13 |
# # Calculate the split point
|
14 |
# split_index = int(0.8 * len(dataset))
|
|
|
18 |
# test_dataset = dataset[split_index:]
|
19 |
|
20 |
|
21 |
+
train_loader = DataLoader(
|
22 |
+
train_dataset, batch_size=config.batch_size, shuffle=True)
|
23 |
+
|
24 |
+
|
25 |
+
def get_train_dataset(batch_size):
|
26 |
+
return DataLoader(
|
27 |
+
train_dataset, batch_size=batch_size, shuffle=True)
|
28 |
|
src/models/model.py
CHANGED
@@ -1,17 +1,34 @@
|
|
1 |
from torch import nn
|
2 |
import torch.nn.functional as F
|
3 |
-
|
4 |
class ShapeClassifier(nn.Module):
|
5 |
-
def __init__(self, num_classes):
|
6 |
super(ShapeClassifier, self).__init__()
|
7 |
-
|
8 |
-
self.
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
def forward(self, x):
|
|
|
|
|
|
|
13 |
x = self.pool(F.relu(self.conv1(x)))
|
14 |
-
|
|
|
|
|
|
|
|
|
15 |
x = F.relu(self.fc1(x))
|
|
|
|
|
16 |
x = self.fc2(x)
|
17 |
-
|
|
|
|
1 |
from torch import nn
|
2 |
import torch.nn.functional as F
|
3 |
+
# Ảnh gốc có kích thước 128x128x3
|
4 |
class ShapeClassifier(nn.Module):
|
5 |
+
def __init__(self, num_classes, hidden_size=128):
|
6 |
super(ShapeClassifier, self).__init__()
|
7 |
+
# Layer 1: Convolutional layer with 3 input channels (RGB) and 16 output channels, using a 3x3 kernel and padding of 1
|
8 |
+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1) # ra 128x128x16
|
9 |
+
|
10 |
+
# Layer 2: Max pooling layer with a 2x2 kernel and stride of 2 to reduce spatial dimensions
|
11 |
+
self.pool = nn.MaxPool2d(kernel_size=2, stride=2) # ra 64x64x16
|
12 |
+
|
13 |
+
# Layer 3: Fully connected layer with input size 16 * 64 * 64 (depends on the input image size) and output size 128
|
14 |
+
self.fc1 = nn.Linear(16 * 64 * 64, hidden_size)
|
15 |
+
|
16 |
+
# Layer 4: Fully connected layer with input size 128 and output size num_classes
|
17 |
+
self.fc2 = nn.Linear(hidden_size, num_classes)
|
18 |
|
19 |
def forward(self, x):
|
20 |
+
# Forward pass through the network
|
21 |
+
|
22 |
+
# Apply convolution, activation function (ReLU), and max pooling
|
23 |
x = self.pool(F.relu(self.conv1(x)))
|
24 |
+
|
25 |
+
# Adjust the dimensions for the fully connected layer
|
26 |
+
x = x.view(-1, 16 * 64 * 64)
|
27 |
+
|
28 |
+
# Apply activation function (ReLU) to the first fully connected layer
|
29 |
x = F.relu(self.fc1(x))
|
30 |
+
|
31 |
+
# Output layer without activation function (applied later during loss computation)
|
32 |
x = self.fc2(x)
|
33 |
+
|
34 |
+
return x
|
src/train.py
CHANGED
@@ -1,51 +1,40 @@
|
|
1 |
import torch
|
2 |
import torch.optim as optim
|
3 |
import torch.nn.functional as F
|
4 |
-
from .models.model import ShapeClassifier
|
5 |
|
6 |
from src.configs.model_config import ModelConfig
|
7 |
from src.data.data_loader import train_loader, num_classes
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
|
10 |
-
def
|
|
|
11 |
config = ModelConfig().get_config()
|
12 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
13 |
model = ShapeClassifier(num_classes=num_classes).to(device)
|
14 |
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
|
15 |
log_interval = 20
|
|
|
|
|
|
|
16 |
for epoch in range(config.epochs):
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
if batch_idx % log_interval == 0:
|
32 |
-
current_loss = running_loss / log_interval
|
33 |
-
print(
|
34 |
-
f"Epoch [{epoch + 1}/{config.epochs}], Batch [{batch_idx + 1}/{len(train_loader)}], Loss: {current_loss:.4f}")
|
35 |
-
running_loss = 0.0
|
36 |
-
|
37 |
-
# calculate the accuracy on the test set
|
38 |
-
|
39 |
-
with torch.no_grad():
|
40 |
-
model.eval()
|
41 |
-
correct = 0
|
42 |
-
total = 0
|
43 |
-
for inputs, labels in train_loader:
|
44 |
-
inputs, labels = inputs.to(device), labels.to(device)
|
45 |
-
outputs = model(inputs)
|
46 |
-
predicted = torch.argmax(outputs.data, 1)
|
47 |
-
total += labels.size(0)
|
48 |
-
correct += (predicted == labels).sum().item()
|
49 |
-
print(f"Accuracy of the model on the test images: {100 * correct / total} %")
|
50 |
-
# save the model
|
51 |
-
torch.save(model.state_dict(), "model.pth")
|
|
|
1 |
import torch
|
2 |
import torch.optim as optim
|
3 |
import torch.nn.functional as F
|
4 |
+
from src.models.model import ShapeClassifier
|
5 |
|
6 |
from src.configs.model_config import ModelConfig
|
7 |
from src.data.data_loader import train_loader, num_classes
|
8 |
+
from src.utils.logs import writer
|
9 |
+
from src.utils.train import train
|
10 |
+
from src.utils.test import test
|
11 |
+
import wandb
|
12 |
+
import json
|
13 |
+
wandb.init(project="template-pytorch-model", entity="nguyen")
|
14 |
|
15 |
|
16 |
+
def train_runner():
|
17 |
+
|
18 |
config = ModelConfig().get_config()
|
19 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
20 |
model = ShapeClassifier(num_classes=num_classes).to(device)
|
21 |
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
|
22 |
log_interval = 20
|
23 |
+
# log models config to wandb
|
24 |
+
wandb.config.update(config)
|
25 |
+
|
26 |
for epoch in range(config.epochs):
|
27 |
+
print(f"Epoch {epoch+1}\n-------------------------------")
|
28 |
+
|
29 |
+
loss = train(train_loader, model=model, loss_fn=F.cross_entropy,
|
30 |
+
optimizer=optimizer)
|
31 |
+
test(train_loader, model=model, loss_fn=F.cross_entropy)
|
32 |
+
# 3. Log metrics over time to visualize performance
|
33 |
+
wandb.log({"loss": loss})
|
34 |
+
|
35 |
+
# save model
|
36 |
+
torch.save(model.state_dict(), "model.pth")
|
37 |
+
|
38 |
+
# 4. Log an artifact to W&B
|
39 |
+
wandb.log_artifact("model.pth")
|
40 |
+
# model.train()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/utils/__init__.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import os
|
3 |
+
from inspect import isclass
|
4 |
+
|
5 |
+
# import all files under utils/
|
6 |
+
utils_dir = os.path.dirname(__file__)
|
7 |
+
for file in os.listdir(utils_dir):
|
8 |
+
path = os.path.join(utils_dir, file)
|
9 |
+
if not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)):
|
10 |
+
config_name = file[: file.find(".py")] if file.endswith(".py") else file
|
11 |
+
module = importlib.import_module("src.utils." + config_name)
|
12 |
+
for attribute_name in dir(module):
|
13 |
+
attribute = getattr(module, attribute_name)
|
14 |
+
|
15 |
+
if isclass(attribute):
|
16 |
+
# Add the class to this package's variables
|
17 |
+
globals()[attribute_name] = attribute
|
src/utils/logs.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from torch.utils.tensorboard import SummaryWriter
|
2 |
+
writer = SummaryWriter()
|
src/utils/test.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.configs.model_config import ModelConfig
|
2 |
+
import torch
|
3 |
+
|
4 |
+
def test(dataloader, model, loss_fn):
|
5 |
+
config = ModelConfig().get_config()
|
6 |
+
size = len(dataloader.dataset)
|
7 |
+
num_batches = len(dataloader)
|
8 |
+
model.eval()
|
9 |
+
test_loss, correct = 0, 0
|
10 |
+
with torch.no_grad():
|
11 |
+
for X, y in dataloader:
|
12 |
+
X, y = X.to(config.device), y.to(config.device)
|
13 |
+
pred = model(X)
|
14 |
+
test_loss += loss_fn(pred, y).item()
|
15 |
+
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
|
16 |
+
test_loss /= num_batches
|
17 |
+
correct /= size
|
18 |
+
print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
|
src/utils/train.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.configs.model_config import ModelConfig
|
2 |
+
import torch
|
3 |
+
def train(dataloader, model, loss_fn, optimizer):
|
4 |
+
config = ModelConfig().get_config()
|
5 |
+
size = len(dataloader.dataset)
|
6 |
+
model.train()
|
7 |
+
for batch, (X, y) in enumerate(dataloader):
|
8 |
+
X, y = X.to(config.device), y.to(config.device)
|
9 |
+
|
10 |
+
# Compute prediction error
|
11 |
+
pred = model(X)
|
12 |
+
loss = loss_fn(pred, y)
|
13 |
+
|
14 |
+
# Backpropagation
|
15 |
+
loss.backward()
|
16 |
+
optimizer.step()
|
17 |
+
optimizer.zero_grad()
|
18 |
+
|
19 |
+
if batch % config.log_interval == 0:
|
20 |
+
loss, current = loss.item(), (batch + 1) * len(X)
|
21 |
+
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
|
22 |
+
# return loss
|
23 |
+
return loss
|
src/utils/utils.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|