Spaces:
Build error
Build error
update
Browse files- .gitignore +6 -3
- Makefile +3 -1
- app.py +1 -2
- data/raw/.gitkeep +0 -0
- main.py +46 -9
- {data/processed → results/models}/.gitkeep +0 -0
- src/configs/model_config.py +10 -2
- src/train.py +0 -40
- src/utils/__init__.py +6 -15
- src/utils/file.py +8 -0
- src/utils/logs.py +34 -2
- src/utils/model.py +15 -0
- src/utils/tensorboard.py +2 -0
- src/utils/wandb.py +17 -0
.gitignore
CHANGED
@@ -163,11 +163,14 @@ cython_debug/
|
|
163 |
# ignore dataset but not the folder
|
164 |
data/raw/*
|
165 |
data/processed/*
|
|
|
166 |
|
167 |
!data/raw/.gitkeep
|
168 |
!data/processed/.gitkeep
|
|
|
169 |
|
170 |
-
|
171 |
runs/
|
172 |
-
|
173 |
-
|
|
|
|
163 |
# ignore dataset but not the folder
|
164 |
data/raw/*
|
165 |
data/processed/*
|
166 |
+
results/models/*
|
167 |
|
168 |
!data/raw/.gitkeep
|
169 |
!data/processed/.gitkeep
|
170 |
+
!results/models/.gitkeep
|
171 |
|
172 |
+
# tensorboard logs
|
173 |
runs/
|
174 |
+
|
175 |
+
# wandb logs
|
176 |
+
wandb/
|
Makefile
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
package:
|
2 |
pip freeze > requirements.txt
|
3 |
venv:
|
4 |
-
source /mnt/d/ubuntu/env/mlenv/bin/activate
|
|
|
|
|
|
1 |
package:
|
2 |
pip freeze > requirements.txt
|
3 |
venv:
|
4 |
+
source /mnt/d/ubuntu/env/mlenv/bin/activate
|
5 |
+
tensorboard:
|
6 |
+
tensorboard --inspect --logdir logs/tensorboard
|
app.py
CHANGED
@@ -18,7 +18,7 @@ def classify_drawing(drawing_image):
|
|
18 |
num_classes = 3 # Set the number of classes
|
19 |
# Initialize your model class
|
20 |
model = ShapeClassifier(num_classes=num_classes)
|
21 |
-
model.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
|
22 |
model.eval() # Set the model to evaluation mode
|
23 |
|
24 |
# Convert the drawing to a grayscale image
|
@@ -45,6 +45,5 @@ iface = gr.Interface(
|
|
45 |
inputs=gr.Image(type="pil"), # Use Sketchpad as input
|
46 |
outputs="text",
|
47 |
live=True,
|
48 |
-
capture_session=True,
|
49 |
)
|
50 |
iface.launch(server_port=7860)
|
|
|
18 |
num_classes = 3 # Set the number of classes
|
19 |
# Initialize your model class
|
20 |
model = ShapeClassifier(num_classes=num_classes)
|
21 |
+
model.load_state_dict(torch.load('results/models/model.pth', map_location=torch.device('cpu')))
|
22 |
model.eval() # Set the model to evaluation mode
|
23 |
|
24 |
# Convert the drawing to a grayscale image
|
|
|
45 |
inputs=gr.Image(type="pil"), # Use Sketchpad as input
|
46 |
outputs="text",
|
47 |
live=True,
|
|
|
48 |
)
|
49 |
iface.launch(server_port=7860)
|
data/raw/.gitkeep
DELETED
File without changes
|
main.py
CHANGED
@@ -1,10 +1,47 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
import
|
4 |
-
|
5 |
-
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
if __name__ == "__main__":
|
8 |
-
|
9 |
-
|
10 |
-
auto_hyper_parameter()
|
|
|
1 |
+
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.optim as optim
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from src.models.model import ShapeClassifier
|
7 |
+
|
8 |
+
from src.configs.model_config import ModelConfig
|
9 |
+
from src.data.data_loader import train_loader, num_classes
|
10 |
+
from src.utils.tensorboard import writer
|
11 |
+
from src.utils.train import train
|
12 |
+
from src.utils.test import test
|
13 |
+
from src.utils.wandb import wandb
|
14 |
+
from src.utils.logs import logging
|
15 |
+
import json
|
16 |
+
from src.utils.model import save_model
|
17 |
+
|
18 |
+
|
19 |
+
def train_runner():
|
20 |
+
|
21 |
+
config = ModelConfig().get_config()
|
22 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
23 |
+
model = ShapeClassifier(num_classes=num_classes).to(device)
|
24 |
+
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
|
25 |
+
# log models config to wandb
|
26 |
+
wandb.config.update(config)
|
27 |
+
|
28 |
+
for epoch in range(config.epochs):
|
29 |
+
print(f"Epoch {epoch+1}\n-------------------------------")
|
30 |
+
|
31 |
+
loss = train(train_loader, model=model, loss_fn=F.cross_entropy,
|
32 |
+
optimizer=optimizer)
|
33 |
+
test(train_loader, model=model, loss_fn=F.cross_entropy)
|
34 |
+
# 3. Log metrics over time to visualize performance
|
35 |
+
wandb.log({"loss": loss})
|
36 |
+
|
37 |
+
# save model
|
38 |
+
save_model(model, "results/models/last.pth")
|
39 |
+
|
40 |
+
# 4. Log an artifact to W&B
|
41 |
+
wandb.log_artifact("model.pth")
|
42 |
+
# model.train()
|
43 |
+
|
44 |
+
|
45 |
if __name__ == "__main__":
|
46 |
+
logging.info("Training model")
|
47 |
+
train_runner()
|
|
{data/processed → results/models}/.gitkeep
RENAMED
File without changes
|
src/configs/model_config.py
CHANGED
@@ -1,11 +1,19 @@
|
|
1 |
import torch
|
2 |
|
|
|
3 |
class ModelConfig:
|
4 |
def __init__(self):
|
5 |
self.learning_rate = 0.001
|
6 |
self.batch_size = 32
|
7 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
8 |
self.epochs = 5
|
9 |
-
self.log_interval = 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
def get_config(self):
|
11 |
-
return self
|
|
|
1 |
import torch
|
2 |
|
3 |
+
|
4 |
class ModelConfig:
|
5 |
def __init__(self):
|
6 |
self.learning_rate = 0.001
|
7 |
self.batch_size = 32
|
8 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
9 |
self.epochs = 5
|
10 |
+
self.log_interval = 2 # Log every 2 batches => number of items is 32*2 = 64
|
11 |
+
|
12 |
+
# Wandb config
|
13 |
+
self.wandb = True
|
14 |
+
self.wandb_project = "template-pytorch-model"
|
15 |
+
self.wandb_entity = "nguyen"
|
16 |
+
self.wandb_api_key = ""
|
17 |
+
|
18 |
def get_config(self):
|
19 |
+
return self
|
src/train.py
DELETED
@@ -1,40 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.optim as optim
|
3 |
-
import torch.nn.functional as F
|
4 |
-
from src.models.model import ShapeClassifier
|
5 |
-
|
6 |
-
from src.configs.model_config import ModelConfig
|
7 |
-
from src.data.data_loader import train_loader, num_classes
|
8 |
-
from src.utils.logs import writer
|
9 |
-
from src.utils.train import train
|
10 |
-
from src.utils.test import test
|
11 |
-
import wandb
|
12 |
-
import json
|
13 |
-
wandb.init(project="template-pytorch-model", entity="nguyen")
|
14 |
-
|
15 |
-
|
16 |
-
def train_runner():
|
17 |
-
|
18 |
-
config = ModelConfig().get_config()
|
19 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
20 |
-
model = ShapeClassifier(num_classes=num_classes).to(device)
|
21 |
-
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
|
22 |
-
log_interval = 20
|
23 |
-
# log models config to wandb
|
24 |
-
wandb.config.update(config)
|
25 |
-
|
26 |
-
for epoch in range(config.epochs):
|
27 |
-
print(f"Epoch {epoch+1}\n-------------------------------")
|
28 |
-
|
29 |
-
loss = train(train_loader, model=model, loss_fn=F.cross_entropy,
|
30 |
-
optimizer=optimizer)
|
31 |
-
test(train_loader, model=model, loss_fn=F.cross_entropy)
|
32 |
-
# 3. Log metrics over time to visualize performance
|
33 |
-
wandb.log({"loss": loss})
|
34 |
-
|
35 |
-
# save model
|
36 |
-
torch.save(model.state_dict(), "model.pth")
|
37 |
-
|
38 |
-
# 4. Log an artifact to W&B
|
39 |
-
wandb.log_artifact("model.pth")
|
40 |
-
# model.train()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/utils/__init__.py
CHANGED
@@ -1,17 +1,8 @@
|
|
1 |
-
import importlib
|
2 |
-
import os
|
3 |
-
from inspect import isclass
|
4 |
|
5 |
-
# import all files under utils/
|
6 |
-
utils_dir = os.path.dirname(__file__)
|
7 |
-
for file in os.listdir(utils_dir):
|
8 |
-
path = os.path.join(utils_dir, file)
|
9 |
-
if not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)):
|
10 |
-
config_name = file[: file.find(".py")] if file.endswith(".py") else file
|
11 |
-
module = importlib.import_module("src.utils." + config_name)
|
12 |
-
for attribute_name in dir(module):
|
13 |
-
attribute = getattr(module, attribute_name)
|
14 |
|
15 |
-
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
+
# import all files under utils/
|
4 |
+
from os.path import dirname, basename, isfile, join
|
5 |
+
import glob
|
6 |
+
modules = glob.glob(join(dirname(__file__), "*.py"))
|
7 |
+
__all__ = [basename(f)[:-3] for f in modules if isfile(f)
|
8 |
+
and not f.endswith('__init__.py')]
|
src/utils/file.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
# enssure folder exists
|
4 |
+
|
5 |
+
|
6 |
+
def ensure_folder_exists(path):
|
7 |
+
if not os.path.exists(path):
|
8 |
+
os.makedirs(path)
|
src/utils/logs.py
CHANGED
@@ -1,2 +1,34 @@
|
|
1 |
-
|
2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
|
4 |
+
log_dir = "logs"
|
5 |
+
log_level = logging.INFO
|
6 |
+
|
7 |
+
if not os.path.exists(log_dir):
|
8 |
+
os.makedirs(log_dir)
|
9 |
+
|
10 |
+
log_filename = os.path.join(log_dir, "app.log")
|
11 |
+
|
12 |
+
logging.basicConfig(
|
13 |
+
filename=log_filename,
|
14 |
+
level=log_level,
|
15 |
+
format="%(asctime)s [%(levelname)s]: %(message)s",
|
16 |
+
datefmt="%Y-%m-%d %H:%M:%S"
|
17 |
+
)
|
18 |
+
|
19 |
+
console_handler = logging.StreamHandler()
|
20 |
+
console_handler.setLevel(log_level)
|
21 |
+
console_handler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s]: %(message)s"))
|
22 |
+
logging.getLogger().addHandler(console_handler)
|
23 |
+
|
24 |
+
def log_info(message):
|
25 |
+
logging.info(message)
|
26 |
+
|
27 |
+
def log_warning(message):
|
28 |
+
logging.warning(message)
|
29 |
+
|
30 |
+
def log_error(message):
|
31 |
+
logging.error(message)
|
32 |
+
|
33 |
+
def log_exception(message):
|
34 |
+
logging.exception(message)
|
src/utils/model.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# save model utils
|
2 |
+
import torch
|
3 |
+
import torch
|
4 |
+
import os
|
5 |
+
|
6 |
+
|
7 |
+
def save_model(model: torch.nn.Module, path: str) -> str:
|
8 |
+
parent_folder = os.path.dirname(path)
|
9 |
+
os.makedirs(parent_folder, exist_ok=True)
|
10 |
+
torch.save(model.state_dict(), path)
|
11 |
+
return path
|
12 |
+
|
13 |
+
def load_model(model: torch.nn.Module, path: str) -> torch.nn.Module:
|
14 |
+
model.load_state_dict(torch.load(path))
|
15 |
+
return model
|
src/utils/tensorboard.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from torch.utils.tensorboard import SummaryWriter
|
2 |
+
writer = SummaryWriter()
|
src/utils/wandb.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import wandb
|
2 |
+
from .logs import log_info
|
3 |
+
from src.configs.model_config import ModelConfig
|
4 |
+
config = ModelConfig().get_config()
|
5 |
+
|
6 |
+
if config.wandb:
|
7 |
+
project = config.wandb_project
|
8 |
+
entity = config.wandb_entity
|
9 |
+
api_key = config.wandb_api_key
|
10 |
+
wandb.login(key=api_key)
|
11 |
+
wandb.init(project=project, entity=entity)
|
12 |
+
log_info("Wandb is enabled")
|
13 |
+
else:
|
14 |
+
log_info("Wandb is disabled")
|
15 |
+
# disable wandb
|
16 |
+
wandb.init(mode="disabled")
|
17 |
+
pass
|