File size: 4,497 Bytes
6f6830f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""
An example config file to train a ImageNet classifier with detectron2.
Model and dataloader both come from torchvision.
This shows how to use detectron2 as a general engine for any new models and tasks.

To run, use the following command:

python tools/lazyconfig_train_net.py --config-file configs/Misc/torchvision_imagenet_R_50.py \
    --num-gpus 8 dataloader.train.dataset.root=/path/to/imagenet/

"""


import torch
from torch import nn
from torch.nn import functional as F
from omegaconf import OmegaConf
import torchvision
from torchvision.transforms import transforms as T
from torchvision.models.resnet import ResNet, Bottleneck
from fvcore.common.param_scheduler import MultiStepParamScheduler

from detectron2.solver import WarmupParamScheduler
from detectron2.solver.build import get_default_optimizer_params
from detectron2.config import LazyCall as L
from detectron2.model_zoo import get_config
from detectron2.data.samplers import TrainingSampler, InferenceSampler
from detectron2.evaluation import DatasetEvaluator
from detectron2.utils import comm


"""
Note: Here we put reusable code (models, evaluation, data) together with configs just as a
proof-of-concept, to easily demonstrate what's needed to train a ImageNet classifier in detectron2.
Writing code in configs offers extreme flexibility but is often not a good engineering practice.
In practice, you might want to put code in your project and import them instead.
"""


def build_data_loader(dataset, batch_size, num_workers, training=True):
    return torch.utils.data.DataLoader(
        dataset,
        sampler=(TrainingSampler if training else InferenceSampler)(len(dataset)),
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=True,
    )


class ClassificationNet(nn.Module):
    def __init__(self, model: nn.Module):
        super().__init__()
        self.model = model

    @property
    def device(self):
        return list(self.model.parameters())[0].device

    def forward(self, inputs):
        image, label = inputs
        pred = self.model(image.to(self.device))
        if self.training:
            label = label.to(self.device)
            return F.cross_entropy(pred, label)
        else:
            return pred


class ClassificationAcc(DatasetEvaluator):
    def reset(self):
        self.corr = self.total = 0

    def process(self, inputs, outputs):
        image, label = inputs
        self.corr += (outputs.argmax(dim=1).cpu() == label.cpu()).sum().item()
        self.total += len(label)

    def evaluate(self):
        all_corr_total = comm.all_gather([self.corr, self.total])
        corr = sum(x[0] for x in all_corr_total)
        total = sum(x[1] for x in all_corr_total)
        return {"accuracy": corr / total}


# --- End of code that could be in a project and be imported


dataloader = OmegaConf.create()
dataloader.train = L(build_data_loader)(
    dataset=L(torchvision.datasets.ImageNet)(
        root="/path/to/imagenet",
        split="train",
        transform=L(T.Compose)(
            transforms=[
                L(T.RandomResizedCrop)(size=224),
                L(T.RandomHorizontalFlip)(),
                T.ToTensor(),
                L(T.Normalize)(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ]
        ),
    ),
    batch_size=256 // 8,
    num_workers=4,
    training=True,
)

dataloader.test = L(build_data_loader)(
    dataset=L(torchvision.datasets.ImageNet)(
        root="${...train.dataset.root}",
        split="val",
        transform=L(T.Compose)(
            transforms=[
                L(T.Resize)(size=256),
                L(T.CenterCrop)(size=224),
                T.ToTensor(),
                L(T.Normalize)(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ]
        ),
    ),
    batch_size=256 // 8,
    num_workers=4,
    training=False,
)

dataloader.evaluator = L(ClassificationAcc)()

model = L(ClassificationNet)(
    model=(ResNet)(block=Bottleneck, layers=[3, 4, 6, 3], zero_init_residual=True)
)


optimizer = L(torch.optim.SGD)(
    params=L(get_default_optimizer_params)(),
    lr=0.1,
    momentum=0.9,
    weight_decay=1e-4,
)

lr_multiplier = L(WarmupParamScheduler)(
    scheduler=L(MultiStepParamScheduler)(
        values=[1.0, 0.1, 0.01, 0.001], milestones=[30, 60, 90, 100]
    ),
    warmup_length=1 / 100,
    warmup_factor=0.1,
)


train = get_config("common/train.py").train
train.init_checkpoint = None
train.max_iter = 100 * 1281167 // 256