Spaces:
Build error
Build error
update
Browse files- main.py +5 -7
- src/data/data_loader.py +17 -8
- src/utils/dataset.py +0 -0
main.py
CHANGED
@@ -6,17 +6,15 @@ import torch.nn.functional as F
|
|
6 |
from src.models.model import ShapeClassifier
|
7 |
|
8 |
from src.configs.model_config import ModelConfig
|
9 |
-
from src.data.data_loader import train_loader, num_classes
|
10 |
-
from src.utils.tensorboard import writer
|
11 |
from src.utils.train import train
|
12 |
from src.utils.test import test
|
13 |
from src.utils.wandb import wandb
|
14 |
from src.utils.logs import logging
|
15 |
-
import json
|
16 |
from src.utils.model import save_model
|
17 |
|
18 |
|
19 |
-
def
|
20 |
|
21 |
config = ModelConfig().get_config()
|
22 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
@@ -30,7 +28,7 @@ def train_runner():
|
|
30 |
|
31 |
loss = train(train_loader, model=model, loss_fn=F.cross_entropy,
|
32 |
optimizer=optimizer)
|
33 |
-
test(
|
34 |
# 3. Log metrics over time to visualize performance
|
35 |
wandb.log({"loss": loss})
|
36 |
|
@@ -38,10 +36,10 @@ def train_runner():
|
|
38 |
save_model(model, "results/models/last.pth")
|
39 |
|
40 |
# 4. Log an artifact to W&B
|
41 |
-
wandb.log_artifact("model.pth")
|
42 |
# model.train()
|
43 |
|
44 |
|
45 |
if __name__ == "__main__":
|
46 |
logging.info("Training model")
|
47 |
-
|
|
|
6 |
from src.models.model import ShapeClassifier
|
7 |
|
8 |
from src.configs.model_config import ModelConfig
|
9 |
+
from src.data.data_loader import train_loader, num_classes, val_loader
|
|
|
10 |
from src.utils.train import train
|
11 |
from src.utils.test import test
|
12 |
from src.utils.wandb import wandb
|
13 |
from src.utils.logs import logging
|
|
|
14 |
from src.utils.model import save_model
|
15 |
|
16 |
|
17 |
+
def main():
|
18 |
|
19 |
config = ModelConfig().get_config()
|
20 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
28 |
|
29 |
loss = train(train_loader, model=model, loss_fn=F.cross_entropy,
|
30 |
optimizer=optimizer)
|
31 |
+
test(val_loader, model=model, loss_fn=F.cross_entropy)
|
32 |
# 3. Log metrics over time to visualize performance
|
33 |
wandb.log({"loss": loss})
|
34 |
|
|
|
36 |
save_model(model, "results/models/last.pth")
|
37 |
|
38 |
# 4. Log an artifact to W&B
|
39 |
+
# wandb.log_artifact("model.pth")
|
40 |
# model.train()
|
41 |
|
42 |
|
43 |
if __name__ == "__main__":
|
44 |
logging.info("Training model")
|
45 |
+
main()
|
src/data/data_loader.py
CHANGED
@@ -3,26 +3,35 @@ from torch.utils.data import DataLoader
|
|
3 |
from src.configs.model_config import ModelConfig
|
4 |
from .transform import data_transform
|
5 |
import os
|
|
|
|
|
6 |
|
7 |
num_classes = 3
|
8 |
config = ModelConfig().get_config()
|
9 |
|
10 |
-
|
11 |
"data", 'raw'), transform=data_transform)
|
12 |
|
13 |
-
#
|
14 |
-
|
|
|
|
|
|
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
# test_dataset = dataset[split_index:]
|
19 |
|
20 |
|
21 |
train_loader = DataLoader(
|
22 |
train_dataset, batch_size=config.batch_size, shuffle=True)
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
def get_train_dataset(batch_size):
|
26 |
return DataLoader(
|
27 |
-
|
28 |
-
|
|
|
3 |
from src.configs.model_config import ModelConfig
|
4 |
from .transform import data_transform
|
5 |
import os
|
6 |
+
from torch.utils.data import random_split
|
7 |
+
|
8 |
|
9 |
num_classes = 3
|
10 |
config = ModelConfig().get_config()
|
11 |
|
12 |
+
all_dataset = CustomDataset(data_folder=os.path.join(
|
13 |
"data", 'raw'), transform=data_transform)
|
14 |
|
15 |
+
# split to train, val, test
|
16 |
+
total_size = len(all_dataset)
|
17 |
+
train_size = int(0.8 * total_size)
|
18 |
+
val_size = int(0.1 * total_size)
|
19 |
+
test_size = total_size - train_size - val_size
|
20 |
|
21 |
+
train_dataset, val_dataset, test_dataset = random_split(
|
22 |
+
all_dataset, [train_size, val_size, test_size])
|
|
|
23 |
|
24 |
|
25 |
train_loader = DataLoader(
|
26 |
train_dataset, batch_size=config.batch_size, shuffle=True)
|
27 |
|
28 |
+
val_loader = DataLoader(
|
29 |
+
val_dataset, batch_size=config.batch_size, shuffle=True)
|
30 |
+
|
31 |
+
test_loader = DataLoader(
|
32 |
+
test_dataset, batch_size=config.batch_size, shuffle=True)
|
33 |
+
|
34 |
|
35 |
def get_train_dataset(batch_size):
|
36 |
return DataLoader(
|
37 |
+
all_dataset, batch_size=batch_size, shuffle=True)
|
|
src/utils/dataset.py
ADDED
File without changes
|