File size: 3,372 Bytes
fd01725
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import time
import torch
import hydra
import pytorch_lightning as pl
from typing import Any

from hydra.core.config_store import ConfigStore
from omegaconf import OmegaConf
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint

from pathlib import Path
from dataclasses import dataclass

from .module import GenericModule
from .data.module import GenericDataModule
from .callbacks import EvalSaveCallback, ImageLoggerCallback
from .models.schema import ModelConfiguration, DINOConfiguration, ResNetConfiguration
from .data.schema import MIADataConfiguration, KITTIDataConfiguration, NuScenesDataConfiguration


@dataclass
class ExperimentConfiguration:
    name: str

@dataclass
class Configuration:
    model: ModelConfiguration
    experiment: ExperimentConfiguration
    data: Any
    training: Any


cs = ConfigStore.instance()

# Store root configuration schema
cs.store(name="pretrain", node=Configuration)
cs.store(name="mapper_nuscenes", node=Configuration)
cs.store(name="mapper_kitti", node=Configuration)

# Store data configuration schema
cs.store(group="schema/data", name="mia",
         node=MIADataConfiguration, package="data")
cs.store(group="schema/data", name="kitti", node=KITTIDataConfiguration, package="data")
cs.store(group="schema/data", name="nuscenes", node=NuScenesDataConfiguration, package="data")

cs.store(group="model/schema/backbone", name="dino", node=DINOConfiguration, package="model.image_encoder.backbone")
cs.store(group="model/schema/backbone", name="resnet", node=ResNetConfiguration, package="model.image_encoder.backbone")


@hydra.main(version_base=None, config_path="conf", config_name="pretrain")
def train(cfg: Configuration):
    OmegaConf.resolve(cfg)

    dm = GenericDataModule(cfg.data)

    model = GenericModule(cfg)

    exp_name_with_time = cfg.experiment.name + \
        "_" + time.strftime("%Y-%m-%d_%H-%M-%S")

    callbacks: list[pl.Callback]

    if cfg.training.eval:
        save_dir = Path(cfg.training.save_dir)
        save_dir.mkdir(parents=True, exist_ok=True)

        callbacks = [
            EvalSaveCallback(save_dir=save_dir)
        ]

        logger = None
    else:
        callbacks = [
            ImageLoggerCallback(num_classes=cfg.training.num_classes),
            ModelCheckpoint(
                monitor=cfg.training.checkpointing.monitor,
                save_last=cfg.training.checkpointing.save_last,
                save_top_k=cfg.training.checkpointing.save_top_k,
            )
        ]

        logger = WandbLogger(
            name=exp_name_with_time,
            id=exp_name_with_time,
            entity="mappred-large",
            project="map-pred-full-v3",
        )

        logger.watch(model, log="all", log_freq=500)

    if cfg.training.checkpoint is not None:
        state_dict = torch.load(cfg.training.checkpoint)['state_dict']
        model.load_state_dict(state_dict, strict=False)

    trainer_args = OmegaConf.to_container(cfg.training.trainer)
    trainer_args['callbacks'] = callbacks
    trainer_args['logger'] = logger

    trainer = pl.Trainer(**trainer_args)

    if cfg.training.eval:
        trainer.test(model, datamodule=dm)
    else:
        trainer.fit(model, datamodule=dm)


if __name__ == "__main__":
    pl.seed_everything(42)
    torch.set_float32_matmul_precision("high")

    train()