D-FINE / src /solver /det_solver.py
developer0hye's picture
Upload 76 files
e85fecb verified
"""
D-FINE: Redefine Regression Task of DETRs as Fine-grained Distribution Refinement
Copyright (c) 2024 The D-FINE Authors. All Rights Reserved.
---------------------------------------------------------------------------------
Modified from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
Copyright (c) 2023 lyuwenyu. All Rights Reserved.
"""
import datetime
import json
import time
import torch
from ..misc import dist_utils, stats
from ._solver import BaseSolver
from .det_engine import evaluate, train_one_epoch
class DetSolver(BaseSolver):
def fit(self):
self.train()
args = self.cfg
metric_names = ["AP50:95", "AP50", "AP75", "APsmall", "APmedium", "APlarge"]
if self.use_wandb:
import wandb
wandb.init(
project=args.yaml_cfg["project_name"],
name=args.yaml_cfg["exp_name"],
config=args.yaml_cfg,
)
wandb.watch(self.model)
n_parameters, model_stats = stats(self.cfg)
print(model_stats)
print("-" * 42 + "Start training" + "-" * 43)
top1 = 0
best_stat = {
"epoch": -1,
}
if self.last_epoch > 0:
module = self.ema.module if self.ema else self.model
test_stats, coco_evaluator = evaluate(
module,
self.criterion,
self.postprocessor,
self.val_dataloader,
self.evaluator,
self.device,
self.last_epoch,
self.use_wandb
)
for k in test_stats:
best_stat["epoch"] = self.last_epoch
best_stat[k] = test_stats[k][0]
top1 = test_stats[k][0]
print(f"best_stat: {best_stat}")
best_stat_print = best_stat.copy()
start_time = time.time()
start_epoch = self.last_epoch + 1
for epoch in range(start_epoch, args.epochs):
self.train_dataloader.set_epoch(epoch)
# self.train_dataloader.dataset.set_epoch(epoch)
if dist_utils.is_dist_available_and_initialized():
self.train_dataloader.sampler.set_epoch(epoch)
if epoch == self.train_dataloader.collate_fn.stop_epoch:
self.load_resume_state(str(self.output_dir / "best_stg1.pth"))
if self.ema:
self.ema.decay = self.train_dataloader.collate_fn.ema_restart_decay
print(f"Refresh EMA at epoch {epoch} with decay {self.ema.decay}")
train_stats = train_one_epoch(
self.model,
self.criterion,
self.train_dataloader,
self.optimizer,
self.device,
epoch,
max_norm=args.clip_max_norm,
print_freq=args.print_freq,
ema=self.ema,
scaler=self.scaler,
lr_warmup_scheduler=self.lr_warmup_scheduler,
writer=self.writer,
use_wandb=self.use_wandb,
output_dir=self.output_dir,
)
if self.lr_warmup_scheduler is None or self.lr_warmup_scheduler.finished():
self.lr_scheduler.step()
self.last_epoch += 1
if self.output_dir and epoch < self.train_dataloader.collate_fn.stop_epoch:
checkpoint_paths = [self.output_dir / "last.pth"]
# extra checkpoint before LR drop and every 100 epochs
if (epoch + 1) % args.checkpoint_freq == 0:
checkpoint_paths.append(self.output_dir / f"checkpoint{epoch:04}.pth")
for checkpoint_path in checkpoint_paths:
dist_utils.save_on_master(self.state_dict(), checkpoint_path)
module = self.ema.module if self.ema else self.model
test_stats, coco_evaluator = evaluate(
module,
self.criterion,
self.postprocessor,
self.val_dataloader,
self.evaluator,
self.device,
epoch,
self.use_wandb,
output_dir=self.output_dir,
)
# TODO
for k in test_stats:
if self.writer and dist_utils.is_main_process():
for i, v in enumerate(test_stats[k]):
self.writer.add_scalar(f"Test/{k}_{i}".format(k), v, epoch)
if k in best_stat:
best_stat["epoch"] = (
epoch if test_stats[k][0] > best_stat[k] else best_stat["epoch"]
)
best_stat[k] = max(best_stat[k], test_stats[k][0])
else:
best_stat["epoch"] = epoch
best_stat[k] = test_stats[k][0]
if best_stat[k] > top1:
best_stat_print["epoch"] = epoch
top1 = best_stat[k]
if self.output_dir:
if epoch >= self.train_dataloader.collate_fn.stop_epoch:
dist_utils.save_on_master(
self.state_dict(), self.output_dir / "best_stg2.pth"
)
else:
dist_utils.save_on_master(
self.state_dict(), self.output_dir / "best_stg1.pth"
)
best_stat_print[k] = max(best_stat[k], top1)
print(f"best_stat: {best_stat_print}") # global best
if best_stat["epoch"] == epoch and self.output_dir:
if epoch >= self.train_dataloader.collate_fn.stop_epoch:
if test_stats[k][0] > top1:
top1 = test_stats[k][0]
dist_utils.save_on_master(
self.state_dict(), self.output_dir / "best_stg2.pth"
)
else:
top1 = max(test_stats[k][0], top1)
dist_utils.save_on_master(
self.state_dict(), self.output_dir / "best_stg1.pth"
)
elif epoch >= self.train_dataloader.collate_fn.stop_epoch:
best_stat = {
"epoch": -1,
}
if self.ema:
self.ema.decay -= 0.0001
self.load_resume_state(str(self.output_dir / "best_stg1.pth"))
print(f"Refresh EMA at epoch {epoch} with decay {self.ema.decay}")
log_stats = {
**{f"train_{k}": v for k, v in train_stats.items()},
**{f"test_{k}": v for k, v in test_stats.items()},
"epoch": epoch,
"n_parameters": n_parameters,
}
if self.use_wandb:
wandb_logs = {}
for idx, metric_name in enumerate(metric_names):
wandb_logs[f"metrics/{metric_name}"] = test_stats["coco_eval_bbox"][idx]
wandb_logs["epoch"] = epoch
wandb.log(wandb_logs)
if self.output_dir and dist_utils.is_main_process():
with (self.output_dir / "log.txt").open("a") as f:
f.write(json.dumps(log_stats) + "\n")
# for evaluation logs
if coco_evaluator is not None:
(self.output_dir / "eval").mkdir(exist_ok=True)
if "bbox" in coco_evaluator.coco_eval:
filenames = ["latest.pth"]
if epoch % 50 == 0:
filenames.append(f"{epoch:03}.pth")
for name in filenames:
torch.save(
coco_evaluator.coco_eval["bbox"].eval,
self.output_dir / "eval" / name,
)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("Training time {}".format(total_time_str))
def val(self):
self.eval()
module = self.ema.module if self.ema else self.model
test_stats, coco_evaluator = evaluate(
module,
self.criterion,
self.postprocessor,
self.val_dataloader,
self.evaluator,
self.device,
epoch=-1,
use_wandb=False,
)
if self.output_dir:
dist_utils.save_on_master(
coco_evaluator.coco_eval["bbox"].eval, self.output_dir / "eval.pth"
)
return