npv2k1 commited on
Commit
f6c0b9e
·
verified ·
1 Parent(s): b6fee08
Files changed (3) hide show
  1. main.py +5 -7
  2. src/data/data_loader.py +17 -8
  3. src/utils/dataset.py +0 -0
main.py CHANGED
@@ -6,17 +6,15 @@ import torch.nn.functional as F
6
  from src.models.model import ShapeClassifier
7
 
8
  from src.configs.model_config import ModelConfig
9
- from src.data.data_loader import train_loader, num_classes
10
- from src.utils.tensorboard import writer
11
  from src.utils.train import train
12
  from src.utils.test import test
13
  from src.utils.wandb import wandb
14
  from src.utils.logs import logging
15
- import json
16
  from src.utils.model import save_model
17
 
18
 
19
- def train_runner():
20
 
21
  config = ModelConfig().get_config()
22
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -30,7 +28,7 @@ def train_runner():
30
 
31
  loss = train(train_loader, model=model, loss_fn=F.cross_entropy,
32
  optimizer=optimizer)
33
- test(train_loader, model=model, loss_fn=F.cross_entropy)
34
  # 3. Log metrics over time to visualize performance
35
  wandb.log({"loss": loss})
36
 
@@ -38,10 +36,10 @@ def train_runner():
38
  save_model(model, "results/models/last.pth")
39
 
40
  # 4. Log an artifact to W&B
41
- wandb.log_artifact("model.pth")
42
  # model.train()
43
 
44
 
45
  if __name__ == "__main__":
46
  logging.info("Training model")
47
- train_runner()
 
6
  from src.models.model import ShapeClassifier
7
 
8
  from src.configs.model_config import ModelConfig
9
+ from src.data.data_loader import train_loader, num_classes, val_loader
 
10
  from src.utils.train import train
11
  from src.utils.test import test
12
  from src.utils.wandb import wandb
13
  from src.utils.logs import logging
 
14
  from src.utils.model import save_model
15
 
16
 
17
+ def main():
18
 
19
  config = ModelConfig().get_config()
20
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
28
 
29
  loss = train(train_loader, model=model, loss_fn=F.cross_entropy,
30
  optimizer=optimizer)
31
+ test(val_loader, model=model, loss_fn=F.cross_entropy)
32
  # 3. Log metrics over time to visualize performance
33
  wandb.log({"loss": loss})
34
 
 
36
  save_model(model, "results/models/last.pth")
37
 
38
  # 4. Log an artifact to W&B
39
+ # wandb.log_artifact("model.pth")
40
  # model.train()
41
 
42
 
43
  if __name__ == "__main__":
44
  logging.info("Training model")
45
+ main()
src/data/data_loader.py CHANGED
@@ -3,26 +3,35 @@ from torch.utils.data import DataLoader
3
  from src.configs.model_config import ModelConfig
4
  from .transform import data_transform
5
  import os
 
 
6
 
7
  num_classes = 3
8
  config = ModelConfig().get_config()
9
 
10
- train_dataset = CustomDataset(data_folder=os.path.join(
11
  "data", 'raw'), transform=data_transform)
12
 
13
- # # Calculate the split point
14
- # split_index = int(0.8 * len(dataset))
 
 
 
15
 
16
- # # Split the dataset into training and testing
17
- # train_dataset = dataset[:split_index]
18
- # test_dataset = dataset[split_index:]
19
 
20
 
21
  train_loader = DataLoader(
22
  train_dataset, batch_size=config.batch_size, shuffle=True)
23
 
 
 
 
 
 
 
24
 
25
  def get_train_dataset(batch_size):
26
  return DataLoader(
27
- train_dataset, batch_size=batch_size, shuffle=True)
28
-
 
3
  from src.configs.model_config import ModelConfig
4
  from .transform import data_transform
5
  import os
6
+ from torch.utils.data import random_split
7
+
8
 
9
  num_classes = 3
10
  config = ModelConfig().get_config()
11
 
12
+ all_dataset = CustomDataset(data_folder=os.path.join(
13
  "data", 'raw'), transform=data_transform)
14
 
15
+ # split to train, val, test
16
+ total_size = len(all_dataset)
17
+ train_size = int(0.8 * total_size)
18
+ val_size = int(0.1 * total_size)
19
+ test_size = total_size - train_size - val_size
20
 
21
+ train_dataset, val_dataset, test_dataset = random_split(
22
+ all_dataset, [train_size, val_size, test_size])
 
23
 
24
 
25
  train_loader = DataLoader(
26
  train_dataset, batch_size=config.batch_size, shuffle=True)
27
 
28
+ val_loader = DataLoader(
29
+ val_dataset, batch_size=config.batch_size, shuffle=True)
30
+
31
+ test_loader = DataLoader(
32
+ test_dataset, batch_size=config.batch_size, shuffle=True)
33
+
34
 
35
  def get_train_dataset(batch_size):
36
  return DataLoader(
37
+ all_dataset, batch_size=batch_size, shuffle=True)
 
src/utils/dataset.py ADDED
File without changes