Spaces:
Running
Running
✨ [New] use lightning framework to training!
Browse files- yolo/__init__.py +10 -7
- yolo/lazy.py +21 -26
- yolo/tools/solver.py +77 -255
- yolo/utils/bounding_box_utils.py +7 -0
- yolo/utils/logging_utils.py +207 -197
- yolo/utils/model_utils.py +2 -16
- yolo/utils/solver_utils.py +3 -2
yolo/__init__.py
CHANGED
|
@@ -2,18 +2,22 @@ from yolo.config.config import Config, NMSConfig
|
|
| 2 |
from yolo.model.yolo import create_model
|
| 3 |
from yolo.tools.data_loader import AugmentationComposer, create_dataloader
|
| 4 |
from yolo.tools.drawer import draw_bboxes
|
| 5 |
-
from yolo.tools.solver import
|
| 6 |
from yolo.utils.bounding_box_utils import Anc2Box, Vec2Box, bbox_nms, create_converter
|
| 7 |
from yolo.utils.deploy_utils import FastModelLoader
|
| 8 |
-
from yolo.utils.logging_utils import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
from yolo.utils.model_utils import PostProccess
|
| 10 |
|
| 11 |
all = [
|
| 12 |
"create_model",
|
| 13 |
"Config",
|
| 14 |
-
"
|
| 15 |
"NMSConfig",
|
| 16 |
-
"
|
| 17 |
"validate_log_directory",
|
| 18 |
"draw_bboxes",
|
| 19 |
"Vec2Box",
|
|
@@ -21,10 +25,9 @@ all = [
|
|
| 21 |
"bbox_nms",
|
| 22 |
"create_converter",
|
| 23 |
"AugmentationComposer",
|
|
|
|
| 24 |
"create_dataloader",
|
| 25 |
"FastModelLoader",
|
| 26 |
-
"
|
| 27 |
-
"ModelTrainer",
|
| 28 |
-
"ModelValidator",
|
| 29 |
"PostProccess",
|
| 30 |
]
|
|
|
|
| 2 |
from yolo.model.yolo import create_model
|
| 3 |
from yolo.tools.data_loader import AugmentationComposer, create_dataloader
|
| 4 |
from yolo.tools.drawer import draw_bboxes
|
| 5 |
+
from yolo.tools.solver import TrainModel
|
| 6 |
from yolo.utils.bounding_box_utils import Anc2Box, Vec2Box, bbox_nms, create_converter
|
| 7 |
from yolo.utils.deploy_utils import FastModelLoader
|
| 8 |
+
from yolo.utils.logging_utils import (
|
| 9 |
+
ImageLogger,
|
| 10 |
+
YOLORichModelSummary,
|
| 11 |
+
YOLORichProgressBar,
|
| 12 |
+
)
|
| 13 |
from yolo.utils.model_utils import PostProccess
|
| 14 |
|
| 15 |
all = [
|
| 16 |
"create_model",
|
| 17 |
"Config",
|
| 18 |
+
"YOLORichProgressBar",
|
| 19 |
"NMSConfig",
|
| 20 |
+
"YOLORichModelSummary",
|
| 21 |
"validate_log_directory",
|
| 22 |
"draw_bboxes",
|
| 23 |
"Vec2Box",
|
|
|
|
| 25 |
"bbox_nms",
|
| 26 |
"create_converter",
|
| 27 |
"AugmentationComposer",
|
| 28 |
+
"ImageLogger",
|
| 29 |
"create_dataloader",
|
| 30 |
"FastModelLoader",
|
| 31 |
+
"TrainModel",
|
|
|
|
|
|
|
| 32 |
"PostProccess",
|
| 33 |
]
|
yolo/lazy.py
CHANGED
|
@@ -2,41 +2,36 @@ import sys
|
|
| 2 |
from pathlib import Path
|
| 3 |
|
| 4 |
import hydra
|
|
|
|
| 5 |
|
| 6 |
project_root = Path(__file__).resolve().parent.parent
|
| 7 |
sys.path.append(str(project_root))
|
| 8 |
|
| 9 |
from yolo.config.config import Config
|
| 10 |
-
from yolo.
|
| 11 |
-
from yolo.
|
| 12 |
-
from yolo.tools.solver import ModelTester, ModelTrainer, ModelValidator
|
| 13 |
-
from yolo.utils.bounding_box_utils import create_converter
|
| 14 |
-
from yolo.utils.deploy_utils import FastModelLoader
|
| 15 |
-
from yolo.utils.logging_utils import ProgressLogger
|
| 16 |
-
from yolo.utils.model_utils import get_device
|
| 17 |
|
| 18 |
|
| 19 |
@hydra.main(config_path="config", config_name="config", version_base=None)
|
| 20 |
def main(cfg: Config):
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
solver.solve(dataloader)
|
| 40 |
|
| 41 |
|
| 42 |
if __name__ == "__main__":
|
|
|
|
| 2 |
from pathlib import Path
|
| 3 |
|
| 4 |
import hydra
|
| 5 |
+
from lightning import Trainer
|
| 6 |
|
| 7 |
project_root = Path(__file__).resolve().parent.parent
|
| 8 |
sys.path.append(str(project_root))
|
| 9 |
|
| 10 |
from yolo.config.config import Config
|
| 11 |
+
from yolo.tools.solver import TrainModel, ValidateModel
|
| 12 |
+
from yolo.utils.logging_utils import setup
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
@hydra.main(config_path="config", config_name="config", version_base=None)
|
| 16 |
def main(cfg: Config):
|
| 17 |
+
callbacks, loggers = setup(cfg)
|
| 18 |
+
|
| 19 |
+
trainer = Trainer(
|
| 20 |
+
accelerator="cuda",
|
| 21 |
+
max_epochs=getattr(cfg.task, "epoch", None),
|
| 22 |
+
precision="16-mixed",
|
| 23 |
+
callbacks=callbacks,
|
| 24 |
+
logger=loggers,
|
| 25 |
+
log_every_n_steps=1,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
match cfg.task.task:
|
| 29 |
+
case "train":
|
| 30 |
+
model = TrainModel(cfg)
|
| 31 |
+
trainer.fit(model)
|
| 32 |
+
case "validation":
|
| 33 |
+
model = ValidateModel(cfg)
|
| 34 |
+
trainer.validate(model)
|
|
|
|
| 35 |
|
| 36 |
|
| 37 |
if __name__ == "__main__":
|
yolo/tools/solver.py
CHANGED
|
@@ -1,267 +1,89 @@
|
|
| 1 |
-
import
|
| 2 |
-
import
|
| 3 |
-
import json
|
| 4 |
-
import os
|
| 5 |
-
import time
|
| 6 |
-
from collections import defaultdict
|
| 7 |
-
from pathlib import Path
|
| 8 |
-
from typing import Dict, Optional
|
| 9 |
|
| 10 |
-
import
|
| 11 |
-
from
|
| 12 |
-
from
|
| 13 |
-
from torch.cuda.amp import GradScaler, autocast
|
| 14 |
-
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 15 |
-
from torch.utils.data import DataLoader
|
| 16 |
-
|
| 17 |
-
from yolo.config.config import Config, DatasetConfig, TrainConfig, ValidationConfig
|
| 18 |
-
from yolo.model.yolo import YOLO
|
| 19 |
-
from yolo.tools.data_loader import StreamDataLoader, create_dataloader
|
| 20 |
-
from yolo.tools.drawer import draw_bboxes, draw_model
|
| 21 |
from yolo.tools.loss_functions import create_loss_function
|
| 22 |
-
from yolo.utils.bounding_box_utils import
|
| 23 |
-
from yolo.utils.
|
| 24 |
-
from yolo.utils.logger import logger
|
| 25 |
-
from yolo.utils.logging_utils import ProgressLogger, log_model_structure
|
| 26 |
-
from yolo.utils.model_utils import (
|
| 27 |
-
ExponentialMovingAverage,
|
| 28 |
-
PostProccess,
|
| 29 |
-
collect_prediction,
|
| 30 |
-
create_optimizer,
|
| 31 |
-
create_scheduler,
|
| 32 |
-
predicts_to_json,
|
| 33 |
-
)
|
| 34 |
-
from yolo.utils.solver_utils import calculate_ap
|
| 35 |
-
|
| 36 |
|
| 37 |
-
class ModelTrainer:
|
| 38 |
-
def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, progress: ProgressLogger, device, use_ddp: bool):
|
| 39 |
-
train_cfg: TrainConfig = cfg.task
|
| 40 |
-
self.model = model if not use_ddp else DDP(model, device_ids=[device])
|
| 41 |
-
self.use_ddp = use_ddp
|
| 42 |
-
self.vec2box = vec2box
|
| 43 |
-
self.device = device
|
| 44 |
-
self.optimizer = create_optimizer(model, train_cfg.optimizer)
|
| 45 |
-
self.scheduler = create_scheduler(self.optimizer, train_cfg.scheduler)
|
| 46 |
-
self.loss_fn = create_loss_function(cfg, vec2box)
|
| 47 |
-
self.progress = progress
|
| 48 |
-
self.num_epochs = cfg.task.epoch
|
| 49 |
-
self.mAPs_dict = defaultdict(list)
|
| 50 |
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
| 53 |
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
draw_model(model=model)
|
| 57 |
|
| 58 |
-
self.validation_dataloader = create_dataloader(
|
| 59 |
-
cfg.task.validation.data, cfg.dataset, cfg.task.validation.task, use_ddp
|
| 60 |
-
)
|
| 61 |
-
self.validator = ModelValidator(cfg.task.validation, cfg.dataset, model, vec2box, progress, device)
|
| 62 |
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
else:
|
| 66 |
-
self.
|
| 67 |
-
self.
|
| 68 |
-
|
| 69 |
-
def train_one_batch(self, images: Tensor, targets: Tensor):
|
| 70 |
-
images, targets = images.to(self.device), targets.to(self.device)
|
| 71 |
-
self.optimizer.zero_grad()
|
| 72 |
-
|
| 73 |
-
with autocast():
|
| 74 |
-
predicts = self.model(images)
|
| 75 |
-
aux_predicts = self.vec2box(predicts["AUX"])
|
| 76 |
-
main_predicts = self.vec2box(predicts["Main"])
|
| 77 |
-
loss, loss_item = self.loss_fn(aux_predicts, main_predicts, targets)
|
| 78 |
-
|
| 79 |
-
self.scaler.scale(loss).backward()
|
| 80 |
-
self.scaler.unscale_(self.optimizer)
|
| 81 |
-
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0)
|
| 82 |
-
self.scaler.step(self.optimizer)
|
| 83 |
-
self.scaler.update()
|
| 84 |
-
|
| 85 |
-
return loss_item
|
| 86 |
-
|
| 87 |
-
def train_one_epoch(self, dataloader):
|
| 88 |
-
self.model.train()
|
| 89 |
-
total_loss = defaultdict(float)
|
| 90 |
-
total_samples = 0
|
| 91 |
-
self.optimizer.next_epoch(len(dataloader))
|
| 92 |
-
for batch_size, images, targets, *_ in dataloader:
|
| 93 |
-
self.optimizer.next_batch()
|
| 94 |
-
loss_each = self.train_one_batch(images, targets)
|
| 95 |
-
|
| 96 |
-
for loss_name, loss_val in loss_each.items():
|
| 97 |
-
if self.use_ddp: # collecting loss for each batch
|
| 98 |
-
distributed.all_reduce(loss_val, op=distributed.ReduceOp.AVG)
|
| 99 |
-
total_loss[loss_name] += loss_val.item() * batch_size
|
| 100 |
-
total_samples += batch_size
|
| 101 |
-
self.progress.one_batch(loss_each)
|
| 102 |
-
|
| 103 |
-
for loss_val in total_loss.values():
|
| 104 |
-
loss_val /= total_samples
|
| 105 |
-
|
| 106 |
-
if self.scheduler:
|
| 107 |
-
self.scheduler.step()
|
| 108 |
-
|
| 109 |
-
return total_loss
|
| 110 |
-
|
| 111 |
-
def save_checkpoint(self, epoch_idx: int, file_name: Optional[str] = None):
|
| 112 |
-
file_name = file_name or f"E{epoch_idx:03d}.pt"
|
| 113 |
-
file_path = self.weights_dir / file_name
|
| 114 |
-
|
| 115 |
-
checkpoint = {
|
| 116 |
-
"epoch": epoch_idx,
|
| 117 |
-
"model_state_dict": self.model.state_dict(),
|
| 118 |
-
"optimizer_state_dict": self.optimizer.state_dict(),
|
| 119 |
-
}
|
| 120 |
-
if self.ema:
|
| 121 |
-
self.ema.apply_shadow()
|
| 122 |
-
checkpoint["model_state_dict_ema"] = self.model.state_dict()
|
| 123 |
-
self.ema.restore()
|
| 124 |
-
|
| 125 |
-
logger.info(f"💾 success save at {file_path}")
|
| 126 |
-
torch.save(checkpoint, file_path)
|
| 127 |
-
|
| 128 |
-
def good_epoch(self, mAPs: Dict[str, Tensor]) -> bool:
|
| 129 |
-
save_flag = True
|
| 130 |
-
for mAP_key, mAP_val in mAPs.items():
|
| 131 |
-
self.mAPs_dict[mAP_key].append(mAP_val)
|
| 132 |
-
if mAP_val < max(self.mAPs_dict[mAP_key]):
|
| 133 |
-
save_flag = False
|
| 134 |
-
return save_flag
|
| 135 |
-
|
| 136 |
-
def solve(self, dataloader: DataLoader):
|
| 137 |
-
logger.info("🚄 Start Training!")
|
| 138 |
-
num_epochs = self.num_epochs
|
| 139 |
-
|
| 140 |
-
self.progress.start_train(num_epochs)
|
| 141 |
-
for epoch_idx in range(num_epochs):
|
| 142 |
-
if self.use_ddp:
|
| 143 |
-
dataloader.sampler.set_epoch(epoch_idx)
|
| 144 |
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
self.
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
if mAPs is not None and self.good_epoch(mAPs):
|
| 151 |
-
self.save_checkpoint(epoch_idx=epoch_idx)
|
| 152 |
-
# TODO: save model if result are better than before
|
| 153 |
-
self.progress.finish_train()
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
class ModelTester:
|
| 157 |
-
def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, progress: ProgressLogger, device):
|
| 158 |
-
self.model = model
|
| 159 |
-
self.device = device
|
| 160 |
-
self.progress = progress
|
| 161 |
-
|
| 162 |
-
self.post_proccess = PostProccess(vec2box, cfg.task.nms)
|
| 163 |
-
self.save_path = progress.save_path / "images"
|
| 164 |
-
os.makedirs(self.save_path, exist_ok=True)
|
| 165 |
-
self.save_predict = getattr(cfg.task, "save_predict", None)
|
| 166 |
-
self.idx2label = cfg.dataset.class_list
|
| 167 |
-
|
| 168 |
-
def solve(self, dataloader: StreamDataLoader):
|
| 169 |
-
logger.info("👀 Start Inference!")
|
| 170 |
-
if isinstance(self.model, torch.nn.Module):
|
| 171 |
-
self.model.eval()
|
| 172 |
-
|
| 173 |
-
if dataloader.is_stream:
|
| 174 |
-
import cv2
|
| 175 |
-
import numpy as np
|
| 176 |
-
|
| 177 |
-
last_time = time.time()
|
| 178 |
-
try:
|
| 179 |
-
for idx, (images, rev_tensor, origin_frame) in enumerate(dataloader):
|
| 180 |
-
images = images.to(self.device)
|
| 181 |
-
rev_tensor = rev_tensor.to(self.device)
|
| 182 |
-
with torch.no_grad():
|
| 183 |
-
predicts = self.model(images)
|
| 184 |
-
predicts = self.post_proccess(predicts, rev_tensor)
|
| 185 |
-
img = draw_bboxes(origin_frame, predicts, idx2label=self.idx2label)
|
| 186 |
-
|
| 187 |
-
if dataloader.is_stream:
|
| 188 |
-
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
| 189 |
-
fps = 1 / (time.time() - last_time)
|
| 190 |
-
cv2.putText(img, f"FPS: {fps:.2f}", (0, 15), 0, 0.5, (100, 255, 0), 1, cv2.LINE_AA)
|
| 191 |
-
last_time = time.time()
|
| 192 |
-
cv2.imshow("Prediction", img)
|
| 193 |
-
if cv2.waitKey(1) & 0xFF == ord("q"):
|
| 194 |
-
break
|
| 195 |
-
if not self.save_predict:
|
| 196 |
-
continue
|
| 197 |
-
if self.save_predict != False:
|
| 198 |
-
save_image_path = self.save_path / f"frame{idx:03d}.png"
|
| 199 |
-
img.save(save_image_path)
|
| 200 |
-
logger.info(f"💾 Saved visualize image at {save_image_path}")
|
| 201 |
-
|
| 202 |
-
except (KeyboardInterrupt, Exception) as e:
|
| 203 |
-
dataloader.stop_event.set()
|
| 204 |
-
dataloader.stop()
|
| 205 |
-
if isinstance(e, KeyboardInterrupt):
|
| 206 |
-
logger.error("User Keyboard Interrupt")
|
| 207 |
-
else:
|
| 208 |
-
raise e
|
| 209 |
-
dataloader.stop()
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
class ModelValidator:
|
| 213 |
-
def __init__(
|
| 214 |
-
self,
|
| 215 |
-
validation_cfg: ValidationConfig,
|
| 216 |
-
dataset_cfg: DatasetConfig,
|
| 217 |
-
model: YOLO,
|
| 218 |
-
vec2box: Vec2Box,
|
| 219 |
-
progress: ProgressLogger,
|
| 220 |
-
device,
|
| 221 |
-
):
|
| 222 |
-
self.model = model
|
| 223 |
-
self.device = device
|
| 224 |
-
self.progress = progress
|
| 225 |
-
|
| 226 |
-
self.post_proccess = PostProccess(vec2box, validation_cfg.nms)
|
| 227 |
-
self.json_path = self.progress.save_path / "predict.json"
|
| 228 |
-
|
| 229 |
-
with contextlib.redirect_stdout(io.StringIO()):
|
| 230 |
-
# TODO: load with config file
|
| 231 |
-
json_path, _ = locate_label_paths(Path(dataset_cfg.path), dataset_cfg.get("validation", "val"))
|
| 232 |
-
if json_path:
|
| 233 |
-
self.coco_gt = COCO(json_path)
|
| 234 |
-
|
| 235 |
-
def solve(self, dataloader, epoch_idx=1):
|
| 236 |
-
# logger.info("🧪 Start Validation!")
|
| 237 |
-
self.model.eval()
|
| 238 |
-
predict_json, mAPs = [], defaultdict(list)
|
| 239 |
-
self.progress.start_one_epoch(len(dataloader), task="Validate")
|
| 240 |
-
for batch_size, images, targets, rev_tensor, img_paths in dataloader:
|
| 241 |
-
images, targets, rev_tensor = images.to(self.device), targets.to(self.device), rev_tensor.to(self.device)
|
| 242 |
-
with torch.no_grad():
|
| 243 |
-
predicts = self.model(images)
|
| 244 |
-
predicts = self.post_proccess(predicts)
|
| 245 |
-
for idx, predict in enumerate(predicts):
|
| 246 |
-
mAP = calculate_map(predict, targets[idx])
|
| 247 |
-
for mAP_key, mAP_val in mAP.items():
|
| 248 |
-
mAPs[mAP_key].append(mAP_val)
|
| 249 |
-
|
| 250 |
-
avg_mAPs = {key: 100 * torch.mean(torch.stack(val)) for key, val in mAPs.items()}
|
| 251 |
-
self.progress.one_batch(avg_mAPs)
|
| 252 |
|
| 253 |
-
|
| 254 |
-
self.
|
| 255 |
-
self.progress.visualize_image(images, targets, predicts, epoch_idx=epoch_idx)
|
| 256 |
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
self.progress.start_pycocotools()
|
| 264 |
-
result = calculate_ap(self.coco_gt, predict_json)
|
| 265 |
-
self.progress.finish_pycocotools(result, epoch_idx)
|
| 266 |
|
| 267 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from lightning import LightningModule
|
| 2 |
+
from torchmetrics.detection import MeanAveragePrecision
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
+
from yolo.config.config import Config
|
| 5 |
+
from yolo.model.yolo import create_model
|
| 6 |
+
from yolo.tools.data_loader import create_dataloader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
from yolo.tools.loss_functions import create_loss_function
|
| 8 |
+
from yolo.utils.bounding_box_utils import create_converter, to_metrics_format
|
| 9 |
+
from yolo.utils.model_utils import PostProccess, create_optimizer, create_scheduler
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
+
class BaseModel(LightningModule):
|
| 13 |
+
def __init__(self, cfg: Config):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.model = create_model(cfg.model, class_num=cfg.dataset.class_num, weight_path=cfg.weight)
|
| 16 |
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
return self.model(x)
|
|
|
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
+
class ValidateModel(BaseModel):
|
| 22 |
+
def __init__(self, cfg: Config):
|
| 23 |
+
super().__init__(cfg)
|
| 24 |
+
self.cfg = cfg
|
| 25 |
+
if self.cfg.task.task == "validation":
|
| 26 |
+
self.validation_cfg = self.cfg.task
|
| 27 |
else:
|
| 28 |
+
self.validation_cfg = self.cfg.task.validation
|
| 29 |
+
self.metric = MeanAveragePrecision(iou_type="bbox", box_format="xyxy")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
+
def setup(self, stage):
|
| 32 |
+
self.vec2box = create_converter(
|
| 33 |
+
self.cfg.model.name, self.model, self.cfg.model.anchor, self.cfg.image_size, self.device
|
| 34 |
+
)
|
| 35 |
+
self.post_proccess = PostProccess(self.vec2box, self.validation_cfg.nms)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
+
def val_dataloader(self):
|
| 38 |
+
return create_dataloader(self.validation_cfg.data, self.cfg.dataset, self.validation_cfg.task)
|
|
|
|
| 39 |
|
| 40 |
+
def validation_step(self, batch, batch_idx):
|
| 41 |
+
batch_size, images, targets, rev_tensor, img_paths = batch
|
| 42 |
+
predicts = self.post_proccess(self(images))
|
| 43 |
+
batch_metrics = self.metric(
|
| 44 |
+
[to_metrics_format(predict) for predict in predicts], [to_metrics_format(target) for target in targets]
|
| 45 |
+
)
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
+
self.log_dict(
|
| 48 |
+
{
|
| 49 |
+
"map": batch_metrics["map"],
|
| 50 |
+
"map_50": batch_metrics["map_50"],
|
| 51 |
+
},
|
| 52 |
+
on_step=True,
|
| 53 |
+
prog_bar=True,
|
| 54 |
+
logger=False,
|
| 55 |
+
batch_size=batch_size,
|
| 56 |
+
)
|
| 57 |
+
return predicts
|
| 58 |
+
|
| 59 |
+
def on_validation_epoch_end(self):
|
| 60 |
+
epoch_metrics = self.metric.compute()
|
| 61 |
+
del epoch_metrics["classes"]
|
| 62 |
+
self.log_dict(epoch_metrics, on_epoch=True, prog_bar=True, logger=True)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class TrainModel(ValidateModel):
|
| 66 |
+
def __init__(self, cfg: Config):
|
| 67 |
+
super().__init__(cfg)
|
| 68 |
+
self.cfg = cfg
|
| 69 |
+
|
| 70 |
+
def setup(self, stage):
|
| 71 |
+
super().setup(stage)
|
| 72 |
+
self.loss_fn = create_loss_function(self.cfg, self.vec2box)
|
| 73 |
+
|
| 74 |
+
def train_dataloader(self):
|
| 75 |
+
return create_dataloader(self.cfg.task.data, self.cfg.dataset, self.cfg.task.task)
|
| 76 |
+
|
| 77 |
+
def training_step(self, batch, batch_idx):
|
| 78 |
+
batch_size, images, targets, *_ = batch
|
| 79 |
+
predicts = self(images)
|
| 80 |
+
aux_predicts = self.vec2box(predicts["AUX"])
|
| 81 |
+
main_predicts = self.vec2box(predicts["Main"])
|
| 82 |
+
loss, loss_item = self.loss_fn(aux_predicts, main_predicts, targets)
|
| 83 |
+
self.log_dict(loss_item, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_size)
|
| 84 |
+
return loss * batch_size
|
| 85 |
+
|
| 86 |
+
def configure_optimizers(self):
|
| 87 |
+
optimizer = create_optimizer(self.model, self.cfg.task.optimizer)
|
| 88 |
+
scheduler = create_scheduler(optimizer, self.cfg.task.scheduler)
|
| 89 |
+
return [optimizer], [scheduler]
|
yolo/utils/bounding_box_utils.py
CHANGED
|
@@ -446,3 +446,10 @@ def calculate_map(predictions, ground_truths, iou_thresholds=arange(0.5, 1, 0.05
|
|
| 446 |
"mAP.5:.95": torch.mean(torch.stack(aps)),
|
| 447 |
}
|
| 448 |
return mAP
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 446 |
"mAP.5:.95": torch.mean(torch.stack(aps)),
|
| 447 |
}
|
| 448 |
return mAP
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def to_metrics_format(prediction: Tensor) -> Dict[str, Union[float, Tensor]]:
|
| 452 |
+
bbox = {"boxes": prediction[:, 1:5], "labels": prediction[:, 0].int()}
|
| 453 |
+
if prediction.size(1) == 6:
|
| 454 |
+
bbox["scores"] = prediction[:, 5]
|
| 455 |
+
return bbox
|
yolo/utils/logging_utils.py
CHANGED
|
@@ -11,9 +11,7 @@ Example:
|
|
| 11 |
custom_logger()
|
| 12 |
"""
|
| 13 |
|
| 14 |
-
import
|
| 15 |
-
import random
|
| 16 |
-
import sys
|
| 17 |
from collections import deque
|
| 18 |
from logging import FileHandler
|
| 19 |
from pathlib import Path
|
|
@@ -22,39 +20,29 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|
| 22 |
import numpy as np
|
| 23 |
import torch
|
| 24 |
import wandb
|
| 25 |
-
import
|
|
|
|
|
|
|
|
|
|
| 26 |
from omegaconf import ListConfig
|
|
|
|
| 27 |
from rich.console import Console, Group
|
| 28 |
-
from rich.
|
| 29 |
-
BarColumn,
|
| 30 |
-
Progress,
|
| 31 |
-
SpinnerColumn,
|
| 32 |
-
TextColumn,
|
| 33 |
-
TimeRemainingColumn,
|
| 34 |
-
)
|
| 35 |
from rich.table import Table
|
|
|
|
| 36 |
from torch import Tensor
|
| 37 |
from torch.nn import ModuleList
|
| 38 |
-
from
|
| 39 |
-
from torchvision.transforms.functional import pil_to_tensor
|
| 40 |
|
| 41 |
from yolo.config.config import Config, YOLOLayer
|
| 42 |
from yolo.model.yolo import YOLO
|
| 43 |
-
from yolo.tools.drawer import draw_bboxes
|
| 44 |
from yolo.utils.logger import logger
|
| 45 |
from yolo.utils.solver_utils import make_ap_table
|
| 46 |
|
| 47 |
|
| 48 |
-
def custom_logger(quite: bool = False):
|
| 49 |
-
if quite:
|
| 50 |
-
logger.removeHandler("YOLO_logger")
|
| 51 |
-
|
| 52 |
-
|
| 53 |
# TODO: should be moved to correct position
|
| 54 |
def set_seed(seed):
|
| 55 |
-
|
| 56 |
-
np.random.seed(seed)
|
| 57 |
-
torch.manual_seed(seed)
|
| 58 |
if torch.cuda.is_available():
|
| 59 |
torch.cuda.manual_seed(seed)
|
| 60 |
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
|
|
@@ -62,189 +50,211 @@ def set_seed(seed):
|
|
| 62 |
torch.backends.cudnn.benchmark = False
|
| 63 |
|
| 64 |
|
| 65 |
-
class
|
| 66 |
-
def
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
self.save_path = validate_log_directory(cfg, exp_name=cfg.name)
|
| 72 |
-
|
| 73 |
-
progress_bar = (
|
| 74 |
-
SpinnerColumn(),
|
| 75 |
-
TextColumn("[progress.description]{task.description}"),
|
| 76 |
-
BarColumn(bar_width=None),
|
| 77 |
-
TextColumn("{task.completed:.0f}/{task.total:.0f}"),
|
| 78 |
-
TimeRemainingColumn(),
|
| 79 |
-
)
|
| 80 |
-
self.ap_table = Table()
|
| 81 |
-
# TODO: load maxlen by config files
|
| 82 |
-
self.ap_past_list = deque(maxlen=5)
|
| 83 |
-
self.last_result = 0
|
| 84 |
-
super().__init__(*args, *progress_bar, **kwargs)
|
| 85 |
-
|
| 86 |
-
self.use_wandb = cfg.use_wandb
|
| 87 |
-
if self.use_wandb and self.local_rank == 0:
|
| 88 |
-
wandb.errors.term._log = custom_wandb_log
|
| 89 |
-
self.wandb = wandb.init(
|
| 90 |
-
project="YOLO", resume="allow", mode="online", dir=self.save_path, id=None, name=exp_name
|
| 91 |
-
)
|
| 92 |
|
| 93 |
-
self.use_tensorboard = cfg.use_tensorboard
|
| 94 |
-
if self.use_tensorboard and self.local_rank == 0:
|
| 95 |
-
from torch.utils.tensorboard import SummaryWriter
|
| 96 |
|
| 97 |
-
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
-
|
| 101 |
-
def wrapper(self, *args, **kwargs):
|
| 102 |
-
if getattr(self, "local_rank", 0) != 0:
|
| 103 |
-
return
|
| 104 |
-
return logging_function(self, *args, **kwargs)
|
| 105 |
|
| 106 |
-
|
|
|
|
|
|
|
| 107 |
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
return
|
| 111 |
|
| 112 |
-
@
|
| 113 |
-
def
|
| 114 |
-
self.
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
self.
|
| 122 |
-
self.
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
if
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
self.
|
| 138 |
-
|
| 139 |
-
@rank_check
|
| 140 |
-
def one_batch(self, batch_info: Dict[str, Tensor] = None):
|
| 141 |
-
epoch_descript = "[cyan]" + self.task + "[white] |"
|
| 142 |
-
batch_descript = "|"
|
| 143 |
-
if self.task == "Train":
|
| 144 |
-
self.update(self.task_epoch, advance=1 / self.num_batches)
|
| 145 |
-
for info_name, info_val in batch_info.items():
|
| 146 |
-
epoch_descript += f"{info_name: ^9}|"
|
| 147 |
-
batch_descript += f" {info_val:2.2f} |"
|
| 148 |
-
self.update(self.batch_task, advance=1, description=f"[green]{self.task} [white]{batch_descript}")
|
| 149 |
-
if hasattr(self, "task_epoch"):
|
| 150 |
-
self.update(self.task_epoch, description=epoch_descript)
|
| 151 |
-
|
| 152 |
-
@rank_check
|
| 153 |
-
def finish_one_epoch(self, batch_info: Dict[str, Any] = None, epoch_idx: int = -1):
|
| 154 |
-
if self.task == "Train":
|
| 155 |
-
prefix = "Loss"
|
| 156 |
-
elif self.task == "Validate":
|
| 157 |
-
prefix = "Metrics"
|
| 158 |
-
batch_info = {f"{prefix}/{key}": value for key, value in batch_info.items()}
|
| 159 |
-
if self.use_wandb:
|
| 160 |
-
self.wandb.log(batch_info, step=epoch_idx)
|
| 161 |
-
if self.use_tensorboard:
|
| 162 |
-
for key, value in batch_info.items():
|
| 163 |
-
self.tb_writer.add_scalar(key, value, epoch_idx)
|
| 164 |
-
|
| 165 |
-
self.remove_task(self.batch_task)
|
| 166 |
-
|
| 167 |
-
@rank_check
|
| 168 |
-
def visualize_image(
|
| 169 |
-
self,
|
| 170 |
-
images: Optional[Tensor] = None,
|
| 171 |
-
ground_truth: Optional[Tensor] = None,
|
| 172 |
-
prediction: Optional[Union[List[Tensor], Tensor]] = None,
|
| 173 |
-
epoch_idx: int = 0,
|
| 174 |
-
) -> None:
|
| 175 |
-
"""
|
| 176 |
-
Upload the ground truth bounding boxes, predicted bounding boxes, and the original image to wandb or TensorBoard.
|
| 177 |
-
|
| 178 |
-
Args:
|
| 179 |
-
images (Optional[Tensor]): Tensor of images with shape (BZ, 3, 640, 640).
|
| 180 |
-
ground_truth (Optional[Tensor]): Ground truth bounding boxes with shape (BZ, N, 5) or (N, 5). Defaults to None.
|
| 181 |
-
prediction (prediction: Optional[Union[List[Tensor], Tensor]]): List of predicted bounding boxes with shape (N, 6) or (N, 6). Defaults to None.
|
| 182 |
-
epoch_idx (int): Current epoch index. Defaults to 0.
|
| 183 |
-
"""
|
| 184 |
-
if images is not None:
|
| 185 |
-
images = images[0] if images.ndim == 4 else images
|
| 186 |
-
if self.use_wandb:
|
| 187 |
-
wandb.log({"Input Image": wandb.Image(images)}, step=epoch_idx)
|
| 188 |
-
if self.use_tensorboard:
|
| 189 |
-
self.tb_writer.add_image("Media/Input Image", images, 1)
|
| 190 |
-
|
| 191 |
-
if ground_truth is not None:
|
| 192 |
-
gt_boxes = ground_truth[0] if ground_truth.ndim == 3 else ground_truth
|
| 193 |
-
if self.use_wandb:
|
| 194 |
-
wandb.log(
|
| 195 |
-
{"Ground Truth": wandb.Image(images, boxes={"predictions": {"box_data": log_bbox(gt_boxes)}})},
|
| 196 |
-
step=epoch_idx,
|
| 197 |
-
)
|
| 198 |
-
if self.use_tensorboard:
|
| 199 |
-
self.tb_writer.add_image("Media/Ground Truth", pil_to_tensor(draw_bboxes(images, gt_boxes)), epoch_idx)
|
| 200 |
-
|
| 201 |
-
if prediction is not None:
|
| 202 |
-
pred_boxes = prediction[0] if isinstance(prediction, list) else prediction
|
| 203 |
-
if self.use_wandb:
|
| 204 |
-
wandb.log(
|
| 205 |
-
{"Prediction": wandb.Image(images, boxes={"predictions": {"box_data": log_bbox(pred_boxes)}})},
|
| 206 |
-
step=epoch_idx,
|
| 207 |
-
)
|
| 208 |
-
if self.use_tensorboard:
|
| 209 |
-
self.tb_writer.add_image("Media/Prediction", pil_to_tensor(draw_bboxes(images, pred_boxes)), epoch_idx)
|
| 210 |
-
|
| 211 |
-
@rank_check
|
| 212 |
-
def start_pycocotools(self):
|
| 213 |
-
self.batch_task = self.add_task("[green]Run pycocotools", total=1)
|
| 214 |
-
|
| 215 |
-
@rank_check
|
| 216 |
-
def finish_pycocotools(self, result, epoch_idx=-1):
|
| 217 |
-
ap_table, ap_main = make_ap_table(result * 100, self.ap_past_list, self.last_result, epoch_idx)
|
| 218 |
-
self.last_result = np.maximum(result, self.last_result)
|
| 219 |
-
self.ap_past_list.append((epoch_idx, ap_main))
|
| 220 |
-
self.ap_table = ap_table
|
| 221 |
-
|
| 222 |
-
if self.use_wandb:
|
| 223 |
-
self.wandb.log({"PyCOCO/AP @ .5:.95": ap_main[2], "PyCOCO/AP @ .5": ap_main[5]})
|
| 224 |
-
if self.use_tensorboard:
|
| 225 |
-
# TODO: waiting torch bugs fix, https://github.com/pytorch/pytorch/issues/32651
|
| 226 |
-
self.tb_writer.add_scalar("PyCOCO/AP @ .5:.95", ap_main[2], epoch_idx)
|
| 227 |
-
self.tb_writer.add_scalar("PyCOCO/AP @ .5", ap_main[5], epoch_idx)
|
| 228 |
-
|
| 229 |
-
self.update(self.batch_task, advance=1)
|
| 230 |
self.refresh()
|
| 231 |
-
self.remove_task(self.batch_task)
|
| 232 |
|
| 233 |
-
@
|
| 234 |
-
def
|
| 235 |
-
self.
|
| 236 |
-
self.
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
|
| 243 |
-
def custom_wandb_log(string="", level=int, newline=True, repeat=True, prefix=True, silent=False):
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
|
| 249 |
|
| 250 |
def log_model_structure(model: Union[ModuleList, YOLOLayer, YOLO]):
|
|
@@ -291,7 +301,7 @@ def validate_log_directory(cfg: Config, exp_name: str) -> Path:
|
|
| 291 |
)
|
| 292 |
|
| 293 |
save_path.mkdir(parents=True, exist_ok=True)
|
| 294 |
-
logger.info(f"📄 Created log folder: [
|
| 295 |
logger.addHandler(FileHandler(save_path / "output.log"))
|
| 296 |
return save_path
|
| 297 |
|
|
@@ -327,4 +337,4 @@ def log_bbox(
|
|
| 327 |
bbox_entry["scores"] = {"confidence": conf[0]}
|
| 328 |
bbox_list.append(bbox_entry)
|
| 329 |
|
| 330 |
-
return bbox_list
|
|
|
|
| 11 |
custom_logger()
|
| 12 |
"""
|
| 13 |
|
| 14 |
+
import logging
|
|
|
|
|
|
|
| 15 |
from collections import deque
|
| 16 |
from logging import FileHandler
|
| 17 |
from pathlib import Path
|
|
|
|
| 20 |
import numpy as np
|
| 21 |
import torch
|
| 22 |
import wandb
|
| 23 |
+
from lightning import LightningModule, Trainer, seed_everything
|
| 24 |
+
from lightning.pytorch.callbacks import Callback, RichModelSummary, RichProgressBar
|
| 25 |
+
from lightning.pytorch.callbacks.progress.rich_progress import CustomProgress
|
| 26 |
+
from lightning.pytorch.loggers import WandbLogger
|
| 27 |
from omegaconf import ListConfig
|
| 28 |
+
from rich import reconfigure
|
| 29 |
from rich.console import Console, Group
|
| 30 |
+
from rich.logging import RichHandler
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
from rich.table import Table
|
| 32 |
+
from rich.text import Text
|
| 33 |
from torch import Tensor
|
| 34 |
from torch.nn import ModuleList
|
| 35 |
+
from typing_extensions import override
|
|
|
|
| 36 |
|
| 37 |
from yolo.config.config import Config, YOLOLayer
|
| 38 |
from yolo.model.yolo import YOLO
|
|
|
|
| 39 |
from yolo.utils.logger import logger
|
| 40 |
from yolo.utils.solver_utils import make_ap_table
|
| 41 |
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
# TODO: should be moved to correct position
|
| 44 |
def set_seed(seed):
|
| 45 |
+
seed_everything(seed)
|
|
|
|
|
|
|
| 46 |
if torch.cuda.is_available():
|
| 47 |
torch.cuda.manual_seed(seed)
|
| 48 |
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
|
|
|
|
| 50 |
torch.backends.cudnn.benchmark = False
|
| 51 |
|
| 52 |
|
| 53 |
+
class YOLOCustomProgress(CustomProgress):
|
| 54 |
+
def get_renderable(self):
|
| 55 |
+
renderable = Group(*self.get_renderables())
|
| 56 |
+
if hasattr(self, "table"):
|
| 57 |
+
renderable = Group(*self.get_renderables(), self.table)
|
| 58 |
+
return renderable
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
+
class YOLORichProgressBar(RichProgressBar):
|
| 62 |
+
@override
|
| 63 |
+
def _init_progress(self, trainer: "Trainer") -> None:
|
| 64 |
+
if self.is_enabled and (self.progress is None or self._progress_stopped):
|
| 65 |
+
self._reset_progress_bar_ids()
|
| 66 |
+
reconfigure(**self._console_kwargs)
|
| 67 |
+
self._console = Console()
|
| 68 |
+
self._console.clear_live()
|
| 69 |
+
self.progress = YOLOCustomProgress(
|
| 70 |
+
*self.configure_columns(trainer),
|
| 71 |
+
auto_refresh=False,
|
| 72 |
+
disable=self.is_disabled,
|
| 73 |
+
console=self._console,
|
| 74 |
+
)
|
| 75 |
+
self.progress.start()
|
| 76 |
|
| 77 |
+
self._progress_stopped = False
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
+
self.max_result = 0
|
| 80 |
+
self.past_results = deque(maxlen=5)
|
| 81 |
+
self.progress.table = Table()
|
| 82 |
|
| 83 |
+
@override
|
| 84 |
+
def _get_train_description(self, current_epoch: int) -> str:
|
| 85 |
+
return Text("[cyan]Train [white]|")
|
| 86 |
|
| 87 |
+
@override
|
| 88 |
+
def on_train_start(self, trainer, pl_module):
|
| 89 |
+
self._init_progress(trainer)
|
| 90 |
+
num_epochs = trainer.max_epochs - 1
|
| 91 |
+
self.task_epoch = self._add_task(
|
| 92 |
+
total_batches=num_epochs,
|
| 93 |
+
description=f"[cyan]Start Training {num_epochs} epochs",
|
| 94 |
+
)
|
| 95 |
+
self.max_result = 0
|
| 96 |
+
self.past_results.clear()
|
| 97 |
+
self.progress.update(self.task_epoch, advance=-0.5)
|
| 98 |
+
|
| 99 |
+
@override
|
| 100 |
+
def on_train_batch_end(self, trainer, pl_module, outputs, batch: Any, batch_idx: int):
|
| 101 |
+
self._update(self.train_progress_bar_id, batch_idx + 1)
|
| 102 |
+
self._update_metrics(trainer, pl_module)
|
| 103 |
+
epoch_descript = "[cyan]Train [white]|"
|
| 104 |
+
batch_descript = "[green]Train [white]|"
|
| 105 |
+
metrics = self.get_metrics(trainer, pl_module)
|
| 106 |
+
metrics.pop("v_num")
|
| 107 |
+
for metrics_name, metrics_val in metrics.items():
|
| 108 |
+
if "Loss_step" in metrics_name:
|
| 109 |
+
epoch_descript += f"{metrics_name.removesuffix('_step'): ^9}|"
|
| 110 |
+
batch_descript += f" {metrics_val:2.2f} |"
|
| 111 |
+
|
| 112 |
+
self.progress.update(self.task_epoch, advance=1 / self.total_train_batches, description=epoch_descript)
|
| 113 |
+
self.progress.update(self.train_progress_bar_id, description=batch_descript)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
self.refresh()
|
|
|
|
| 115 |
|
| 116 |
+
@override
|
| 117 |
+
def on_train_end(self, trainer: "Trainer", pl_module: "LightningModule") -> None:
|
| 118 |
+
self._update_metrics(trainer, pl_module)
|
| 119 |
+
self.progress.remove_task(self.train_progress_bar_id)
|
| 120 |
+
self.train_progress_bar_id = None
|
| 121 |
+
|
| 122 |
+
@override
|
| 123 |
+
def on_validation_end(self, trainer: "Trainer", pl_module: "LightningModule") -> None:
|
| 124 |
+
if trainer.state.fn == "fit":
|
| 125 |
+
self._update_metrics(trainer, pl_module)
|
| 126 |
+
self.reset_dataloader_idx_tracker()
|
| 127 |
+
all_metrics = self.get_metrics(trainer, pl_module)
|
| 128 |
+
|
| 129 |
+
ap_ar_list = [
|
| 130 |
+
key
|
| 131 |
+
for key in all_metrics.keys()
|
| 132 |
+
if key.startswith(("map", "mar")) and not key.endswith(("_step", "_epoch"))
|
| 133 |
+
]
|
| 134 |
+
score = np.array([all_metrics[key] for key in ap_ar_list]) * 100
|
| 135 |
+
|
| 136 |
+
self.progress.table, ap_main = make_ap_table(score, self.past_results, self.max_result, trainer.current_epoch)
|
| 137 |
+
self.max_result = np.maximum(score, self.max_result)
|
| 138 |
+
self.past_results.append((trainer.current_epoch, ap_main))
|
| 139 |
+
|
| 140 |
+
@override
|
| 141 |
+
def refresh(self) -> None:
|
| 142 |
+
if self.progress:
|
| 143 |
+
self.progress.refresh()
|
| 144 |
+
|
| 145 |
+
@property
|
| 146 |
+
def validation_description(self) -> str:
|
| 147 |
+
return "[green]Validation"
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class YOLORichModelSummary(RichModelSummary):
|
| 151 |
+
|
| 152 |
+
from typing_extensions import override
|
| 153 |
+
|
| 154 |
+
@staticmethod
|
| 155 |
+
@override
|
| 156 |
+
def summarize(
|
| 157 |
+
summary_data: List[Tuple[str, List[str]]],
|
| 158 |
+
total_parameters: int,
|
| 159 |
+
trainable_parameters: int,
|
| 160 |
+
model_size: float,
|
| 161 |
+
total_training_modes: Dict[str, int],
|
| 162 |
+
**summarize_kwargs: Any,
|
| 163 |
+
) -> None:
|
| 164 |
+
from lightning.pytorch.utilities.model_summary import get_human_readable_count
|
| 165 |
+
from rich import get_console
|
| 166 |
+
from rich.table import Table
|
| 167 |
+
|
| 168 |
+
console = get_console()
|
| 169 |
+
|
| 170 |
+
header_style: str = summarize_kwargs.get("header_style", "bold magenta")
|
| 171 |
+
table = Table(header_style=header_style)
|
| 172 |
+
table.add_column(" ", style="dim")
|
| 173 |
+
table.add_column("Name", justify="left", no_wrap=True)
|
| 174 |
+
table.add_column("Type")
|
| 175 |
+
table.add_column("Params", justify="right")
|
| 176 |
+
table.add_column("Mode")
|
| 177 |
+
|
| 178 |
+
column_names = list(zip(*summary_data))[0]
|
| 179 |
+
|
| 180 |
+
for column_name in ["In sizes", "Out sizes"]:
|
| 181 |
+
if column_name in column_names:
|
| 182 |
+
table.add_column(column_name, justify="right", style="white")
|
| 183 |
+
|
| 184 |
+
rows = list(zip(*(arr[1] for arr in summary_data)))
|
| 185 |
+
for row in rows:
|
| 186 |
+
table.add_row(*row)
|
| 187 |
+
|
| 188 |
+
console.print(table)
|
| 189 |
+
|
| 190 |
+
parameters = []
|
| 191 |
+
for param in [trainable_parameters, total_parameters - trainable_parameters, total_parameters, model_size]:
|
| 192 |
+
parameters.append("{:<{}}".format(get_human_readable_count(int(param)), 10))
|
| 193 |
+
|
| 194 |
+
grid = Table(header_style=header_style)
|
| 195 |
+
table.add_column(" ", style="dim")
|
| 196 |
+
grid.add_column("[bold]Attributes[/]")
|
| 197 |
+
grid.add_column("Value")
|
| 198 |
+
|
| 199 |
+
grid.add_row("[bold]Trainable params[/]", f"{parameters[0]}")
|
| 200 |
+
grid.add_row("[bold]Non-trainable params[/]", f"{parameters[1]}")
|
| 201 |
+
grid.add_row("[bold]Total params[/]", f"{parameters[2]}")
|
| 202 |
+
grid.add_row("[bold]Total estimated model params size (MB)[/]", f"{parameters[3]}")
|
| 203 |
+
grid.add_row("[bold]Modules in train mode[/]", f"{total_training_modes['train']}")
|
| 204 |
+
grid.add_row("[bold]Modules in eval mode[/]", f"{total_training_modes['eval']}")
|
| 205 |
+
|
| 206 |
+
console.print(grid)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class ImageLogger(Callback):
|
| 210 |
+
def on_validation_batch_end(self, trainer: Trainer, pl_module, outputs, batch, batch_idx) -> None:
|
| 211 |
+
if batch_idx != 0:
|
| 212 |
+
return
|
| 213 |
+
batch_size, images, targets, rev_tensor, img_paths = batch
|
| 214 |
+
gt_boxes = targets[0] if targets.ndim == 3 else targets
|
| 215 |
+
pred_boxes = outputs[0] if isinstance(outputs, list) else outputs
|
| 216 |
+
images = [images[0]]
|
| 217 |
+
step = trainer.current_epoch
|
| 218 |
+
for logger in trainer.loggers:
|
| 219 |
+
if isinstance(logger, WandbLogger):
|
| 220 |
+
logger.log_image("Input Image", images, step=step)
|
| 221 |
+
logger.log_image("Ground Truth", images, step=step, boxes=[log_bbox(gt_boxes)])
|
| 222 |
+
logger.log_image("Prediction", images, step=step, boxes=[log_bbox(pred_boxes)])
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def setup(cfg: Config):
|
| 226 |
+
if hasattr(cfg, "quite"):
|
| 227 |
+
logger.removeHandler("YOLO_logger")
|
| 228 |
+
return
|
| 229 |
+
|
| 230 |
+
class EmojiFormatter(logging.Formatter):
|
| 231 |
+
def format(self, record):
|
| 232 |
+
return f":high_voltage: {super().format(record)}"
|
| 233 |
|
| 234 |
+
rich_handler = RichHandler(markup=True)
|
| 235 |
+
rich_handler.setFormatter(EmojiFormatter("%(message)s"))
|
| 236 |
+
lightning_logger = logging.getLogger("lightning.pytorch")
|
| 237 |
+
lightning_logger.handlers.clear()
|
| 238 |
+
lightning_logger.addHandler(rich_handler)
|
| 239 |
|
| 240 |
+
def custom_wandb_log(string="", level=int, newline=True, repeat=True, prefix=True, silent=False):
|
| 241 |
+
if silent:
|
| 242 |
+
return
|
| 243 |
+
for line in string.split("\n"):
|
| 244 |
+
logger.info(Text.from_ansi(":globe_with_meridians: " + line))
|
| 245 |
+
|
| 246 |
+
wandb.errors.term._log = custom_wandb_log
|
| 247 |
+
|
| 248 |
+
save_path = validate_log_directory(cfg, cfg.name)
|
| 249 |
+
|
| 250 |
+
progress, loggers = [], []
|
| 251 |
+
progress.append(YOLORichProgressBar())
|
| 252 |
+
progress.append(YOLORichModelSummary())
|
| 253 |
+
progress.append(ImageLogger())
|
| 254 |
+
|
| 255 |
+
loggers.append(WandbLogger(project="YOLO", name=cfg.name, save_dir=save_path, id=None))
|
| 256 |
+
|
| 257 |
+
return progress, loggers
|
| 258 |
|
| 259 |
|
| 260 |
def log_model_structure(model: Union[ModuleList, YOLOLayer, YOLO]):
|
|
|
|
| 301 |
)
|
| 302 |
|
| 303 |
save_path.mkdir(parents=True, exist_ok=True)
|
| 304 |
+
logger.info(f"📄 Created log folder: [blue b u]123{save_path}[/]")
|
| 305 |
logger.addHandler(FileHandler(save_path / "output.log"))
|
| 306 |
return save_path
|
| 307 |
|
|
|
|
| 337 |
bbox_entry["scores"] = {"confidence": conf[0]}
|
| 338 |
bbox_list.append(bbox_entry)
|
| 339 |
|
| 340 |
+
return {"predictions": {"box_data": bbox_list}}
|
yolo/utils/model_utils.py
CHANGED
|
@@ -56,23 +56,8 @@ def create_optimizer(model: YOLO, optim_cfg: OptimizerConfig) -> Optimizer:
|
|
| 56 |
{"params": conv_params},
|
| 57 |
{"params": norm_params, "weight_decay": 0},
|
| 58 |
]
|
| 59 |
-
|
| 60 |
-
def next_epoch(self, batch_num):
|
| 61 |
-
self.min_lr = self.max_lr
|
| 62 |
-
self.max_lr = [param["lr"] for param in self.param_groups]
|
| 63 |
-
self.batch_num = batch_num
|
| 64 |
-
self.batch_idx = 0
|
| 65 |
-
|
| 66 |
-
def next_batch(self):
|
| 67 |
-
self.batch_idx += 1
|
| 68 |
-
for lr_idx, param_group in enumerate(self.param_groups):
|
| 69 |
-
min_lr, max_lr = self.min_lr[lr_idx], self.max_lr[lr_idx]
|
| 70 |
-
param_group["lr"] = min_lr + (self.batch_idx) * (max_lr - min_lr) / self.batch_num
|
| 71 |
-
|
| 72 |
-
optimizer_class.next_batch = next_batch
|
| 73 |
-
optimizer_class.next_epoch = next_epoch
|
| 74 |
optimizer = optimizer_class(model_parameters, **optim_cfg.args)
|
| 75 |
-
|
| 76 |
return optimizer
|
| 77 |
|
| 78 |
|
|
@@ -168,6 +153,7 @@ def predicts_to_json(img_paths, predicts, rev_tensor):
|
|
| 168 |
batch_json = []
|
| 169 |
for img_path, bboxes, box_reverse in zip(img_paths, predicts, rev_tensor):
|
| 170 |
scale, shift = box_reverse.split([1, 4])
|
|
|
|
| 171 |
bboxes[:, 1:5] = (bboxes[:, 1:5] - shift[None]) / scale[None]
|
| 172 |
bboxes[:, 1:5] = transform_bbox(bboxes[:, 1:5], "xyxy -> xywh")
|
| 173 |
for cls, *pos, conf in bboxes:
|
|
|
|
| 56 |
{"params": conv_params},
|
| 57 |
{"params": norm_params, "weight_decay": 0},
|
| 58 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
optimizer = optimizer_class(model_parameters, **optim_cfg.args)
|
| 60 |
+
# TODO: implement batch lr schedular when warm up
|
| 61 |
return optimizer
|
| 62 |
|
| 63 |
|
|
|
|
| 153 |
batch_json = []
|
| 154 |
for img_path, bboxes, box_reverse in zip(img_paths, predicts, rev_tensor):
|
| 155 |
scale, shift = box_reverse.split([1, 4])
|
| 156 |
+
bboxes = bboxes.clone()
|
| 157 |
bboxes[:, 1:5] = (bboxes[:, 1:5] - shift[None]) / scale[None]
|
| 158 |
bboxes[:, 1:5] = transform_bbox(bboxes[:, 1:5], "xyxy -> xywh")
|
| 159 |
for cls, *pos, conf in bboxes:
|
yolo/utils/solver_utils.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import contextlib
|
| 2 |
import io
|
|
|
|
| 3 |
|
| 4 |
import numpy as np
|
| 5 |
from pycocotools.coco import COCO
|
|
@@ -17,7 +18,7 @@ def calculate_ap(coco_gt: COCO, pd_path):
|
|
| 17 |
return coco_eval.stats
|
| 18 |
|
| 19 |
|
| 20 |
-
def make_ap_table(score, past_result=[],
|
| 21 |
ap_table = Table()
|
| 22 |
ap_table.add_column("Epoch", justify="center", style="white", width=5)
|
| 23 |
ap_table.add_column("Avg. Precision", justify="left", style="cyan")
|
|
@@ -30,7 +31,7 @@ def make_ap_table(score, past_result=[], last_score=None, epoch=-1):
|
|
| 30 |
if past_result:
|
| 31 |
ap_table.add_row()
|
| 32 |
|
| 33 |
-
color = np.where(
|
| 34 |
|
| 35 |
this_ap = ("AP @ .5:.95", color[0], score[0], "AP @ .5", color[1], score[1])
|
| 36 |
metrics = [
|
|
|
|
| 1 |
import contextlib
|
| 2 |
import io
|
| 3 |
+
from typing import Dict
|
| 4 |
|
| 5 |
import numpy as np
|
| 6 |
from pycocotools.coco import COCO
|
|
|
|
| 18 |
return coco_eval.stats
|
| 19 |
|
| 20 |
|
| 21 |
+
def make_ap_table(score: Dict[str, float], past_result=[], max_result=None, epoch=-1):
|
| 22 |
ap_table = Table()
|
| 23 |
ap_table.add_column("Epoch", justify="center", style="white", width=5)
|
| 24 |
ap_table.add_column("Avg. Precision", justify="left", style="cyan")
|
|
|
|
| 31 |
if past_result:
|
| 32 |
ap_table.add_row()
|
| 33 |
|
| 34 |
+
color = np.where(max_result <= score, "[green]", "[red]")
|
| 35 |
|
| 36 |
this_ap = ("AP @ .5:.95", color[0], score[0], "AP @ .5", color[1], score[1])
|
| 37 |
metrics = [
|