Spaces:
Runtime error
Runtime error
""" | |
An example config file to train a ImageNet classifier with detectron2. | |
Model and dataloader both come from torchvision. | |
This shows how to use detectron2 as a general engine for any new models and tasks. | |
To run, use the following command: | |
python tools/lazyconfig_train_net.py --config-file configs/Misc/torchvision_imagenet_R_50.py \ | |
--num-gpus 8 dataloader.train.dataset.root=/path/to/imagenet/ | |
""" | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from omegaconf import OmegaConf | |
import torchvision | |
from torchvision.transforms import transforms as T | |
from torchvision.models.resnet import ResNet, Bottleneck | |
from fvcore.common.param_scheduler import MultiStepParamScheduler | |
from detectron2.solver import WarmupParamScheduler | |
from detectron2.solver.build import get_default_optimizer_params | |
from detectron2.config import LazyCall as L | |
from detectron2.model_zoo import get_config | |
from detectron2.data.samplers import TrainingSampler, InferenceSampler | |
from detectron2.evaluation import DatasetEvaluator | |
from detectron2.utils import comm | |
""" | |
Note: Here we put reusable code (models, evaluation, data) together with configs just as a | |
proof-of-concept, to easily demonstrate what's needed to train a ImageNet classifier in detectron2. | |
Writing code in configs offers extreme flexibility but is often not a good engineering practice. | |
In practice, you might want to put code in your project and import them instead. | |
""" | |
def build_data_loader(dataset, batch_size, num_workers, training=True): | |
return torch.utils.data.DataLoader( | |
dataset, | |
sampler=(TrainingSampler if training else InferenceSampler)(len(dataset)), | |
batch_size=batch_size, | |
num_workers=num_workers, | |
pin_memory=True, | |
) | |
class ClassificationNet(nn.Module): | |
def __init__(self, model: nn.Module): | |
super().__init__() | |
self.model = model | |
def device(self): | |
return list(self.model.parameters())[0].device | |
def forward(self, inputs): | |
image, label = inputs | |
pred = self.model(image.to(self.device)) | |
if self.training: | |
label = label.to(self.device) | |
return F.cross_entropy(pred, label) | |
else: | |
return pred | |
class ClassificationAcc(DatasetEvaluator): | |
def reset(self): | |
self.corr = self.total = 0 | |
def process(self, inputs, outputs): | |
image, label = inputs | |
self.corr += (outputs.argmax(dim=1).cpu() == label.cpu()).sum().item() | |
self.total += len(label) | |
def evaluate(self): | |
all_corr_total = comm.all_gather([self.corr, self.total]) | |
corr = sum(x[0] for x in all_corr_total) | |
total = sum(x[1] for x in all_corr_total) | |
return {"accuracy": corr / total} | |
# --- End of code that could be in a project and be imported | |
dataloader = OmegaConf.create() | |
dataloader.train = L(build_data_loader)( | |
dataset=L(torchvision.datasets.ImageNet)( | |
root="/path/to/imagenet", | |
split="train", | |
transform=L(T.Compose)( | |
transforms=[ | |
L(T.RandomResizedCrop)(size=224), | |
L(T.RandomHorizontalFlip)(), | |
T.ToTensor(), | |
L(T.Normalize)(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), | |
] | |
), | |
), | |
batch_size=256 // 8, | |
num_workers=4, | |
training=True, | |
) | |
dataloader.test = L(build_data_loader)( | |
dataset=L(torchvision.datasets.ImageNet)( | |
root="${...train.dataset.root}", | |
split="val", | |
transform=L(T.Compose)( | |
transforms=[ | |
L(T.Resize)(size=256), | |
L(T.CenterCrop)(size=224), | |
T.ToTensor(), | |
L(T.Normalize)(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), | |
] | |
), | |
), | |
batch_size=256 // 8, | |
num_workers=4, | |
training=False, | |
) | |
dataloader.evaluator = L(ClassificationAcc)() | |
model = L(ClassificationNet)( | |
model=(ResNet)(block=Bottleneck, layers=[3, 4, 6, 3], zero_init_residual=True) | |
) | |
optimizer = L(torch.optim.SGD)( | |
params=L(get_default_optimizer_params)(), | |
lr=0.1, | |
momentum=0.9, | |
weight_decay=1e-4, | |
) | |
lr_multiplier = L(WarmupParamScheduler)( | |
scheduler=L(MultiStepParamScheduler)( | |
values=[1.0, 0.1, 0.01, 0.001], milestones=[30, 60, 90, 100] | |
), | |
warmup_length=1 / 100, | |
warmup_factor=0.1, | |
) | |
train = get_config("common/train.py").train | |
train.init_checkpoint = None | |
train.max_iter = 100 * 1281167 // 256 | |