SAM2Long-Demo / sam2 /configs /sam2.1_training /sam2.1_hiera_b+_MOSE_finetune.yaml
Mar2Ding's picture
Upload 61 files
b578f14 verified
raw
history blame
11.4 kB
# @package _global_
scratch:
resolution: 1024
train_batch_size: 1
num_train_workers: 10
num_frames: 8
max_num_objects: 3
base_lr: 5.0e-6
vision_lr: 3.0e-06
phases_per_epoch: 1
num_epochs: 40
dataset:
# PATHS to Dataset
img_folder: /fsx-onevision/shared/data/academic_vos_data/MOSE/train/JPEGImages # PATH to MOSE JPEGImages folder
gt_folder: /fsx-onevision/shared/data/academic_vos_data/MOSE/train/Annotations/ # PATH to MOSE Annotations folder
file_list_txt: training/assets/MOSE_sample_train_list.txt # Optional PATH to filelist containing a subset of videos to be used for training
multiplier: 2
# Video transforms
vos:
train_transforms:
- _target_: training.dataset.transforms.ComposeAPI
transforms:
- _target_: training.dataset.transforms.RandomHorizontalFlip
consistent_transform: True
- _target_: training.dataset.transforms.RandomAffine
degrees: 25
shear: 20
image_interpolation: bilinear
consistent_transform: True
- _target_: training.dataset.transforms.RandomResizeAPI
sizes: ${scratch.resolution}
square: true
consistent_transform: True
- _target_: training.dataset.transforms.ColorJitter
consistent_transform: True
brightness: 0.1
contrast: 0.03
saturation: 0.03
hue: null
- _target_: training.dataset.transforms.RandomGrayscale
p: 0.05
consistent_transform: True
- _target_: training.dataset.transforms.ColorJitter
consistent_transform: False
brightness: 0.1
contrast: 0.05
saturation: 0.05
hue: null
- _target_: training.dataset.transforms.ToTensorAPI
- _target_: training.dataset.transforms.NormalizeAPI
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
trainer:
_target_: training.trainer.Trainer
mode: train_only
max_epochs: ${times:${scratch.num_epochs},${scratch.phases_per_epoch}}
accelerator: cuda
seed_value: 123
model:
_target_: training.model.sam2.SAM2Train
image_encoder:
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
scalp: 1
trunk:
_target_: sam2.modeling.backbones.hieradet.Hiera
embed_dim: 112
num_heads: 2
drop_path_rate: 0.1
neck:
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 256
normalize: true
scale: null
temperature: 10000
d_model: 256
backbone_channel_list: [896, 448, 224, 112]
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
fpn_interp_model: nearest
memory_attention:
_target_: sam2.modeling.memory_attention.MemoryAttention
d_model: 256
pos_enc_at_input: true
layer:
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
activation: relu
dim_feedforward: 2048
dropout: 0.1
pos_enc_at_attn: false
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
d_model: 256
pos_enc_at_cross_attn_keys: true
pos_enc_at_cross_attn_queries: false
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
kv_in_dim: 64
num_layers: 4
memory_encoder:
_target_: sam2.modeling.memory_encoder.MemoryEncoder
out_dim: 64
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 64
normalize: true
scale: null
temperature: 10000
mask_downsampler:
_target_: sam2.modeling.memory_encoder.MaskDownSampler
kernel_size: 3
stride: 2
padding: 1
fuser:
_target_: sam2.modeling.memory_encoder.Fuser
layer:
_target_: sam2.modeling.memory_encoder.CXBlock
dim: 256
kernel_size: 7
padding: 3
layer_scale_init_value: 1e-6
use_dwconv: True # depth-wise convs
num_layers: 2
num_maskmem: 7
image_size: ${scratch.resolution}
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
sigmoid_scale_for_mem_enc: 20.0
sigmoid_bias_for_mem_enc: -10.0
use_mask_input_as_output_without_sam: true
# Memory
directly_add_no_mem_embed: true
no_obj_embed_spatial: true
# use high-resolution feature map in the SAM mask decoder
use_high_res_features_in_sam: true
# output 3 masks on the first click on initial conditioning frames
multimask_output_in_sam: true
# SAM heads
iou_prediction_use_sigmoid: True
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
use_obj_ptrs_in_encoder: true
add_tpos_enc_to_obj_ptrs: true
proj_tpos_enc_in_obj_ptrs: true
use_signed_tpos_enc_to_obj_ptrs: true
only_obj_ptrs_in_the_past_for_eval: true
# object occlusion prediction
pred_obj_scores: true
pred_obj_scores_mlp: true
fixed_no_obj_ptr: true
# multimask tracking settings
multimask_output_for_tracking: true
use_multimask_token_for_obj_ptr: true
multimask_min_pt_num: 0
multimask_max_pt_num: 1
use_mlp_for_obj_ptr_proj: true
# Compilation flag
# compile_image_encoder: False
####### Training specific params #######
# box/point input and corrections
prob_to_use_pt_input_for_train: 0.5
prob_to_use_pt_input_for_eval: 0.0
prob_to_use_box_input_for_train: 0.5 # 0.5*0.5 = 0.25 prob to use box instead of points
prob_to_use_box_input_for_eval: 0.0
prob_to_sample_from_gt_for_train: 0.1 # with a small prob, sampling correction points from GT mask instead of prediction errors
num_frames_to_correct_for_train: 2 # iteratively sample on random 1~2 frames (always include the first frame)
num_frames_to_correct_for_eval: 1 # only iteratively sample on first frame
rand_frames_to_correct_for_train: True # random #init-cond-frame ~ 2
add_all_frames_to_correct_as_cond: True # when a frame receives a correction click, it becomes a conditioning frame (even if it's not initially a conditioning frame)
# maximum 2 initial conditioning frames
num_init_cond_frames_for_train: 2
rand_init_cond_frames_for_train: True # random 1~2
num_correction_pt_per_frame: 7
use_act_ckpt_iterative_pt_sampling: false
num_init_cond_frames_for_eval: 1 # only mask on the first frame
forward_backbone_per_frame_for_eval: True
data:
train:
_target_: training.dataset.sam2_datasets.TorchTrainMixedDataset
phases_per_epoch: ${scratch.phases_per_epoch}
batch_sizes:
- ${scratch.train_batch_size}
datasets:
- _target_: training.dataset.utils.RepeatFactorWrapper
dataset:
_target_: training.dataset.utils.ConcatDataset
datasets:
- _target_: training.dataset.vos_dataset.VOSDataset
transforms: ${vos.train_transforms}
training: true
video_dataset:
_target_: training.dataset.vos_raw_dataset.PNGRawDataset
img_folder: ${dataset.img_folder}
gt_folder: ${dataset.gt_folder}
file_list_txt: ${dataset.file_list_txt}
sampler:
_target_: training.dataset.vos_sampler.RandomUniformSampler
num_frames: ${scratch.num_frames}
max_num_objects: ${scratch.max_num_objects}
multiplier: ${dataset.multiplier}
shuffle: True
num_workers: ${scratch.num_train_workers}
pin_memory: True
drop_last: True
collate_fn:
_target_: training.utils.data_utils.collate_fn
_partial_: true
dict_key: all
optim:
amp:
enabled: True
amp_dtype: bfloat16
optimizer:
_target_: torch.optim.AdamW
gradient_clip:
_target_: training.optimizer.GradientClipper
max_norm: 0.1
norm_type: 2
param_group_modifiers:
- _target_: training.optimizer.layer_decay_param_modifier
_partial_: True
layer_decay_value: 0.9
apply_to: 'image_encoder.trunk'
overrides:
- pattern: '*pos_embed*'
value: 1.0
options:
lr:
- scheduler:
_target_: fvcore.common.param_scheduler.CosineParamScheduler
start_value: ${scratch.base_lr}
end_value: ${divide:${scratch.base_lr},10}
- scheduler:
_target_: fvcore.common.param_scheduler.CosineParamScheduler
start_value: ${scratch.vision_lr}
end_value: ${divide:${scratch.vision_lr},10}
param_names:
- 'image_encoder.*'
weight_decay:
- scheduler:
_target_: fvcore.common.param_scheduler.ConstantParamScheduler
value: 0.1
- scheduler:
_target_: fvcore.common.param_scheduler.ConstantParamScheduler
value: 0.0
param_names:
- '*bias*'
module_cls_names: ['torch.nn.LayerNorm']
loss:
all:
_target_: training.loss_fns.MultiStepMultiMasksAndIous
weight_dict:
loss_mask: 20
loss_dice: 1
loss_iou: 1
loss_class: 1
supervise_all_iou: true
iou_use_l1_loss: true
pred_obj_scores: true
focal_gamma_obj_score: 0.0
focal_alpha_obj_score: -1.0
distributed:
backend: nccl
find_unused_parameters: True
logging:
tensorboard_writer:
_target_: training.utils.logger.make_tensorboard_logger
log_dir: ${launcher.experiment_log_dir}/tensorboard
flush_secs: 120
should_log: True
log_dir: ${launcher.experiment_log_dir}/logs
log_freq: 10
# initialize from a SAM 2 checkpoint
checkpoint:
save_dir: ${launcher.experiment_log_dir}/checkpoints
save_freq: 0 # 0 only last checkpoint is saved.
model_weight_initializer:
_partial_: True
_target_: training.utils.checkpoint_utils.load_state_dict_into_model
strict: True
ignore_unexpected_keys: null
ignore_missing_keys: null
state_dict:
_target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels
checkpoint_path: ./checkpoints/sam2.1_hiera_base_plus.pt # PATH to SAM 2.1 checkpoint
ckpt_state_dict_keys: ['model']
launcher:
num_nodes: 1
gpus_per_node: 8
experiment_log_dir: null # Path to log directory, defaults to ./sam2_logs/${config_name}
# SLURM args if running on a cluster
submitit:
partition: null
account: null
qos: null
cpus_per_task: 10
use_cluster: false
timeout_hour: 24
name: null
port_range: [10000, 65000]