navsim_ours / det_map /train_det.py
lkllkl's picture
Upload folder using huggingface_hub
da2e2ac verified
from typing import Tuple
import hydra
from hydra.utils import instantiate
import logging
from omegaconf import DictConfig
from pathlib import Path
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from det_map.data.datasets.dataset_det import DetDataset
from det_map.utils import collate_fn_pad_lidar
from det_map.data.datasets.dataset import Dataset
from navsim.planning.training.agent_lightning_module import AgentLightningModule
from det_map.data.datasets.dataloader import SceneLoader
from det_map.data.datasets.dataclasses import SceneFilter
from navsim.agents.abstract_agent import AbstractAgent
logger = logging.getLogger(__name__)
CONFIG_PATH = "config/"
CONFIG_NAME = "train_det"
def build_datasets(cfg: DictConfig, agent: AbstractAgent) -> Tuple[Dataset, Dataset]:
train_scene_filter: SceneFilter = instantiate(cfg.scene_filter)
train_scene_filter.log_names = cfg.train_logs
val_scene_filter: SceneFilter = instantiate(cfg.scene_filter)
val_scene_filter.log_names = cfg.val_logs
data_path = Path(cfg.navsim_log_path)
sensor_blobs_path = Path(cfg.sensor_blobs_path)
train_scene_loader = SceneLoader(
sensor_blobs_path=sensor_blobs_path,
data_path=data_path,
scene_filter=train_scene_filter,
sensor_config=agent.get_sensor_config(),
)
val_scene_loader = SceneLoader(
sensor_blobs_path=sensor_blobs_path,
data_path=data_path,
scene_filter=val_scene_filter,
sensor_config=agent.get_sensor_config(),
)
train_data = DetDataset(
scene_loader=train_scene_loader,
feature_builders=agent.get_feature_builders(),
target_builders=agent.get_target_builders(),
pipelines=agent.pipelines,
is_train=True
)
val_data = DetDataset(
scene_loader=val_scene_loader,
feature_builders=agent.get_feature_builders(),
target_builders=agent.get_target_builders(),
pipelines=agent.pipelines,
is_train=False
)
return train_data, val_data
@hydra.main(config_path=CONFIG_PATH, config_name=CONFIG_NAME)
def main(cfg: DictConfig) -> None:
logger.info("Global Seed set to 0")
pl.seed_everything(0, workers=True)
logger.info(f"Path where all results are stored: {cfg.output_dir}")
logger.info("Building Agent")
agent: AbstractAgent = instantiate(cfg.agent)
logger.info("Building Lightning Module")
lightning_module = AgentLightningModule(
agent=agent,
)
logger.info("Building SceneLoader")
train_data, val_data = build_datasets(cfg, agent)
logger.info("Building Datasets")
train_dataloader = DataLoader(train_data, **cfg.dataloader.params, shuffle=True, collate_fn=collate_fn_pad_lidar)
logger.info("Num training samples: %d", len(train_data))
val_dataloader = DataLoader(val_data, **cfg.dataloader.params, shuffle=False, collate_fn=collate_fn_pad_lidar)
logger.info("Num validation samples: %d", len(val_data))
logger.info("Building Trainer")
trainer = pl.Trainer(**cfg.trainer.params)
logger.info("Starting Training")
trainer.fit(
model=lightning_module,
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader,
)
if __name__ == "__main__":
main()