D-FINE / src /solver /clas_engine.py
developer0hye's picture
Upload 76 files
e85fecb verified
"""
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
"""
import torch
import torch.nn as nn
from ..misc import MetricLogger, SmoothedValue, reduce_dict
def train_one_epoch(
model: nn.Module, criterion: nn.Module, dataloader, optimizer, ema, epoch, device
):
""" """
model.train()
metric_logger = MetricLogger(delimiter=" ")
metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}"))
print_freq = 100
header = "Epoch: [{}]".format(epoch)
for imgs, labels in metric_logger.log_every(dataloader, print_freq, header):
imgs = imgs.to(device)
labels = labels.to(device)
preds = model(imgs)
loss: torch.Tensor = criterion(preds, labels, epoch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if ema is not None:
ema.update(model)
loss_reduced_values = {k: v.item() for k, v in reduce_dict({"loss": loss}).items()}
metric_logger.update(**loss_reduced_values)
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
return stats
@torch.no_grad()
def evaluate(model, criterion, dataloader, device):
model.eval()
metric_logger = MetricLogger(delimiter=" ")
# metric_logger.add_meter('acc', SmoothedValue(window_size=1, fmt='{global_avg:.4f}'))
# metric_logger.add_meter('loss', SmoothedValue(window_size=1, fmt='{value:.2f}'))
metric_logger.add_meter("acc", SmoothedValue(window_size=1))
metric_logger.add_meter("loss", SmoothedValue(window_size=1))
header = "Test:"
for imgs, labels in metric_logger.log_every(dataloader, 10, header):
imgs, labels = imgs.to(device), labels.to(device)
preds = model(imgs)
acc = (preds.argmax(dim=-1) == labels).sum() / preds.shape[0]
loss = criterion(preds, labels)
dict_reduced = reduce_dict({"acc": acc, "loss": loss})
reduced_values = {k: v.item() for k, v in dict_reduced.items()}
metric_logger.update(**reduced_values)
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
return stats