|
from typing import Tuple
|
|
import hydra
|
|
from hydra.utils import instantiate
|
|
import logging
|
|
from omegaconf import DictConfig
|
|
from pathlib import Path
|
|
import pytorch_lightning as pl
|
|
from torch.utils.data import DataLoader
|
|
|
|
from det_map.data.datasets.dataset_det import DetDataset
|
|
from det_map.utils import collate_fn_pad_lidar
|
|
from det_map.data.datasets.dataset import Dataset
|
|
from navsim.planning.training.agent_lightning_module import AgentLightningModule
|
|
from det_map.data.datasets.dataloader import SceneLoader
|
|
from det_map.data.datasets.dataclasses import SceneFilter
|
|
from navsim.agents.abstract_agent import AbstractAgent
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
CONFIG_PATH = "config/"
|
|
CONFIG_NAME = "train_det"
|
|
|
|
def build_datasets(cfg: DictConfig, agent: AbstractAgent) -> Tuple[Dataset, Dataset]:
|
|
train_scene_filter: SceneFilter = instantiate(cfg.scene_filter)
|
|
train_scene_filter.log_names = cfg.train_logs
|
|
|
|
val_scene_filter: SceneFilter = instantiate(cfg.scene_filter)
|
|
val_scene_filter.log_names = cfg.val_logs
|
|
|
|
data_path = Path(cfg.navsim_log_path)
|
|
sensor_blobs_path = Path(cfg.sensor_blobs_path)
|
|
|
|
train_scene_loader = SceneLoader(
|
|
sensor_blobs_path=sensor_blobs_path,
|
|
data_path=data_path,
|
|
scene_filter=train_scene_filter,
|
|
sensor_config=agent.get_sensor_config(),
|
|
)
|
|
|
|
val_scene_loader = SceneLoader(
|
|
sensor_blobs_path=sensor_blobs_path,
|
|
data_path=data_path,
|
|
scene_filter=val_scene_filter,
|
|
sensor_config=agent.get_sensor_config(),
|
|
)
|
|
|
|
train_data = DetDataset(
|
|
scene_loader=train_scene_loader,
|
|
feature_builders=agent.get_feature_builders(),
|
|
target_builders=agent.get_target_builders(),
|
|
pipelines=agent.pipelines,
|
|
is_train=True
|
|
)
|
|
|
|
val_data = DetDataset(
|
|
scene_loader=val_scene_loader,
|
|
feature_builders=agent.get_feature_builders(),
|
|
target_builders=agent.get_target_builders(),
|
|
pipelines=agent.pipelines,
|
|
is_train=False
|
|
)
|
|
|
|
return train_data, val_data
|
|
|
|
|
|
@hydra.main(config_path=CONFIG_PATH, config_name=CONFIG_NAME)
|
|
def main(cfg: DictConfig) -> None:
|
|
logger.info("Global Seed set to 0")
|
|
pl.seed_everything(0, workers=True)
|
|
|
|
logger.info(f"Path where all results are stored: {cfg.output_dir}")
|
|
|
|
logger.info("Building Agent")
|
|
agent: AbstractAgent = instantiate(cfg.agent)
|
|
|
|
logger.info("Building Lightning Module")
|
|
lightning_module = AgentLightningModule(
|
|
agent=agent,
|
|
)
|
|
|
|
logger.info("Building SceneLoader")
|
|
train_data, val_data = build_datasets(cfg, agent)
|
|
|
|
logger.info("Building Datasets")
|
|
train_dataloader = DataLoader(train_data, **cfg.dataloader.params, shuffle=True, collate_fn=collate_fn_pad_lidar)
|
|
logger.info("Num training samples: %d", len(train_data))
|
|
val_dataloader = DataLoader(val_data, **cfg.dataloader.params, shuffle=False, collate_fn=collate_fn_pad_lidar)
|
|
logger.info("Num validation samples: %d", len(val_data))
|
|
|
|
logger.info("Building Trainer")
|
|
trainer = pl.Trainer(**cfg.trainer.params)
|
|
|
|
logger.info("Starting Training")
|
|
trainer.fit(
|
|
model=lightning_module,
|
|
train_dataloaders=train_dataloader,
|
|
val_dataloaders=val_dataloader,
|
|
)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|