Spaces:
Running
on
Zero
Running
on
Zero
""" | |
D-FINE: Redefine Regression Task of DETRs as Fine-grained Distribution Refinement | |
Copyright (c) 2024 The D-FINE Authors. All Rights Reserved. | |
--------------------------------------------------------------------------------- | |
Modified from RT-DETR (https://github.com/lyuwenyu/RT-DETR) | |
Copyright (c) 2023 lyuwenyu. All Rights Reserved. | |
""" | |
import datetime | |
import json | |
import time | |
import torch | |
from ..misc import dist_utils, stats | |
from ._solver import BaseSolver | |
from .det_engine import evaluate, train_one_epoch | |
class DetSolver(BaseSolver): | |
def fit(self): | |
self.train() | |
args = self.cfg | |
metric_names = ["AP50:95", "AP50", "AP75", "APsmall", "APmedium", "APlarge"] | |
if self.use_wandb: | |
import wandb | |
wandb.init( | |
project=args.yaml_cfg["project_name"], | |
name=args.yaml_cfg["exp_name"], | |
config=args.yaml_cfg, | |
) | |
wandb.watch(self.model) | |
n_parameters, model_stats = stats(self.cfg) | |
print(model_stats) | |
print("-" * 42 + "Start training" + "-" * 43) | |
top1 = 0 | |
best_stat = { | |
"epoch": -1, | |
} | |
if self.last_epoch > 0: | |
module = self.ema.module if self.ema else self.model | |
test_stats, coco_evaluator = evaluate( | |
module, | |
self.criterion, | |
self.postprocessor, | |
self.val_dataloader, | |
self.evaluator, | |
self.device, | |
self.last_epoch, | |
self.use_wandb | |
) | |
for k in test_stats: | |
best_stat["epoch"] = self.last_epoch | |
best_stat[k] = test_stats[k][0] | |
top1 = test_stats[k][0] | |
print(f"best_stat: {best_stat}") | |
best_stat_print = best_stat.copy() | |
start_time = time.time() | |
start_epoch = self.last_epoch + 1 | |
for epoch in range(start_epoch, args.epochs): | |
self.train_dataloader.set_epoch(epoch) | |
# self.train_dataloader.dataset.set_epoch(epoch) | |
if dist_utils.is_dist_available_and_initialized(): | |
self.train_dataloader.sampler.set_epoch(epoch) | |
if epoch == self.train_dataloader.collate_fn.stop_epoch: | |
self.load_resume_state(str(self.output_dir / "best_stg1.pth")) | |
if self.ema: | |
self.ema.decay = self.train_dataloader.collate_fn.ema_restart_decay | |
print(f"Refresh EMA at epoch {epoch} with decay {self.ema.decay}") | |
train_stats = train_one_epoch( | |
self.model, | |
self.criterion, | |
self.train_dataloader, | |
self.optimizer, | |
self.device, | |
epoch, | |
max_norm=args.clip_max_norm, | |
print_freq=args.print_freq, | |
ema=self.ema, | |
scaler=self.scaler, | |
lr_warmup_scheduler=self.lr_warmup_scheduler, | |
writer=self.writer, | |
use_wandb=self.use_wandb, | |
output_dir=self.output_dir, | |
) | |
if self.lr_warmup_scheduler is None or self.lr_warmup_scheduler.finished(): | |
self.lr_scheduler.step() | |
self.last_epoch += 1 | |
if self.output_dir and epoch < self.train_dataloader.collate_fn.stop_epoch: | |
checkpoint_paths = [self.output_dir / "last.pth"] | |
# extra checkpoint before LR drop and every 100 epochs | |
if (epoch + 1) % args.checkpoint_freq == 0: | |
checkpoint_paths.append(self.output_dir / f"checkpoint{epoch:04}.pth") | |
for checkpoint_path in checkpoint_paths: | |
dist_utils.save_on_master(self.state_dict(), checkpoint_path) | |
module = self.ema.module if self.ema else self.model | |
test_stats, coco_evaluator = evaluate( | |
module, | |
self.criterion, | |
self.postprocessor, | |
self.val_dataloader, | |
self.evaluator, | |
self.device, | |
epoch, | |
self.use_wandb, | |
output_dir=self.output_dir, | |
) | |
# TODO | |
for k in test_stats: | |
if self.writer and dist_utils.is_main_process(): | |
for i, v in enumerate(test_stats[k]): | |
self.writer.add_scalar(f"Test/{k}_{i}".format(k), v, epoch) | |
if k in best_stat: | |
best_stat["epoch"] = ( | |
epoch if test_stats[k][0] > best_stat[k] else best_stat["epoch"] | |
) | |
best_stat[k] = max(best_stat[k], test_stats[k][0]) | |
else: | |
best_stat["epoch"] = epoch | |
best_stat[k] = test_stats[k][0] | |
if best_stat[k] > top1: | |
best_stat_print["epoch"] = epoch | |
top1 = best_stat[k] | |
if self.output_dir: | |
if epoch >= self.train_dataloader.collate_fn.stop_epoch: | |
dist_utils.save_on_master( | |
self.state_dict(), self.output_dir / "best_stg2.pth" | |
) | |
else: | |
dist_utils.save_on_master( | |
self.state_dict(), self.output_dir / "best_stg1.pth" | |
) | |
best_stat_print[k] = max(best_stat[k], top1) | |
print(f"best_stat: {best_stat_print}") # global best | |
if best_stat["epoch"] == epoch and self.output_dir: | |
if epoch >= self.train_dataloader.collate_fn.stop_epoch: | |
if test_stats[k][0] > top1: | |
top1 = test_stats[k][0] | |
dist_utils.save_on_master( | |
self.state_dict(), self.output_dir / "best_stg2.pth" | |
) | |
else: | |
top1 = max(test_stats[k][0], top1) | |
dist_utils.save_on_master( | |
self.state_dict(), self.output_dir / "best_stg1.pth" | |
) | |
elif epoch >= self.train_dataloader.collate_fn.stop_epoch: | |
best_stat = { | |
"epoch": -1, | |
} | |
if self.ema: | |
self.ema.decay -= 0.0001 | |
self.load_resume_state(str(self.output_dir / "best_stg1.pth")) | |
print(f"Refresh EMA at epoch {epoch} with decay {self.ema.decay}") | |
log_stats = { | |
**{f"train_{k}": v for k, v in train_stats.items()}, | |
**{f"test_{k}": v for k, v in test_stats.items()}, | |
"epoch": epoch, | |
"n_parameters": n_parameters, | |
} | |
if self.use_wandb: | |
wandb_logs = {} | |
for idx, metric_name in enumerate(metric_names): | |
wandb_logs[f"metrics/{metric_name}"] = test_stats["coco_eval_bbox"][idx] | |
wandb_logs["epoch"] = epoch | |
wandb.log(wandb_logs) | |
if self.output_dir and dist_utils.is_main_process(): | |
with (self.output_dir / "log.txt").open("a") as f: | |
f.write(json.dumps(log_stats) + "\n") | |
# for evaluation logs | |
if coco_evaluator is not None: | |
(self.output_dir / "eval").mkdir(exist_ok=True) | |
if "bbox" in coco_evaluator.coco_eval: | |
filenames = ["latest.pth"] | |
if epoch % 50 == 0: | |
filenames.append(f"{epoch:03}.pth") | |
for name in filenames: | |
torch.save( | |
coco_evaluator.coco_eval["bbox"].eval, | |
self.output_dir / "eval" / name, | |
) | |
total_time = time.time() - start_time | |
total_time_str = str(datetime.timedelta(seconds=int(total_time))) | |
print("Training time {}".format(total_time_str)) | |
def val(self): | |
self.eval() | |
module = self.ema.module if self.ema else self.model | |
test_stats, coco_evaluator = evaluate( | |
module, | |
self.criterion, | |
self.postprocessor, | |
self.val_dataloader, | |
self.evaluator, | |
self.device, | |
epoch=-1, | |
use_wandb=False, | |
) | |
if self.output_dir: | |
dist_utils.save_on_master( | |
coco_evaluator.coco_eval["bbox"].eval, self.output_dir / "eval.pth" | |
) | |
return | |