npv2k1 commited on
Commit
0e63e05
·
verified ·
1 Parent(s): 7d5fa6b

feat: update

Browse files
.gitignore CHANGED
@@ -165,4 +165,9 @@ data/raw/*
165
  data/processed/*
166
 
167
  !data/raw/.gitkeep
168
- !data/processed/.gitkeep
 
 
 
 
 
 
165
  data/processed/*
166
 
167
  !data/raw/.gitkeep
168
+ !data/processed/.gitkeep
169
+
170
+ uvenv/
171
+ runs/
172
+ wandb/
173
+ ubuntu-venv/
.vscode/launch.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ // Use IntelliSense to learn about possible attributes.
3
+ // Hover to view descriptions of existing attributes.
4
+ // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5
+ "version": "0.2.0",
6
+ "configurations": [
7
+ {
8
+ "name": "Python: Current File",
9
+ "type": "python",
10
+ "request": "launch",
11
+ "program": "${file}",
12
+ "console": "integratedTerminal",
13
+ "justMyCode": false
14
+ }
15
+ ]
16
+ }
.vscode/settings.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "[python]": {
3
+ "editor.defaultFormatter": "ms-python.autopep8",
4
+
5
+ },
6
+ "jupyter.debugJustMyCode": false,
7
+ "debug.allowBreakpointsEverywhere": true,
8
+
9
+ "python.formatting.provider": "none"
10
+ }
Makefile CHANGED
@@ -1,2 +1,4 @@
1
  package:
2
- pip freeze > requirements.txt
 
 
 
1
  package:
2
+ pip freeze > requirements.txt
3
+ venv:
4
+ source /mnt/d/ubuntu/env/mlenv/bin/activate
assets/example-wandb.jpeg ADDED
main.py CHANGED
@@ -1,5 +1,10 @@
1
- from src.train import train
2
-
 
 
 
 
3
  if __name__ == "__main__":
4
- train()
5
-
 
 
1
+ from src.train import train_runner
2
+ from src.auto import auto_hyper_parameter
3
+ import os
4
+ # set WANDB_API_KEY=$YOUR_API_KEY
5
+ # os.environ["WANDB_API_KEY"] = '7c0f2b9470a0a5c82bfae5bab4705344cb53288b'
6
+ # os.environ['WANDB_MODE'] = "offline"
7
  if __name__ == "__main__":
8
+ print("Training the model...")
9
+ # train_runner()
10
+ auto_hyper_parameter()
model.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:e72cf086b03927eeb2527f1e94fe3fbcdda64d8749746a308108f73bb47d9455
3
- size 33560199
 
 
 
 
notebooks/notes.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt CHANGED
@@ -1,68 +1,252 @@
 
1
  aiofiles==23.2.1
2
- altair==5.1.1
3
- annotated-types==0.5.0
4
- anyio==3.7.1
 
 
 
 
5
  attrs==23.1.0
 
 
 
 
 
 
 
 
 
 
6
  certifi==2022.12.7
 
7
  charset-normalizer==2.1.1
8
  click==8.1.7
9
  colorama==0.4.6
 
 
10
  contourpy==1.1.1
11
- cycler==0.12.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  exceptiongroup==1.1.3
13
- fastapi==0.103.2
14
- ffmpy==0.3.1
 
 
 
 
15
  filelock==3.9.0
16
- fonttools==4.43.0
17
- fsspec==2023.9.2
18
- gradio==3.45.2
19
- gradio_client==0.5.3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  h11==0.14.0
 
 
 
21
  httpcore==0.18.0
22
  httpx==0.25.0
23
- huggingface-hub==0.17.3
 
 
 
24
  idna==3.4
25
- importlib-resources==6.1.0
 
 
 
 
 
 
 
 
26
  Jinja2==3.1.2
27
- jsonschema==4.19.1
28
- jsonschema-specifications==2023.7.1
 
 
 
 
 
 
 
29
  kiwisolver==1.4.5
 
 
 
 
 
 
 
 
30
  MarkupSafe==2.1.2
31
  matplotlib==3.8.0
 
 
 
32
  mpmath==1.3.0
 
 
 
33
  networkx==3.0
34
- numpy==1.24.1
35
- orjson==3.9.7
36
- packaging==23.1
 
 
 
 
 
 
37
  pandas==2.1.1
38
- Pillow==9.3.0
39
- pydantic==2.4.2
40
- pydantic_core==2.10.1
41
- pydub==0.25.1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  pyparsing==3.1.1
 
 
 
 
 
 
 
43
  python-dateutil==2.8.2
44
- python-multipart==0.0.6
45
- pytz==2023.3.post1
 
46
  PyYAML==6.0.1
47
- referencing==0.30.2
48
- requests==2.28.1
49
- rpds-py==0.10.3
50
- semantic-version==2.10.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  six==1.16.0
 
52
  sniffio==1.3.0
53
- starlette==0.27.0
 
 
 
 
 
 
 
54
  sympy==1.12
55
- toolz==0.12.0
56
- # torch==2.0.1+cu117
57
- # torchaudio==2.0.2+cu117
58
- # torchvision==0.15.2+cu117
 
 
 
 
 
 
 
 
 
 
 
59
  tqdm==4.66.1
 
60
  typing_extensions==4.8.0
61
  tzdata==2023.3
62
- urllib3==1.26.13
63
- uvicorn==0.23.2
64
- websockets==11.0.3
65
-
66
- torch==2.0.1
67
- torchaudio==2.0.2
68
- torchvision==0.15.2
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.0.0
2
  aiofiles==23.2.1
3
+ anyio==4.0.0
4
+ anylabeling==0.3.3
5
+ appdirs==1.4.4
6
+ argcomplete==3.1.2
7
+ asgiref==3.7.2
8
+ asttokens==2.4.1
9
+ attr==0.3.1
10
  attrs==23.1.0
11
+ azure-core==1.29.4
12
+ azure-storage-blob==12.18.3
13
+ beautifulsoup4==4.12.2
14
+ bleach==5.0.1
15
+ boto==2.49.0
16
+ boto3==1.16.28
17
+ botocore==1.19.28
18
+ boxing==0.1.4
19
+ Brotli==1.1.0
20
+ cachetools==5.3.1
21
  certifi==2022.12.7
22
+ cffi==1.16.0
23
  charset-normalizer==2.1.1
24
  click==8.1.7
25
  colorama==0.4.6
26
+ coloredlogs==15.0.1
27
+ comm==0.2.0
28
  contourpy==1.1.1
29
+ cryptography==41.0.4
30
+ cycler==0.12.1
31
+ dacite==1.7.0
32
+ darkdetect==0.8.0
33
+ debugpy==1.8.0
34
+ decorator==5.1.1
35
+ defusedxml==0.7.1
36
+ Deprecated==1.2.14
37
+ dill==0.3.7
38
+ Django==3.2.20
39
+ django-annoying==0.10.6
40
+ django-cors-headers==3.6.0
41
+ django-debug-toolbar==3.2.1
42
+ django-environ==0.10.0
43
+ django-extensions==3.1.0
44
+ django-filter==2.4.0
45
+ django-model-utils==4.1.1
46
+ django-ranged-fileresponse==0.1.2
47
+ django-rq==2.5.1
48
+ django-storages==1.12.3
49
+ django-user-agents==0.4.0
50
+ djangorestframework==3.13.1
51
+ dnspython==2.4.2
52
+ docker-pycreds==0.4.0
53
+ drf-dynamic-fields==0.3.0
54
+ drf-flex-fields==0.9.5
55
+ drf-generators==0.3.0
56
  exceptiongroup==1.1.3
57
+ executing==2.0.1
58
+ expiringdict==1.2.2
59
+ fiftyone==0.22.1
60
+ fiftyone-brain==0.13.2
61
+ fiftyone-db==0.4.0
62
+ fiftyone-desktop==0.29.0
63
  filelock==3.9.0
64
+ flatbuffers==23.5.26
65
+ fonttools==4.43.1
66
+ fsspec==2023.4.0
67
+ ftfy==6.1.1
68
+ future==0.18.3
69
+ gitdb==4.0.11
70
+ GitPython==3.1.40
71
+ glob2==0.7
72
+ google-api-core==2.11.0
73
+ google-auth==2.23.4
74
+ google-auth-oauthlib==1.1.0
75
+ google-cloud-appengine-logging==1.1.0
76
+ google-cloud-audit-log==0.2.0
77
+ google-cloud-core==2.3.2
78
+ google-cloud-logging==2.7.2
79
+ google-cloud-storage==2.5.0
80
+ google-crc32c==1.5.0
81
+ google-resumable-media==2.3.3
82
+ googleapis-common-protos==1.56.4
83
+ graphql-core==3.2.3
84
+ grpc-google-iam-v1==0.12.4
85
+ grpcio==1.59.0
86
+ grpcio-status==1.48.2
87
  h11==0.14.0
88
+ h2==4.1.0
89
+ hpack==4.0.0
90
+ htmlmin==0.1.12
91
  httpcore==0.18.0
92
  httpx==0.25.0
93
+ humanfriendly==10.0
94
+ humansignal-drf-yasg==1.21.9
95
+ hypercorn==0.14.4
96
+ hyperframe==6.0.1
97
  idna==3.4
98
+ ijson==3.2.3
99
+ imageio==2.31.5
100
+ imgviz==1.7.4
101
+ inflate64==0.3.1
102
+ inflection==0.5.1
103
+ ipykernel==6.26.0
104
+ ipython==8.17.2
105
+ isodate==0.6.1
106
+ jedi==0.19.1
107
  Jinja2==3.1.2
108
+ jmespath==0.10.0
109
+ joblib==1.3.2
110
+ jsonlines==4.0.0
111
+ jsonpatch==1.33
112
+ jsonpointer==2.4
113
+ jsonschema==3.2.0
114
+ jupyter_client==8.6.0
115
+ jupyter_core==5.5.0
116
+ kaleido==0.2.1
117
  kiwisolver==1.4.5
118
+ label-studio==1.9.1.post0
119
+ label-studio-converter==0.0.57
120
+ label-studio-tools==0.0.3
121
+ launchdarkly-server-sdk==7.5.0
122
+ lazy_loader==0.3
123
+ lockfile==0.12.2
124
+ lxml==4.9.3
125
+ Markdown==3.5.1
126
  MarkupSafe==2.1.2
127
  matplotlib==3.8.0
128
+ matplotlib-inline==0.1.6
129
+ mongoengine==0.24.2
130
+ motor==3.3.1
131
  mpmath==1.3.0
132
+ multivolumefile==0.2.3
133
+ natsort==8.4.0
134
+ nest-asyncio==1.5.8
135
  networkx==3.0
136
+ nltk==3.6.7
137
+ numpy==1.24.3
138
+ oauthlib==3.2.2
139
+ onnx==1.13.1
140
+ onnxruntime==1.14.1
141
+ opencv-python==4.8.1.78
142
+ opencv-python-headless==4.8.1.78
143
+ ordered-set==4.0.2
144
+ packaging==23.2
145
  pandas==2.1.1
146
+ parso==0.8.3
147
+ pathtools==0.1.2
148
+ Pillow==10.0.1
149
+ platformdirs==4.0.0
150
+ plotly==5.17.0
151
+ pprintpp==0.4.0
152
+ priority==2.0.0
153
+ prompt-toolkit==3.0.40
154
+ proto-plus==1.22.3
155
+ protobuf==3.20.3
156
+ psutil==5.9.5
157
+ psycopg2-binary==2.9.6
158
+ pure-eval==0.2.2
159
+ py-cpuinfo==9.0.0
160
+ py7zr==0.20.6
161
+ pyasn1==0.5.0
162
+ pyasn1-modules==0.3.0
163
+ pybcj==1.0.1
164
+ pycparser==2.21
165
+ pycryptodomex==3.19.0
166
+ pydantic==1.10.13
167
+ Pygments==2.16.1
168
+ pymongo==4.5.0
169
  pyparsing==3.1.1
170
+ pyppmd==1.0.0
171
+ PyQt5==5.15.10
172
+ PyQt5-Qt5==5.15.2
173
+ PyQt5-sip==12.13.0
174
+ pyreadline3==3.4.1
175
+ pyRFC3339==1.1
176
+ pyrsistent==0.19.3
177
  python-dateutil==2.8.2
178
+ python-json-logger==2.0.4
179
+ pytz==2022.7.1
180
+ pywin32==306
181
  PyYAML==6.0.1
182
+ pyzmq==25.1.1
183
+ pyzstd==0.15.9
184
+ qimage2ndarray==1.10.0
185
+ rarfile==4.1
186
+ redis==3.5.3
187
+ regex==2023.10.3
188
+ requests==2.31.0
189
+ requests-oauthlib==1.3.1
190
+ retrying==1.3.4
191
+ rq==1.10.1
192
+ rsa==4.9
193
+ rules==2.2
194
+ s3transfer==0.3.7
195
+ scikit-image==0.22.0
196
+ scikit-learn==1.3.1
197
+ scipy==1.11.3
198
+ seaborn==0.13.0
199
+ semver==2.13.0
200
+ sentry-sdk==1.32.0
201
+ setproctitle==1.3.3
202
+ simplejson==3.19.2
203
  six==1.16.0
204
+ smmap==5.0.1
205
  sniffio==1.3.0
206
+ sortedcontainers==2.4.0
207
+ soupsieve==2.5
208
+ sqlparse==0.4.4
209
+ sse-starlette==0.10.3
210
+ sseclient-py==1.8.0
211
+ stack-data==0.6.3
212
+ starlette==0.31.1
213
+ strawberry-graphql==0.138.1
214
  sympy==1.12
215
+ tabulate==0.9.0
216
+ tenacity==8.2.3
217
+ tensorboard==2.15.1
218
+ tensorboard-data-server==0.7.2
219
+ termcolor==2.3.0
220
+ texttable==1.7.0
221
+ thop==0.1.1.post2209072238
222
+ threadpoolctl==3.2.0
223
+ tifffile==2023.9.26
224
+ tomli==2.0.1
225
+ torch==2.1.0+cu118
226
+ torchaudio==2.1.0+cu118
227
+ torchsummary==1.5.1
228
+ torchvision==0.16.0+cu118
229
+ tornado==6.3.3
230
  tqdm==4.66.1
231
+ traitlets==5.13.0
232
  typing_extensions==4.8.0
233
  tzdata==2023.3
234
+ tzlocal==5.1
235
+ ua-parser==0.18.0
236
+ ujson==5.8.0
237
+ ultralytics==8.0.197
238
+ universal-analytics-python3==1.1.1
239
+ uritemplate==4.1.1
240
+ urllib3==1.26.16
241
+ user-agents==2.2.0
242
+ visdom==0.2.4
243
+ voxel51-eta==0.12.0
244
+ wandb==0.16.0
245
+ wcwidth==0.2.8
246
+ webencodings==0.5.1
247
+ websocket-client==1.6.4
248
+ Werkzeug==3.0.1
249
+ wrapt==1.15.0
250
+ wsproto==1.2.0
251
+ xmljson==0.2.0
252
+ xmltodict==0.13.0
src/configs/model_config.py CHANGED
@@ -5,6 +5,7 @@ class ModelConfig:
5
  self.learning_rate = 0.001
6
  self.batch_size = 32
7
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
8
- self.epochs = 20
 
9
  def get_config(self):
10
  return self
 
5
  self.learning_rate = 0.001
6
  self.batch_size = 32
7
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
8
+ self.epochs = 5
9
+ self.log_interval = 2 # Log every 2 batches => number of items is 32*2 = 64
10
  def get_config(self):
11
  return self
src/data/data_loader.py CHANGED
@@ -7,7 +7,8 @@ import os
7
  num_classes = 3
8
  config = ModelConfig().get_config()
9
 
10
- train_dataset = CustomDataset(data_folder=os.path.join("data", 'raw'), transform=data_transform)
 
11
 
12
  # # Calculate the split point
13
  # split_index = int(0.8 * len(dataset))
@@ -17,5 +18,11 @@ train_dataset = CustomDataset(data_folder=os.path.join("data", 'raw'), transform
17
  # test_dataset = dataset[split_index:]
18
 
19
 
20
- train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
 
 
 
 
 
 
21
 
 
7
  num_classes = 3
8
  config = ModelConfig().get_config()
9
 
10
+ train_dataset = CustomDataset(data_folder=os.path.join(
11
+ "data", 'raw'), transform=data_transform)
12
 
13
  # # Calculate the split point
14
  # split_index = int(0.8 * len(dataset))
 
18
  # test_dataset = dataset[split_index:]
19
 
20
 
21
+ train_loader = DataLoader(
22
+ train_dataset, batch_size=config.batch_size, shuffle=True)
23
+
24
+
25
+ def get_train_dataset(batch_size):
26
+ return DataLoader(
27
+ train_dataset, batch_size=batch_size, shuffle=True)
28
 
src/models/model.py CHANGED
@@ -1,17 +1,34 @@
1
  from torch import nn
2
  import torch.nn.functional as F
3
-
4
  class ShapeClassifier(nn.Module):
5
- def __init__(self, num_classes):
6
  super(ShapeClassifier, self).__init__()
7
- self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)
8
- self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
9
- self.fc1 = nn.Linear(16 * 64 * 64, 128)
10
- self.fc2 = nn.Linear(128, num_classes)
 
 
 
 
 
 
 
11
 
12
  def forward(self, x):
 
 
 
13
  x = self.pool(F.relu(self.conv1(x)))
14
- x = x.view(-1, 16 * 64 * 64) # Adjust the dimensions based on your input image size
 
 
 
 
15
  x = F.relu(self.fc1(x))
 
 
16
  x = self.fc2(x)
17
- return x
 
 
1
  from torch import nn
2
  import torch.nn.functional as F
3
+ # Ảnh gốc có kích thước 128x128x3
4
  class ShapeClassifier(nn.Module):
5
+ def __init__(self, num_classes, hidden_size=128):
6
  super(ShapeClassifier, self).__init__()
7
+ # Layer 1: Convolutional layer with 3 input channels (RGB) and 16 output channels, using a 3x3 kernel and padding of 1
8
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1) # ra 128x128x16
9
+
10
+ # Layer 2: Max pooling layer with a 2x2 kernel and stride of 2 to reduce spatial dimensions
11
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2) # ra 64x64x16
12
+
13
+ # Layer 3: Fully connected layer with input size 16 * 64 * 64 (depends on the input image size) and output size 128
14
+ self.fc1 = nn.Linear(16 * 64 * 64, hidden_size)
15
+
16
+ # Layer 4: Fully connected layer with input size 128 and output size num_classes
17
+ self.fc2 = nn.Linear(hidden_size, num_classes)
18
 
19
  def forward(self, x):
20
+ # Forward pass through the network
21
+
22
+ # Apply convolution, activation function (ReLU), and max pooling
23
  x = self.pool(F.relu(self.conv1(x)))
24
+
25
+ # Adjust the dimensions for the fully connected layer
26
+ x = x.view(-1, 16 * 64 * 64)
27
+
28
+ # Apply activation function (ReLU) to the first fully connected layer
29
  x = F.relu(self.fc1(x))
30
+
31
+ # Output layer without activation function (applied later during loss computation)
32
  x = self.fc2(x)
33
+
34
+ return x
src/train.py CHANGED
@@ -1,51 +1,40 @@
1
  import torch
2
  import torch.optim as optim
3
  import torch.nn.functional as F
4
- from .models.model import ShapeClassifier
5
 
6
  from src.configs.model_config import ModelConfig
7
  from src.data.data_loader import train_loader, num_classes
 
 
 
 
 
 
8
 
9
 
10
- def train():
 
11
  config = ModelConfig().get_config()
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
  model = ShapeClassifier(num_classes=num_classes).to(device)
14
  optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
15
  log_interval = 20
 
 
 
16
  for epoch in range(config.epochs):
17
- model.train()
18
- running_loss = 0.0
19
-
20
- for batch_idx, (inputs, labels) in enumerate(train_loader):
21
- inputs, labels = inputs.to(device), labels.to(device)
22
- optimizer.zero_grad()
23
-
24
- outputs = model(inputs)
25
- loss = F.cross_entropy(outputs, labels)
26
- loss.backward()
27
- optimizer.step()
28
-
29
- running_loss += loss.item()
30
-
31
- if batch_idx % log_interval == 0:
32
- current_loss = running_loss / log_interval
33
- print(
34
- f"Epoch [{epoch + 1}/{config.epochs}], Batch [{batch_idx + 1}/{len(train_loader)}], Loss: {current_loss:.4f}")
35
- running_loss = 0.0
36
-
37
- # calculate the accuracy on the test set
38
-
39
- with torch.no_grad():
40
- model.eval()
41
- correct = 0
42
- total = 0
43
- for inputs, labels in train_loader:
44
- inputs, labels = inputs.to(device), labels.to(device)
45
- outputs = model(inputs)
46
- predicted = torch.argmax(outputs.data, 1)
47
- total += labels.size(0)
48
- correct += (predicted == labels).sum().item()
49
- print(f"Accuracy of the model on the test images: {100 * correct / total} %")
50
- # save the model
51
- torch.save(model.state_dict(), "model.pth")
 
1
  import torch
2
  import torch.optim as optim
3
  import torch.nn.functional as F
4
+ from src.models.model import ShapeClassifier
5
 
6
  from src.configs.model_config import ModelConfig
7
  from src.data.data_loader import train_loader, num_classes
8
+ from src.utils.logs import writer
9
+ from src.utils.train import train
10
+ from src.utils.test import test
11
+ import wandb
12
+ import json
13
+ wandb.init(project="template-pytorch-model", entity="nguyen")
14
 
15
 
16
+ def train_runner():
17
+
18
  config = ModelConfig().get_config()
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
  model = ShapeClassifier(num_classes=num_classes).to(device)
21
  optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
22
  log_interval = 20
23
+ # log models config to wandb
24
+ wandb.config.update(config)
25
+
26
  for epoch in range(config.epochs):
27
+ print(f"Epoch {epoch+1}\n-------------------------------")
28
+
29
+ loss = train(train_loader, model=model, loss_fn=F.cross_entropy,
30
+ optimizer=optimizer)
31
+ test(train_loader, model=model, loss_fn=F.cross_entropy)
32
+ # 3. Log metrics over time to visualize performance
33
+ wandb.log({"loss": loss})
34
+
35
+ # save model
36
+ torch.save(model.state_dict(), "model.pth")
37
+
38
+ # 4. Log an artifact to W&B
39
+ wandb.log_artifact("model.pth")
40
+ # model.train()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/utils/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import os
3
+ from inspect import isclass
4
+
5
+ # import all files under utils/
6
+ utils_dir = os.path.dirname(__file__)
7
+ for file in os.listdir(utils_dir):
8
+ path = os.path.join(utils_dir, file)
9
+ if not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)):
10
+ config_name = file[: file.find(".py")] if file.endswith(".py") else file
11
+ module = importlib.import_module("src.utils." + config_name)
12
+ for attribute_name in dir(module):
13
+ attribute = getattr(module, attribute_name)
14
+
15
+ if isclass(attribute):
16
+ # Add the class to this package's variables
17
+ globals()[attribute_name] = attribute
src/utils/logs.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from torch.utils.tensorboard import SummaryWriter
2
+ writer = SummaryWriter()
src/utils/test.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.configs.model_config import ModelConfig
2
+ import torch
3
+
4
+ def test(dataloader, model, loss_fn):
5
+ config = ModelConfig().get_config()
6
+ size = len(dataloader.dataset)
7
+ num_batches = len(dataloader)
8
+ model.eval()
9
+ test_loss, correct = 0, 0
10
+ with torch.no_grad():
11
+ for X, y in dataloader:
12
+ X, y = X.to(config.device), y.to(config.device)
13
+ pred = model(X)
14
+ test_loss += loss_fn(pred, y).item()
15
+ correct += (pred.argmax(1) == y).type(torch.float).sum().item()
16
+ test_loss /= num_batches
17
+ correct /= size
18
+ print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
src/utils/train.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.configs.model_config import ModelConfig
2
+ import torch
3
+ def train(dataloader, model, loss_fn, optimizer):
4
+ config = ModelConfig().get_config()
5
+ size = len(dataloader.dataset)
6
+ model.train()
7
+ for batch, (X, y) in enumerate(dataloader):
8
+ X, y = X.to(config.device), y.to(config.device)
9
+
10
+ # Compute prediction error
11
+ pred = model(X)
12
+ loss = loss_fn(pred, y)
13
+
14
+ # Backpropagation
15
+ loss.backward()
16
+ optimizer.step()
17
+ optimizer.zero_grad()
18
+
19
+ if batch % config.log_interval == 0:
20
+ loss, current = loss.item(), (batch + 1) * len(X)
21
+ print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
22
+ # return loss
23
+ return loss
src/utils/utils.py ADDED
@@ -0,0 +1 @@
 
 
1
+