npv2k1 commited on
Commit
06c8a6d
·
verified ·
1 Parent(s): 0e63e05
.gitignore CHANGED
@@ -163,11 +163,14 @@ cython_debug/
163
  # ignore dataset but not the folder
164
  data/raw/*
165
  data/processed/*
 
166
 
167
  !data/raw/.gitkeep
168
  !data/processed/.gitkeep
 
169
 
170
- uvenv/
171
  runs/
172
- wandb/
173
- ubuntu-venv/
 
 
163
  # ignore dataset but not the folder
164
  data/raw/*
165
  data/processed/*
166
+ results/models/*
167
 
168
  !data/raw/.gitkeep
169
  !data/processed/.gitkeep
170
+ !results/models/.gitkeep
171
 
172
+ # tensorboard logs
173
  runs/
174
+
175
+ # wandb logs
176
+ wandb/
Makefile CHANGED
@@ -1,4 +1,6 @@
1
  package:
2
  pip freeze > requirements.txt
3
  venv:
4
- source /mnt/d/ubuntu/env/mlenv/bin/activate
 
 
 
1
  package:
2
  pip freeze > requirements.txt
3
  venv:
4
+ source /mnt/d/ubuntu/env/mlenv/bin/activate
5
+ tensorboard:
6
+ tensorboard --inspect --logdir logs/tensorboard
app.py CHANGED
@@ -18,7 +18,7 @@ def classify_drawing(drawing_image):
18
  num_classes = 3 # Set the number of classes
19
  # Initialize your model class
20
  model = ShapeClassifier(num_classes=num_classes)
21
- model.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
22
  model.eval() # Set the model to evaluation mode
23
 
24
  # Convert the drawing to a grayscale image
@@ -45,6 +45,5 @@ iface = gr.Interface(
45
  inputs=gr.Image(type="pil"), # Use Sketchpad as input
46
  outputs="text",
47
  live=True,
48
- capture_session=True,
49
  )
50
  iface.launch(server_port=7860)
 
18
  num_classes = 3 # Set the number of classes
19
  # Initialize your model class
20
  model = ShapeClassifier(num_classes=num_classes)
21
+ model.load_state_dict(torch.load('results/models/model.pth', map_location=torch.device('cpu')))
22
  model.eval() # Set the model to evaluation mode
23
 
24
  # Convert the drawing to a grayscale image
 
45
  inputs=gr.Image(type="pil"), # Use Sketchpad as input
46
  outputs="text",
47
  live=True,
 
48
  )
49
  iface.launch(server_port=7860)
data/raw/.gitkeep DELETED
File without changes
main.py CHANGED
@@ -1,10 +1,47 @@
1
- from src.train import train_runner
2
- from src.auto import auto_hyper_parameter
3
- import os
4
- # set WANDB_API_KEY=$YOUR_API_KEY
5
- # os.environ["WANDB_API_KEY"] = '7c0f2b9470a0a5c82bfae5bab4705344cb53288b'
6
- # os.environ['WANDB_MODE'] = "offline"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  if __name__ == "__main__":
8
- print("Training the model...")
9
- # train_runner()
10
- auto_hyper_parameter()
 
1
+
2
+
3
+ import torch
4
+ import torch.optim as optim
5
+ import torch.nn.functional as F
6
+ from src.models.model import ShapeClassifier
7
+
8
+ from src.configs.model_config import ModelConfig
9
+ from src.data.data_loader import train_loader, num_classes
10
+ from src.utils.tensorboard import writer
11
+ from src.utils.train import train
12
+ from src.utils.test import test
13
+ from src.utils.wandb import wandb
14
+ from src.utils.logs import logging
15
+ import json
16
+ from src.utils.model import save_model
17
+
18
+
19
+ def train_runner():
20
+
21
+ config = ModelConfig().get_config()
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ model = ShapeClassifier(num_classes=num_classes).to(device)
24
+ optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
25
+ # log models config to wandb
26
+ wandb.config.update(config)
27
+
28
+ for epoch in range(config.epochs):
29
+ print(f"Epoch {epoch+1}\n-------------------------------")
30
+
31
+ loss = train(train_loader, model=model, loss_fn=F.cross_entropy,
32
+ optimizer=optimizer)
33
+ test(train_loader, model=model, loss_fn=F.cross_entropy)
34
+ # 3. Log metrics over time to visualize performance
35
+ wandb.log({"loss": loss})
36
+
37
+ # save model
38
+ save_model(model, "results/models/last.pth")
39
+
40
+ # 4. Log an artifact to W&B
41
+ wandb.log_artifact("model.pth")
42
+ # model.train()
43
+
44
+
45
  if __name__ == "__main__":
46
+ logging.info("Training model")
47
+ train_runner()
 
{data/processed → results/models}/.gitkeep RENAMED
File without changes
src/configs/model_config.py CHANGED
@@ -1,11 +1,19 @@
1
  import torch
2
 
 
3
  class ModelConfig:
4
  def __init__(self):
5
  self.learning_rate = 0.001
6
  self.batch_size = 32
7
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
8
  self.epochs = 5
9
- self.log_interval = 2 # Log every 2 batches => number of items is 32*2 = 64
 
 
 
 
 
 
 
10
  def get_config(self):
11
- return self
 
1
  import torch
2
 
3
+
4
  class ModelConfig:
5
  def __init__(self):
6
  self.learning_rate = 0.001
7
  self.batch_size = 32
8
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
9
  self.epochs = 5
10
+ self.log_interval = 2 # Log every 2 batches => number of items is 32*2 = 64
11
+
12
+ # Wandb config
13
+ self.wandb = True
14
+ self.wandb_project = "template-pytorch-model"
15
+ self.wandb_entity = "nguyen"
16
+ self.wandb_api_key = ""
17
+
18
  def get_config(self):
19
+ return self
src/train.py DELETED
@@ -1,40 +0,0 @@
1
- import torch
2
- import torch.optim as optim
3
- import torch.nn.functional as F
4
- from src.models.model import ShapeClassifier
5
-
6
- from src.configs.model_config import ModelConfig
7
- from src.data.data_loader import train_loader, num_classes
8
- from src.utils.logs import writer
9
- from src.utils.train import train
10
- from src.utils.test import test
11
- import wandb
12
- import json
13
- wandb.init(project="template-pytorch-model", entity="nguyen")
14
-
15
-
16
- def train_runner():
17
-
18
- config = ModelConfig().get_config()
19
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
- model = ShapeClassifier(num_classes=num_classes).to(device)
21
- optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
22
- log_interval = 20
23
- # log models config to wandb
24
- wandb.config.update(config)
25
-
26
- for epoch in range(config.epochs):
27
- print(f"Epoch {epoch+1}\n-------------------------------")
28
-
29
- loss = train(train_loader, model=model, loss_fn=F.cross_entropy,
30
- optimizer=optimizer)
31
- test(train_loader, model=model, loss_fn=F.cross_entropy)
32
- # 3. Log metrics over time to visualize performance
33
- wandb.log({"loss": loss})
34
-
35
- # save model
36
- torch.save(model.state_dict(), "model.pth")
37
-
38
- # 4. Log an artifact to W&B
39
- wandb.log_artifact("model.pth")
40
- # model.train()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/utils/__init__.py CHANGED
@@ -1,17 +1,8 @@
1
- import importlib
2
- import os
3
- from inspect import isclass
4
 
5
- # import all files under utils/
6
- utils_dir = os.path.dirname(__file__)
7
- for file in os.listdir(utils_dir):
8
- path = os.path.join(utils_dir, file)
9
- if not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)):
10
- config_name = file[: file.find(".py")] if file.endswith(".py") else file
11
- module = importlib.import_module("src.utils." + config_name)
12
- for attribute_name in dir(module):
13
- attribute = getattr(module, attribute_name)
14
 
15
- if isclass(attribute):
16
- # Add the class to this package's variables
17
- globals()[attribute_name] = attribute
 
 
 
 
 
 
 
1
 
 
 
 
 
 
 
 
 
 
2
 
3
+ # import all files under utils/
4
+ from os.path import dirname, basename, isfile, join
5
+ import glob
6
+ modules = glob.glob(join(dirname(__file__), "*.py"))
7
+ __all__ = [basename(f)[:-3] for f in modules if isfile(f)
8
+ and not f.endswith('__init__.py')]
src/utils/file.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ # enssure folder exists
4
+
5
+
6
+ def ensure_folder_exists(path):
7
+ if not os.path.exists(path):
8
+ os.makedirs(path)
src/utils/logs.py CHANGED
@@ -1,2 +1,34 @@
1
- from torch.utils.tensorboard import SummaryWriter
2
- writer = SummaryWriter()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+
4
+ log_dir = "logs"
5
+ log_level = logging.INFO
6
+
7
+ if not os.path.exists(log_dir):
8
+ os.makedirs(log_dir)
9
+
10
+ log_filename = os.path.join(log_dir, "app.log")
11
+
12
+ logging.basicConfig(
13
+ filename=log_filename,
14
+ level=log_level,
15
+ format="%(asctime)s [%(levelname)s]: %(message)s",
16
+ datefmt="%Y-%m-%d %H:%M:%S"
17
+ )
18
+
19
+ console_handler = logging.StreamHandler()
20
+ console_handler.setLevel(log_level)
21
+ console_handler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s]: %(message)s"))
22
+ logging.getLogger().addHandler(console_handler)
23
+
24
+ def log_info(message):
25
+ logging.info(message)
26
+
27
+ def log_warning(message):
28
+ logging.warning(message)
29
+
30
+ def log_error(message):
31
+ logging.error(message)
32
+
33
+ def log_exception(message):
34
+ logging.exception(message)
src/utils/model.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # save model utils
2
+ import torch
3
+ import torch
4
+ import os
5
+
6
+
7
+ def save_model(model: torch.nn.Module, path: str) -> str:
8
+ parent_folder = os.path.dirname(path)
9
+ os.makedirs(parent_folder, exist_ok=True)
10
+ torch.save(model.state_dict(), path)
11
+ return path
12
+
13
+ def load_model(model: torch.nn.Module, path: str) -> torch.nn.Module:
14
+ model.load_state_dict(torch.load(path))
15
+ return model
src/utils/tensorboard.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from torch.utils.tensorboard import SummaryWriter
2
+ writer = SummaryWriter()
src/utils/wandb.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import wandb
2
+ from .logs import log_info
3
+ from src.configs.model_config import ModelConfig
4
+ config = ModelConfig().get_config()
5
+
6
+ if config.wandb:
7
+ project = config.wandb_project
8
+ entity = config.wandb_entity
9
+ api_key = config.wandb_api_key
10
+ wandb.login(key=api_key)
11
+ wandb.init(project=project, entity=entity)
12
+ log_info("Wandb is enabled")
13
+ else:
14
+ log_info("Wandb is disabled")
15
+ # disable wandb
16
+ wandb.init(mode="disabled")
17
+ pass