"""
An example config file to train a ImageNet classifier with detectron2.
Model and dataloader both come from torchvision.
This shows how to use detectron2 as a general engine for any new models and tasks.

To run, use the following command:

python tools/lazyconfig_train_net.py --config-file configs/Misc/torchvision_imagenet_R_50.py \
    --num-gpus 8 dataloader.train.dataset.root=/path/to/imagenet/

"""


import torch
from torch import nn
from torch.nn import functional as F
from omegaconf import OmegaConf
import torchvision
from torchvision.transforms import transforms as T
from torchvision.models.resnet import ResNet, Bottleneck
from fvcore.common.param_scheduler import MultiStepParamScheduler

from detectron2.solver import WarmupParamScheduler
from detectron2.solver.build import get_default_optimizer_params
from detectron2.config import LazyCall as L
from detectron2.model_zoo import get_config
from detectron2.data.samplers import TrainingSampler, InferenceSampler
from detectron2.evaluation import DatasetEvaluator
from detectron2.utils import comm


"""
Note: Here we put reusable code (models, evaluation, data) together with configs just as a
proof-of-concept, to easily demonstrate what's needed to train a ImageNet classifier in detectron2.
Writing code in configs offers extreme flexibility but is often not a good engineering practice.
In practice, you might want to put code in your project and import them instead.
"""


def build_data_loader(dataset, batch_size, num_workers, training=True):
    return torch.utils.data.DataLoader(
        dataset,
        sampler=(TrainingSampler if training else InferenceSampler)(len(dataset)),
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=True,
    )


class ClassificationNet(nn.Module):
    def __init__(self, model: nn.Module):
        super().__init__()
        self.model = model

    @property
    def device(self):
        return list(self.model.parameters())[0].device

    def forward(self, inputs):
        image, label = inputs
        pred = self.model(image.to(self.device))
        if self.training:
            label = label.to(self.device)
            return F.cross_entropy(pred, label)
        else:
            return pred


class ClassificationAcc(DatasetEvaluator):
    def reset(self):
        self.corr = self.total = 0

    def process(self, inputs, outputs):
        image, label = inputs
        self.corr += (outputs.argmax(dim=1).cpu() == label.cpu()).sum().item()
        self.total += len(label)

    def evaluate(self):
        all_corr_total = comm.all_gather([self.corr, self.total])
        corr = sum(x[0] for x in all_corr_total)
        total = sum(x[1] for x in all_corr_total)
        return {"accuracy": corr / total}


# --- End of code that could be in a project and be imported


dataloader = OmegaConf.create()
dataloader.train = L(build_data_loader)(
    dataset=L(torchvision.datasets.ImageNet)(
        root="/path/to/imagenet",
        split="train",
        transform=L(T.Compose)(
            transforms=[
                L(T.RandomResizedCrop)(size=224),
                L(T.RandomHorizontalFlip)(),
                T.ToTensor(),
                L(T.Normalize)(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ]
        ),
    ),
    batch_size=256 // 8,
    num_workers=4,
    training=True,
)

dataloader.test = L(build_data_loader)(
    dataset=L(torchvision.datasets.ImageNet)(
        root="${...train.dataset.root}",
        split="val",
        transform=L(T.Compose)(
            transforms=[
                L(T.Resize)(size=256),
                L(T.CenterCrop)(size=224),
                T.ToTensor(),
                L(T.Normalize)(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ]
        ),
    ),
    batch_size=256 // 8,
    num_workers=4,
    training=False,
)

dataloader.evaluator = L(ClassificationAcc)()

model = L(ClassificationNet)(
    model=(ResNet)(block=Bottleneck, layers=[3, 4, 6, 3], zero_init_residual=True)
)


optimizer = L(torch.optim.SGD)(
    params=L(get_default_optimizer_params)(),
    lr=0.1,
    momentum=0.9,
    weight_decay=1e-4,
)

lr_multiplier = L(WarmupParamScheduler)(
    scheduler=L(MultiStepParamScheduler)(
        values=[1.0, 0.1, 0.01, 0.001], milestones=[30, 60, 90, 100]
    ),
    warmup_length=1 / 100,
    warmup_factor=0.1,
)


train = get_config("common/train.py").train
train.init_checkpoint = None
train.max_iter = 100 * 1281167 // 256