File size: 2,516 Bytes
e85fecb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""

Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)

Copyright(c) 2023 lyuwenyu. All Rights Reserved.

"""

import torch
import torch.nn as nn

from ..misc import MetricLogger, SmoothedValue, reduce_dict


def train_one_epoch(

    model: nn.Module, criterion: nn.Module, dataloader, optimizer, ema, epoch, device

):
    """ """
    model.train()

    metric_logger = MetricLogger(delimiter="  ")
    metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}"))
    print_freq = 100
    header = "Epoch: [{}]".format(epoch)

    for imgs, labels in metric_logger.log_every(dataloader, print_freq, header):
        imgs = imgs.to(device)
        labels = labels.to(device)

        preds = model(imgs)
        loss: torch.Tensor = criterion(preds, labels, epoch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if ema is not None:
            ema.update(model)

        loss_reduced_values = {k: v.item() for k, v in reduce_dict({"loss": loss}).items()}
        metric_logger.update(**loss_reduced_values)
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])

    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)

    stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
    return stats


@torch.no_grad()
def evaluate(model, criterion, dataloader, device):
    model.eval()

    metric_logger = MetricLogger(delimiter="  ")
    # metric_logger.add_meter('acc', SmoothedValue(window_size=1, fmt='{global_avg:.4f}'))
    # metric_logger.add_meter('loss', SmoothedValue(window_size=1, fmt='{value:.2f}'))
    metric_logger.add_meter("acc", SmoothedValue(window_size=1))
    metric_logger.add_meter("loss", SmoothedValue(window_size=1))

    header = "Test:"
    for imgs, labels in metric_logger.log_every(dataloader, 10, header):
        imgs, labels = imgs.to(device), labels.to(device)
        preds = model(imgs)

        acc = (preds.argmax(dim=-1) == labels).sum() / preds.shape[0]
        loss = criterion(preds, labels)

        dict_reduced = reduce_dict({"acc": acc, "loss": loss})
        reduced_values = {k: v.item() for k, v in dict_reduced.items()}
        metric_logger.update(**reduced_values)

    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)

    stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
    return stats