# Config format schema number, the yaml support to valid case source from different dataset time_info: &time_info num_historical_steps: 11 num_future_steps: 80 use_intention: True token_size: 2048 predict_motion: True predict_state: True predict_map: True predict_occ: True state_token: invalid: 0 valid: 1 enter: 2 exit: 3 pl2seed_radius: 75. disable_grid_token: True grid_range: 150. # 2 times of pl2seed_radius grid_interval: 3. angle_interval: 3. seed_size: 1 buffer_size: 128 max_num: 32 Dataset: root: train_batch_size: 1 val_batch_size: 1 test_batch_size: 1 shuffle: True num_workers: 1 pin_memory: True persistent_workers: True train_raw_dir: 'data/waymo_processed/training' val_raw_dir: 'data/waymo_processed/validation' test_raw_dir: 'data/waymo_processed/validation' val_tfrecords_splitted: data/waymo_processed/training/validation_tfrecords_splitted transform: WaymoTargetBuilder train_processed_dir: val_processed_dir: test_processed_dir: dataset: 'scalable' <<: *time_info Trainer: strategy: ddp_find_unused_parameters_false accelerator: 'gpu' devices: 1 max_epochs: 32 save_ckpt_path: num_nodes: 1 mode: ckpt_path: precision: 32 accumulate_grad_batches: 1 overfit_epochs: 6000 Model: predictor: 'smart' decoder_type: 'agent_decoder' # choose from ['agent_decoder', 'occ_decoder'] dataset: 'waymo' input_dim: 2 hidden_dim: 128 output_dim: 2 output_head: False num_heads: 8 <<: *time_info head_dim: 16 dropout: 0.1 num_freq_bands: 64 lr: 0.0005 warmup_steps: 0 total_steps: 32 predict_map_token: False num_recurrent_steps_val: 300 val_open_loop: False val_close_loop: True val_insert: False n_rollout_close_val: 1 decoder: <<: *time_info num_map_layers: 3 num_agent_layers: 6 a2a_radius: 60 pl2pl_radius: 10 pl2a_radius: 30 a2sa_radius: 10 pl2sa_radius: 10 time_span: 60 loss_weight: token_cls_loss: 1 map_token_loss: 1 state_cls_loss: 10 type_cls_loss: 5 pos_cls_loss: 1 head_cls_loss: 1 offset_reg_loss: 5 shape_reg_loss: .2 pos_reg_loss: 10 state_weight: [0.1, 0.1, 0.8] # invalid, valid, exit seed_state_weight: [0.1, 0.9] # invalid, enter seed_type_weight: [0.8, 0.1, 0.1] agent_occ_pos_weight: 100 pt_occ_pos_weight: 5 agent_occ_loss: 10 pt_occ_loss: 10