D-FINE / src /solver /det_engine.py
developer0hye's picture
Upload 76 files
e85fecb verified
"""
D-FINE: Redefine Regression Task of DETRs as Fine-grained Distribution Refinement
Copyright (c) 2024 The D-FINE Authors. All Rights Reserved.
---------------------------------------------------------------------------------
Modified from DETR (https://github.com/facebookresearch/detr/blob/main/engine.py)
Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""
import math
import sys
from typing import Dict, Iterable, List
import numpy as np
import torch
import torch.amp
from torch.cuda.amp.grad_scaler import GradScaler
from torch.utils.tensorboard import SummaryWriter
from ..data import CocoEvaluator
from ..data.dataset import mscoco_category2label
from ..misc import MetricLogger, SmoothedValue, dist_utils, save_samples
from ..optim import ModelEMA, Warmup
from .validator import Validator, scale_boxes
def train_one_epoch(
model: torch.nn.Module,
criterion: torch.nn.Module,
data_loader: Iterable,
optimizer: torch.optim.Optimizer,
device: torch.device,
epoch: int,
use_wandb: bool,
max_norm: float = 0,
**kwargs,
):
if use_wandb:
import wandb
model.train()
criterion.train()
metric_logger = MetricLogger(delimiter=" ")
metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}"))
header = "Epoch: [{}]".format(epoch)
print_freq = kwargs.get("print_freq", 10)
writer: SummaryWriter = kwargs.get("writer", None)
ema: ModelEMA = kwargs.get("ema", None)
scaler: GradScaler = kwargs.get("scaler", None)
lr_warmup_scheduler: Warmup = kwargs.get("lr_warmup_scheduler", None)
losses = []
output_dir = kwargs.get("output_dir", None)
num_visualization_sample_batch = kwargs.get("num_visualization_sample_batch", 1)
for i, (samples, targets) in enumerate(
metric_logger.log_every(data_loader, print_freq, header)
):
global_step = epoch * len(data_loader) + i
metas = dict(epoch=epoch, step=i, global_step=global_step, epoch_step=len(data_loader))
if global_step < num_visualization_sample_batch and output_dir is not None and dist_utils.is_main_process():
save_samples(samples, targets, output_dir, "train", normalized=True, box_fmt="cxcywh")
samples = samples.to(device)
targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
if scaler is not None:
with torch.autocast(device_type=str(device), cache_enabled=True):
outputs = model(samples, targets=targets)
if torch.isnan(outputs["pred_boxes"]).any() or torch.isinf(outputs["pred_boxes"]).any():
print(outputs["pred_boxes"])
state = model.state_dict()
new_state = {}
for key, value in model.state_dict().items():
# Replace 'module' with 'model' in each key
new_key = key.replace("module.", "")
# Add the updated key-value pair to the state dictionary
state[new_key] = value
new_state["model"] = state
dist_utils.save_on_master(new_state, "./NaN.pth")
with torch.autocast(device_type=str(device), enabled=False):
loss_dict = criterion(outputs, targets, **metas)
loss = sum(loss_dict.values())
scaler.scale(loss).backward()
if max_norm > 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
else:
outputs = model(samples, targets=targets)
loss_dict = criterion(outputs, targets, **metas)
loss: torch.Tensor = sum(loss_dict.values())
optimizer.zero_grad()
loss.backward()
if max_norm > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
optimizer.step()
# ema
if ema is not None:
ema.update(model)
if lr_warmup_scheduler is not None:
lr_warmup_scheduler.step()
loss_dict_reduced = dist_utils.reduce_dict(loss_dict)
loss_value = sum(loss_dict_reduced.values())
losses.append(loss_value.detach().cpu().numpy())
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
print(loss_dict_reduced)
sys.exit(1)
metric_logger.update(loss=loss_value, **loss_dict_reduced)
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
if writer and dist_utils.is_main_process() and global_step % 10 == 0:
writer.add_scalar("Loss/total", loss_value.item(), global_step)
for j, pg in enumerate(optimizer.param_groups):
writer.add_scalar(f"Lr/pg_{j}", pg["lr"], global_step)
for k, v in loss_dict_reduced.items():
writer.add_scalar(f"Loss/{k}", v.item(), global_step)
if use_wandb:
wandb.log(
{"lr": optimizer.param_groups[0]["lr"], "epoch": epoch, "train/loss": np.mean(losses)}
)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
@torch.no_grad()
def evaluate(
model: torch.nn.Module,
criterion: torch.nn.Module,
postprocessor,
data_loader,
coco_evaluator: CocoEvaluator,
device,
epoch: int,
use_wandb: bool,
**kwargs,
):
if use_wandb:
import wandb
model.eval()
criterion.eval()
coco_evaluator.cleanup()
metric_logger = MetricLogger(delimiter=" ")
# metric_logger.add_meter('class_error', SmoothedValue(window_size=1, fmt='{value:.2f}'))
header = "Test:"
# iou_types = tuple(k for k in ('segm', 'bbox') if k in postprocessor.keys())
iou_types = coco_evaluator.iou_types
# coco_evaluator = CocoEvaluator(base_ds, iou_types)
# coco_evaluator.coco_eval[iou_types[0]].params.iouThrs = [0, 0.1, 0.5, 0.75]
gt: List[Dict[str, torch.Tensor]] = []
preds: List[Dict[str, torch.Tensor]] = []
output_dir = kwargs.get("output_dir", None)
num_visualization_sample_batch = kwargs.get("num_visualization_sample_batch", 1)
for i, (samples, targets) in enumerate(metric_logger.log_every(data_loader, 10, header)):
global_step = epoch * len(data_loader) + i
if global_step < num_visualization_sample_batch and output_dir is not None and dist_utils.is_main_process():
save_samples(samples, targets, output_dir, "val", normalized=False, box_fmt="xyxy")
samples = samples.to(device)
targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
outputs = model(samples)
# with torch.autocast(device_type=str(device)):
# outputs = model(samples)
# TODO (lyuwenyu), fix dataset converted using `convert_to_coco_api`?
orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
# orig_target_sizes = torch.tensor([[samples.shape[-1], samples.shape[-2]]], device=samples.device)
results = postprocessor(outputs, orig_target_sizes)
# if 'segm' in postprocessor.keys():
# target_sizes = torch.stack([t["size"] for t in targets], dim=0)
# results = postprocessor['segm'](results, outputs, orig_target_sizes, target_sizes)
res = {target["image_id"].item(): output for target, output in zip(targets, results)}
if coco_evaluator is not None:
coco_evaluator.update(res)
# validator format for metrics
for idx, (target, result) in enumerate(zip(targets, results)):
gt.append(
{
"boxes": scale_boxes( # from model input size to original img size
target["boxes"],
(target["orig_size"][1], target["orig_size"][0]),
(samples[idx].shape[-1], samples[idx].shape[-2]),
),
"labels": target["labels"],
}
)
labels = (
torch.tensor([mscoco_category2label[int(x.item())] for x in result["labels"].flatten()])
.to(result["labels"].device)
.reshape(result["labels"].shape)
) if postprocessor.remap_mscoco_category else result["labels"]
preds.append(
{"boxes": result["boxes"], "labels": labels, "scores": result["scores"]}
)
# Conf matrix, F1, Precision, Recall, box IoU
metrics = Validator(gt, preds).compute_metrics()
print("Metrics:", metrics)
if use_wandb:
metrics = {f"metrics/{k}": v for k, v in metrics.items()}
metrics["epoch"] = epoch
wandb.log(metrics)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
if coco_evaluator is not None:
coco_evaluator.synchronize_between_processes()
# accumulate predictions from all images
if coco_evaluator is not None:
coco_evaluator.accumulate()
coco_evaluator.summarize()
stats = {}
# stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
if coco_evaluator is not None:
if "bbox" in iou_types:
stats["coco_eval_bbox"] = coco_evaluator.coco_eval["bbox"].stats.tolist()
if "segm" in iou_types:
stats["coco_eval_masks"] = coco_evaluator.coco_eval["segm"].stats.tolist()
return stats, coco_evaluator