# @package _global_ scratch: resolution: 1024 train_batch_size: 1 num_train_workers: 10 num_frames: 8 max_num_objects: 3 base_lr: 5.0e-6 vision_lr: 3.0e-06 phases_per_epoch: 1 num_epochs: 40 dataset: # PATHS to Dataset img_folder: null # PATH to MOSE JPEGImages folder gt_folder: null # PATH to MOSE Annotations folder file_list_txt: training/assets/MOSE_sample_train_list.txt # Optional PATH to filelist containing a subset of videos to be used for training multiplier: 2 # Video transforms vos: train_transforms: - _target_: training.dataset.transforms.ComposeAPI transforms: - _target_: training.dataset.transforms.RandomHorizontalFlip consistent_transform: True - _target_: training.dataset.transforms.RandomAffine degrees: 25 shear: 20 image_interpolation: bilinear consistent_transform: True - _target_: training.dataset.transforms.RandomResizeAPI sizes: ${scratch.resolution} square: true consistent_transform: True - _target_: training.dataset.transforms.ColorJitter consistent_transform: True brightness: 0.1 contrast: 0.03 saturation: 0.03 hue: null - _target_: training.dataset.transforms.RandomGrayscale p: 0.05 consistent_transform: True - _target_: training.dataset.transforms.ColorJitter consistent_transform: False brightness: 0.1 contrast: 0.05 saturation: 0.05 hue: null - _target_: training.dataset.transforms.ToTensorAPI - _target_: training.dataset.transforms.NormalizeAPI mean: [0.485, 0.456, 0.406] std: [0.229, 0.224, 0.225] trainer: _target_: training.trainer.Trainer mode: train_only max_epochs: ${times:${scratch.num_epochs},${scratch.phases_per_epoch}} accelerator: cuda seed_value: 123 model: _target_: training.model.sam2.SAM2Train image_encoder: _target_: sam2.modeling.backbones.image_encoder.ImageEncoder scalp: 1 trunk: _target_: sam2.modeling.backbones.hieradet.Hiera embed_dim: 112 num_heads: 2 drop_path_rate: 0.1 neck: _target_: sam2.modeling.backbones.image_encoder.FpnNeck position_encoding: _target_: sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 256 normalize: true scale: null temperature: 10000 d_model: 256 backbone_channel_list: [896, 448, 224, 112] fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features fpn_interp_model: nearest memory_attention: _target_: sam2.modeling.memory_attention.MemoryAttention d_model: 256 pos_enc_at_input: true layer: _target_: sam2.modeling.memory_attention.MemoryAttentionLayer activation: relu dim_feedforward: 2048 dropout: 0.1 pos_enc_at_attn: false self_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [32, 32] embedding_dim: 256 num_heads: 1 downsample_rate: 1 dropout: 0.1 d_model: 256 pos_enc_at_cross_attn_keys: true pos_enc_at_cross_attn_queries: false cross_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [32, 32] rope_k_repeat: True embedding_dim: 256 num_heads: 1 downsample_rate: 1 dropout: 0.1 kv_in_dim: 64 num_layers: 4 memory_encoder: _target_: sam2.modeling.memory_encoder.MemoryEncoder out_dim: 64 position_encoding: _target_: sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 64 normalize: true scale: null temperature: 10000 mask_downsampler: _target_: sam2.modeling.memory_encoder.MaskDownSampler kernel_size: 3 stride: 2 padding: 1 fuser: _target_: sam2.modeling.memory_encoder.Fuser layer: _target_: sam2.modeling.memory_encoder.CXBlock dim: 256 kernel_size: 7 padding: 3 layer_scale_init_value: 1e-6 use_dwconv: True # depth-wise convs num_layers: 2 num_maskmem: 7 image_size: ${scratch.resolution} # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask sigmoid_scale_for_mem_enc: 20.0 sigmoid_bias_for_mem_enc: -10.0 use_mask_input_as_output_without_sam: true # Memory directly_add_no_mem_embed: true no_obj_embed_spatial: true # use high-resolution feature map in the SAM mask decoder use_high_res_features_in_sam: true # output 3 masks on the first click on initial conditioning frames multimask_output_in_sam: true # SAM heads iou_prediction_use_sigmoid: True # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder use_obj_ptrs_in_encoder: true add_tpos_enc_to_obj_ptrs: true proj_tpos_enc_in_obj_ptrs: true use_signed_tpos_enc_to_obj_ptrs: true only_obj_ptrs_in_the_past_for_eval: true # object occlusion prediction pred_obj_scores: true pred_obj_scores_mlp: true fixed_no_obj_ptr: true # multimask tracking settings multimask_output_for_tracking: true use_multimask_token_for_obj_ptr: true multimask_min_pt_num: 0 multimask_max_pt_num: 1 use_mlp_for_obj_ptr_proj: true # Compilation flag # compile_image_encoder: False ####### Training specific params ####### # box/point input and corrections prob_to_use_pt_input_for_train: 0.5 prob_to_use_pt_input_for_eval: 0.0 prob_to_use_box_input_for_train: 0.5 # 0.5*0.5 = 0.25 prob to use box instead of points prob_to_use_box_input_for_eval: 0.0 prob_to_sample_from_gt_for_train: 0.1 # with a small prob, sampling correction points from GT mask instead of prediction errors num_frames_to_correct_for_train: 2 # iteratively sample on random 1~2 frames (always include the first frame) num_frames_to_correct_for_eval: 1 # only iteratively sample on first frame rand_frames_to_correct_for_train: True # random #init-cond-frame ~ 2 add_all_frames_to_correct_as_cond: True # when a frame receives a correction click, it becomes a conditioning frame (even if it's not initially a conditioning frame) # maximum 2 initial conditioning frames num_init_cond_frames_for_train: 2 rand_init_cond_frames_for_train: True # random 1~2 num_correction_pt_per_frame: 7 use_act_ckpt_iterative_pt_sampling: false num_init_cond_frames_for_eval: 1 # only mask on the first frame forward_backbone_per_frame_for_eval: True data: train: _target_: training.dataset.sam2_datasets.TorchTrainMixedDataset phases_per_epoch: ${scratch.phases_per_epoch} batch_sizes: - ${scratch.train_batch_size} datasets: - _target_: training.dataset.utils.RepeatFactorWrapper dataset: _target_: training.dataset.utils.ConcatDataset datasets: - _target_: training.dataset.vos_dataset.VOSDataset transforms: ${vos.train_transforms} training: true video_dataset: _target_: training.dataset.vos_raw_dataset.PNGRawDataset img_folder: ${dataset.img_folder} gt_folder: ${dataset.gt_folder} file_list_txt: ${dataset.file_list_txt} sampler: _target_: training.dataset.vos_sampler.RandomUniformSampler num_frames: ${scratch.num_frames} max_num_objects: ${scratch.max_num_objects} multiplier: ${dataset.multiplier} shuffle: True num_workers: ${scratch.num_train_workers} pin_memory: True drop_last: True collate_fn: _target_: training.utils.data_utils.collate_fn _partial_: true dict_key: all optim: amp: enabled: True amp_dtype: bfloat16 optimizer: _target_: torch.optim.AdamW gradient_clip: _target_: training.optimizer.GradientClipper max_norm: 0.1 norm_type: 2 param_group_modifiers: - _target_: training.optimizer.layer_decay_param_modifier _partial_: True layer_decay_value: 0.9 apply_to: 'image_encoder.trunk' overrides: - pattern: '*pos_embed*' value: 1.0 options: lr: - scheduler: _target_: fvcore.common.param_scheduler.CosineParamScheduler start_value: ${scratch.base_lr} end_value: ${divide:${scratch.base_lr},10} - scheduler: _target_: fvcore.common.param_scheduler.CosineParamScheduler start_value: ${scratch.vision_lr} end_value: ${divide:${scratch.vision_lr},10} param_names: - 'image_encoder.*' weight_decay: - scheduler: _target_: fvcore.common.param_scheduler.ConstantParamScheduler value: 0.1 - scheduler: _target_: fvcore.common.param_scheduler.ConstantParamScheduler value: 0.0 param_names: - '*bias*' module_cls_names: ['torch.nn.LayerNorm'] loss: all: _target_: training.loss_fns.MultiStepMultiMasksAndIous weight_dict: loss_mask: 20 loss_dice: 1 loss_iou: 1 loss_class: 1 supervise_all_iou: true iou_use_l1_loss: true pred_obj_scores: true focal_gamma_obj_score: 0.0 focal_alpha_obj_score: -1.0 distributed: backend: nccl find_unused_parameters: True logging: tensorboard_writer: _target_: training.utils.logger.make_tensorboard_logger log_dir: ${launcher.experiment_log_dir}/tensorboard flush_secs: 120 should_log: True log_dir: ${launcher.experiment_log_dir}/logs log_freq: 10 # initialize from a SAM 2 checkpoint checkpoint: save_dir: ${launcher.experiment_log_dir}/checkpoints save_freq: 0 # 0 only last checkpoint is saved. model_weight_initializer: _partial_: True _target_: training.utils.checkpoint_utils.load_state_dict_into_model strict: True ignore_unexpected_keys: null ignore_missing_keys: null state_dict: _target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels checkpoint_path: ./checkpoints/sam2.1_hiera_base_plus.pt # PATH to SAM 2.1 checkpoint ckpt_state_dict_keys: ['model'] launcher: num_nodes: 1 gpus_per_node: 8 experiment_log_dir: null # Path to log directory, defaults to ./sam2_logs/${config_name} # SLURM args if running on a cluster submitit: partition: null account: null qos: null cpus_per_task: 10 use_cluster: false timeout_hour: 24 name: null port_range: [10000, 65000]