diff --git a/DOCKERFILE b/DOCKERFILE new file mode 100644 index 0000000000000000000000000000000000000000..8d1a6022b54a4edda67eacc6ba92734bc4f053a4 --- /dev/null +++ b/DOCKERFILE @@ -0,0 +1,29 @@ +FROM continuumio/anaconda3:main + +WORKDIR /code +COPY ./environment_docker.yml /code/environment_docker.yml + +# Create the environment using the environment.yml file +RUN conda env create -f /code/environment_docker.yml + +# Set up a new user named "user" with user ID 1000 +RUN useradd -m -u 1000 user +# Switch to the "user" user +USER user +# Set home to the user's home directory +ENV HOME=/home/user \ + PYTHONPATH=$HOME/app \ + PYTHONUNBUFFERED=1 \ + GRADIO_ALLOW_FLAGGING=never \ + GRADIO_NUM_PORTS=1 \ + GRADIO_SERVER_NAME=0.0.0.0 \ + GRADIO_THEME=huggingface \ + SYSTEM=spaces + +# Set the working directory to the user's home directory +WORKDIR $HOME/app + +# Copy the current directory contents into the container at $HOME/app setting the owner to the user +COPY --chown=user . $HOME/app + +CMD ["./run.sh"] \ No newline at end of file diff --git a/README.md b/README.md index 478b92f2234f4bfa11cf8cf41222ca8dc99264d0..22e3fc4d1177f1d937bf80c1232a63d0654fd4f3 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,8 @@ --- -title: MASA GroundingDINO -emoji: 🌍 -colorFrom: red -colorTo: pink +title: MASA + GroundingDINO Space +emoji: 🐳 +colorFrom: purple +colorTo: gray sdk: docker -pinned: false -license: mit +app_port: 7860 --- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..ce5b8acb7da699cce2e26418266e2eae1e3f7505 --- /dev/null +++ b/app.py @@ -0,0 +1,71 @@ +import gradio as gr +import os +import tempfile +import subprocess + +# Define the function to call the command line script +def process_video(uploaded_video_path, texts): + with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tmpfile: + output_video_path = tmpfile.name + + command = [ + "python", "demo/video_demo_with_text.py", uploaded_video_path, + "--out", output_video_path, + "--masa_config", "configs/masa-gdino/masa_gdino_swinb_inference.py", + "--masa_checkpoint", "saved_models/masa_models/gdino_masa.pth", + "--texts", texts, + "--score-thr", "0.2", + "--unified", + "--show_fps" + ] + + subprocess.run(command, check=True) + + # Ensure the video is in a compatible format using ffmpeg + converted_output_path = output_video_path.replace('.mp4', '_converted.mp4') + ffmpeg_command = [ + "ffmpeg", "-i", output_video_path, "-c:v", "mpeg4", + "-c:a", "aac", "-b:a", "128k", "-movflags", "+faststart", converted_output_path + ] + subprocess.run(ffmpeg_command, check=True) + + return converted_output_path + +css = """ +#img-display-container { + max-height: 100vh; + } +#img-display-input { + max-height: 80vh; + } +#img-display-output { + max-height: 80vh; + } +""" + +title = "# MASA Track Everything Demo" +description = """ MASA + GroundingDINO on your video files! +Please refer to our [paper](https://arxiv.org/abs/2406.04221), [project page](https://matchinganything.github.io/), or [github](https://github.com/siyuanliii/masa/tree/main?tab=readme-ov-file) for more details.""" + +with gr.Blocks(css=css) as demo: + gr.Markdown(title) + gr.Markdown(description) + gr.Markdown("### Video Object Tracking demo") + + with gr.Row(): + input_video = gr.Video(label="Input Video") + input_texts = gr.Textbox(label="Input Texts") + + submit = gr.Button("Submit") + processed_video = gr.Video(label="Processed Video") + + submit.click(process_video, inputs=[input_video, input_texts], outputs=processed_video) + + example_files = os.listdir('assets/examples_video') + example_files.sort() + example_files = [os.path.join('assets/examples_video', filename) for filename in example_files] + examples = gr.Examples(examples=example_files, inputs=[input_video, input_texts], outputs=processed_video, fn=process_video, cache_examples=True) + +if __name__ == '__main__': + demo.queue().launch() + diff --git a/configs/datasets/bdd/bdd_dataset.py b/configs/datasets/bdd/bdd_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a612faf1ae765c17983f70df8578d954ecb350c2 --- /dev/null +++ b/configs/datasets/bdd/bdd_dataset.py @@ -0,0 +1,44 @@ +# dataset settings +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) + + +test_dataset_tpye = 'BDDVideoDataset' + +test_pipeline = [ + dict( + type='TransformBroadcaster', + transforms=[ + dict(type='LoadImageFromFile'), + dict(type='Resize', scale=(1333, 800), keep_ratio=True), + dict(type='LoadTrackAnnotations') + ]), + dict(type='PackTrackInputs') +] + +val_dataloader = dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + sampler=dict(type='TrackImgSampler'), + dataset=dict( + type=test_dataset_tpye, + ann_file='data/bdd/annotations/box_track_20/box_track_val_cocofmt.json', + data_prefix=dict(img_path='data/bdd/bdd100k/images/track/val/'), + test_mode=True, + pipeline=test_pipeline + )) + +test_dataloader = val_dataloader + +# evaluator +val_evaluator = dict( + type='BDDTETAMetric', + dataset_type=test_dataset_tpye, + format_only=False, + ann_file='data/bdd/annotations/box_track_20/box_track_val_cocofmt.json', + scalabel_gt='data/bdd/annotations/scalabel_gt/box_track_20/val/', + metric=['TETA']) +test_evaluator = val_evaluator + + diff --git a/configs/datasets/tao/tao_dataset_v05.py b/configs/datasets/tao/tao_dataset_v05.py new file mode 100644 index 0000000000000000000000000000000000000000..20ceae2e8a156b9b1c4ea9893772c8a9a8d8ac6c --- /dev/null +++ b/configs/datasets/tao/tao_dataset_v05.py @@ -0,0 +1,43 @@ +# data pipeline + +test_pipeline = [ + dict( + type='TransformBroadcaster', + transforms=[ + dict(type='LoadImageFromFile'), + dict(type='Resize', scale=(1333, 800), keep_ratio=True), + dict(type='LoadTrackAnnotations') + ]), + dict(type='PackTrackInputs') +] + +# dataloader + +test_dataset_tpye = 'Taov05Dataset' + +val_dataloader = dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + # Now we support two ways to test, image_based and video_based + # if you want to use video_based sampling, you can use as follows + sampler=dict(type='TrackImgSampler'), # image-based sampling + dataset=dict( + type=test_dataset_tpye, + ann_file='data/tao/annotations/tao_val_lvis_v05_classes.json', + data_prefix=dict(img_path='data/tao/frames/'), + test_mode=True, + pipeline=test_pipeline + )) +test_dataloader = val_dataloader + +# evaluator +val_evaluator = dict( + type='TaoTETAMetric', + dataset_type=test_dataset_tpye, + format_only=False, + ann_file='data/tao/annotations/tao_val_lvis_v05_classes.json', + metric=['TETA']) +test_evaluator = val_evaluator + + diff --git a/configs/datasets/tao/tao_dataset_v1.py b/configs/datasets/tao/tao_dataset_v1.py new file mode 100644 index 0000000000000000000000000000000000000000..7bb844ba183a1490e0d73a0e406414dc02772966 --- /dev/null +++ b/configs/datasets/tao/tao_dataset_v1.py @@ -0,0 +1,44 @@ +# data pipeline + +test_pipeline = [ + dict( + type='TransformBroadcaster', + transforms=[ + dict(type='LoadImageFromFile'), + dict(type='Resize', scale=(1333, 800), keep_ratio=True), + dict(type='LoadTrackAnnotations') + ]), + dict(type='PackTrackInputs') +] + +# dataloader + +test_dataset_tpye = 'Taov1Dataset' + +val_dataloader = dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + # Now we support two ways to test, image_based and video_based + # if you want to use video_based sampling, you can use as follows + sampler=dict(type='TrackImgSampler'), # image-based sampling + dataset=dict( + type=test_dataset_tpye, + ann_file='data/tao/annotations/tao_val_lvis_v1_classes.json', + data_prefix=dict(img_path='data/tao/frames/'), + test_mode=True, + pipeline=test_pipeline + )) + +test_dataloader = val_dataloader + +# evaluator +val_evaluator = dict( + type='TaoTETAMetric', + dataset_type=test_dataset_tpye, + format_only=False, + ann_file='data/tao/annotations/tao_val_lvis_v1_classes.json', + metric=['TETA']) +test_evaluator = val_evaluator + + diff --git a/configs/default_runtime.py b/configs/default_runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..84897158bca71f618dcacfb4dfcc4cd7bb2f1aa8 --- /dev/null +++ b/configs/default_runtime.py @@ -0,0 +1,23 @@ +default_scope = 'mmdet' +default_hooks = dict( + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=50), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict(type='CheckpointHook', interval=1), + sampler_seed=dict(type='DistSamplerSeedHook'), + visualization=dict(type='DetVisualizationHook')) + +env_cfg = dict( + cudnn_benchmark=False, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl'), +) + +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='DetLocalVisualizer', vis_backends=vis_backends, name='visualizer') +log_processor = dict(type='LogProcessor', window_size=50, by_epoch=True) + +log_level = 'INFO' +load_from = None +resume = False diff --git a/configs/masa-detic/bdd_test/masa_detic_bdd_mot_test.py b/configs/masa-detic/bdd_test/masa_detic_bdd_mot_test.py new file mode 100644 index 0000000000000000000000000000000000000000..2cdb1522066d5bb2d2bef73d051f5c584f6e0c4b --- /dev/null +++ b/configs/masa-detic/bdd_test/masa_detic_bdd_mot_test.py @@ -0,0 +1,224 @@ +_base_ = [ + '../../../projects/Detic_new/configs/detic_centernet2_swin-b_fpn_4x_lvis-base_in21k-lvis-masa.py', + '../../datasets/bdd/bdd_dataset.py', + '../../default_runtime.py' +] +default_scope = 'mmdet' +detector = _base_.model +detector.pop('data_preprocessor') +detector['init_cfg'] = dict( + type='Pretrained', + checkpoint= 'saved_models/tsa_models/detic_centernet2_swin-b_fpn_4x_lvis-base_in21k-lvis-ec91245d.pth' + # noqa: E501 +) +detector['type'] = 'DeticMasa' + +del _base_.model + +model = dict( + type='MASA', + freeze_detector=True, + unified_backbone=True, + load_public_dets = True, + benchmark = 'bdd', + public_det_path = 'results/public_dets/bdd_mot_yolox_dets/', + data_preprocessor=dict( + type='TrackDataPreprocessor', + # Image normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + # Image padding parameters + pad_mask=True, # In instance segmentation, the mask needs to be padded + pad_size_divisor=32), # Padding the image to multiples of 32 + + detector=detector, + masa_adapter=[ + dict( + type='FPN', + in_channels=[256, 512, 1024], + out_channels=256, + norm_cfg=dict(type='SyncBN', requires_grad=True), + num_outs=5), + dict( + type='DeformFusion', + in_channels=256, + out_channels=256, + num_blocks=3)], + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[8, 16, 32, 64, 128]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0) + ), + roi_head=dict( + type='StandardRoIHead', + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[8, 16, 32]), + bbox_head=dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=1, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.02, + nms=dict(type='nms', + iou_threshold=0.5, + class_agnostic=True, + split_thr=100000), + max_per_img=50, + mask_thr_binary=0.5) + # soft-nms is also supported for rcnn testing + # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05) + ), + track_head=dict( + type='MasaTrackHead', + roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[8, 16, 32]), + embed_head=dict( + type='QuasiDenseEmbedHead', + num_convs=4, + num_fcs=1, + embed_channels=256, + norm_cfg=dict(type='GN', num_groups=32), + loss_track=dict(type='UnbiasedContrastLoss', loss_weight=0.25), + loss_track_aux=dict( + type='MarginL2Loss', + neg_pos_ub=3, + pos_margin=0, + neg_margin=0.1, + hard_mining=True, + loss_weight=1.0)), + # loss_bbox=dict(type='L1Loss', loss_weight=1.0), + train_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='CombinedSampler', + num=512, + pos_fraction=0.5, + neg_pos_ub=3, + add_gt_as_proposals=True, + pos_sampler=dict(type='InstanceBalancedPosSampler'), + neg_sampler=dict(type='RandomSampler')))), + tracker=dict( + type='MasaBDDTracker', + init_score_thr=0.5, + obj_score_thr=0.3, + match_score_thr=0.6, + memo_tracklet_frames=10, + memo_backdrop_frames=1, + memo_momentum=0.8, + nms_conf_thr=0.5, + nms_backdrop_iou_thr=0.3, + nms_class_iou_thr=0.7, + with_cats=False, + match_metric='bisoftmax') +) + +# runtime settings +train_dataloader = None +train_cfg = None +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + logger=dict(type='LoggerHook', interval=50), + visualization=dict(type='TrackVisualizationHook', draw=False), + checkpoint=dict(type='CheckpointHook', interval=1), +) + +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='MasaTrackLocalVisualizer', vis_backends=vis_backends, name='visualizer') + +val_dataloader = dict( + dataset=dict( + ann_file='data/bdd/annotations/box_track_20/box_track_val_cocofmt.json', + ) +) +test_dataloader = val_dataloader +val_evaluator = dict( + ann_file='data/bdd/annotations/box_track_20/box_track_val_cocofmt.json', + scalabel_gt='data/bdd/annotations/scalabel_gt/box_track_20/val/', + outfile_prefix='results/detic_masa_trained_bdd_demo', + metric=['TETA', 'HOTA', 'CLEAR'] +) +test_evaluator = val_evaluator \ No newline at end of file diff --git a/configs/masa-detic/bdd_test/masa_detic_bdd_mots_test.py b/configs/masa-detic/bdd_test/masa_detic_bdd_mots_test.py new file mode 100644 index 0000000000000000000000000000000000000000..95e9aaea52479ca394e0d3c16b33fbcf20e03690 --- /dev/null +++ b/configs/masa-detic/bdd_test/masa_detic_bdd_mots_test.py @@ -0,0 +1,227 @@ +_base_ = [ + '../../projects/Detic_new/configs/detic_centernet2_swin-b_fpn_4x_lvis-base_in21k-lvis-masa.py', + '../datasets/bdd/bdd_dataset.py', + '../default_runtime.py' +] +default_scope = 'mmdet' +detector = _base_.model +detector.pop('data_preprocessor') +detector['init_cfg'] = dict( + type='Pretrained', + checkpoint= 'saved_models/tsa_models/detic_centernet2_swin-b_fpn_4x_lvis-base_in21k-lvis-ec91245d.pth' + # noqa: E501 +) +detector['type'] = 'DeticMasa' + +del _base_.model + +model = dict( + type='MASA', + freeze_detector=True, + unified_backbone=True, + load_public_dets = True, + with_segm=True, + benchmark = 'bdd', + public_det_path = 'results/public_dets/bdd_mots_val_uninext_dets/', + data_preprocessor=dict( + type='TrackDataPreprocessor', + # Image normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + # Image padding parameters + pad_mask=True, # In instance segmentation, the mask needs to be padded + pad_size_divisor=32), # Padding the image to multiples of 32 + + detector=detector, + masa_adapter=[ + dict( + type='FPN', + in_channels=[256, 512, 1024], + out_channels=256, + norm_cfg=dict(type='SyncBN', requires_grad=True), + num_outs=5), + dict( + type='DeformFusion', + in_channels=256, + out_channels=256, + num_blocks=3)], + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[8, 16, 32, 64, 128]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0) + ), + roi_head=dict( + type='StandardRoIHead', + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[8, 16, 32]), + bbox_head=dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=1, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.02, + # nms=dict(type='nms', iou_threshold=0.5), + nms=dict(type='nms', + iou_threshold=0.5, + class_agnostic=True, + split_thr=100000), + max_per_img=50, + mask_thr_binary=0.5) + # soft-nms is also supported for rcnn testing + # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05) + ), + track_head=dict( + type='MasaTrackHead', + roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[8, 16, 32]), + embed_head=dict( + type='QuasiDenseEmbedHead', + num_convs=4, + num_fcs=1, + embed_channels=256, + norm_cfg=dict(type='GN', num_groups=32), + loss_track=dict(type='UnbiasedContrastLoss', loss_weight=0.25), + loss_track_aux=dict( + type='MarginL2Loss', + neg_pos_ub=3, + pos_margin=0, + neg_margin=0.1, + hard_mining=True, + loss_weight=1.0)), + # loss_bbox=dict(type='L1Loss', loss_weight=1.0), + train_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='CombinedSampler', + num=512, + pos_fraction=0.5, + neg_pos_ub=3, + add_gt_as_proposals=True, + pos_sampler=dict(type='InstanceBalancedPosSampler'), + neg_sampler=dict(type='RandomSampler')))), + tracker=dict( + type='MasaBDDTracker', + init_score_thr=0.5, + obj_score_thr=0.3, + match_score_thr=0.6, + memo_tracklet_frames=10, + memo_backdrop_frames=1, + memo_momentum=0.8, + nms_conf_thr=0.5, + nms_backdrop_iou_thr=0.3, + nms_class_iou_thr=0.7, + with_cats=False, + match_metric='bisoftmax') +) + +# runtime settings +train_dataloader = None +train_cfg = None +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + logger=dict(type='LoggerHook', interval=50), + visualization=dict(type='TrackVisualizationHook', draw=False), +checkpoint = dict(type='CheckpointHook', interval=1), +) + +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='MasaTrackLocalVisualizer', vis_backends=vis_backends, name='visualizer') + +val_dataloader = dict( + dataset=dict( + ann_file='data/bdd/annotations/seg_track_val_cocofmt.json', + ) +) +test_dataloader = val_dataloader +val_evaluator = dict( + ann_file='data/bdd/annotations/seg_track_val_cocofmt.json', + scalabel_gt='data/bdd/annotations/scalabel_gt/seg_track_20/val/', + outfile_prefix='results/masa_results/masa-groundingdino-release-bdd-mots-test', + metric=['TETA', 'HOTA', 'CLEAR'], + with_mask=True, +) +test_evaluator = val_evaluator \ No newline at end of file diff --git a/configs/masa-detic/open_vocabulary_mot_test/masa_detic_swinb_open_vocabulary_test.py b/configs/masa-detic/open_vocabulary_mot_test/masa_detic_swinb_open_vocabulary_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4dd4194eb873486b7b5ba499d45b6171292d7cd3 --- /dev/null +++ b/configs/masa-detic/open_vocabulary_mot_test/masa_detic_swinb_open_vocabulary_test.py @@ -0,0 +1,236 @@ +_base_ = [ + '../../../projects/Detic_new/configs/detic_centernet2_swin-b_fpn_4x_lvis-base_in21k-lvis-masa.py', + '../../datasets/tao/tao_dataset_v1.py', + '../../default_runtime.py' +] +default_scope = 'mmdet' +detector = _base_.model +detector.pop('data_preprocessor') +detector['init_cfg'] = dict( + type='Pretrained', + checkpoint= 'saved_models/tsa_models/detic_centernet2_swin-b_fpn_4x_lvis-base_in21k-lvis-ec91245d.pth' + # noqa: E501 +) +detector['type'] = 'DeticMasa' +detector['test_cfg'] =dict( + rpn=dict( + score_thr=0.0001, + nms_pre=1000, + max_per_img=256, + nms=dict(type='nms', iou_threshold=0.9), + min_bbox_size=0), + rcnn=dict( + score_thr=0.02, + nms=dict(type='nms', + iou_threshold=0.5, + class_agnostic=True, + split_thr=100000), + max_per_img=50, + mask_thr_binary=0.5) + ) + +del _base_.model + +model = dict( + type='MASA', + freeze_detector=True, + unified_backbone=True, + load_public_dets = False, + benchmark = 'tao', + public_det_path = 'results/public_dets/tao_val_dets/teta_50_internms/detic_tao_val_det/', + data_preprocessor=dict( + type='TrackDataPreprocessor', + # Image normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + # Image padding parameters + pad_mask=True, # In instance segmentation, the mask needs to be padded + pad_size_divisor=32), # Padding the image to multiples of 32 + detector=detector, + masa_adapter=[ + dict( + type='FPN', + in_channels=[256, 512, 1024], + out_channels=256, + norm_cfg=dict(type='SyncBN', requires_grad=True), + num_outs=5), + dict( + type='DeformFusion', + in_channels=256, + out_channels=256, + num_blocks=3)], + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[8, 16, 32, 64, 128]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0) + ), + roi_head=dict( + type='StandardRoIHead', + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[8, 16, 32]), + bbox_head=dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=1, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.5), + max_per_img=100) + # soft-nms is also supported for rcnn testing + # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05) + ), + track_head=dict( + type='MasaTrackHead', + roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[8, 16, 32]), + embed_head=dict( + type='QuasiDenseEmbedHead', + num_convs=4, + num_fcs=1, + embed_channels=256, + norm_cfg=dict(type='GN', num_groups=32), + loss_track=dict(type='UnbiasedContrastLoss', loss_weight=0.25), + loss_track_aux=dict( + type='MarginL2Loss', + neg_pos_ub=3, + pos_margin=0, + neg_margin=0.1, + hard_mining=True, + loss_weight=1.0)), + # loss_bbox=dict(type='L1Loss', loss_weight=1.0), + train_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='CombinedSampler', + num=512, + pos_fraction=0.8, + neg_pos_ub=3, + add_gt_as_proposals=True, + pos_sampler=dict(type='InstanceBalancedPosSampler'), + neg_sampler=dict(type='RandomSampler')))), + tracker=dict( + type='MasaTaoTracker', + init_score_thr=0.0001, + obj_score_thr=0.0001, + match_score_thr=0.5, + memo_tracklet_frames=10, + memo_momentum=0.8, + with_cats=False, + max_distance=-1, + fps=1, + ) +) + +train_dataloader = None +train_cfg = None +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + logger=dict(type='LoggerHook', interval=50), + visualization=dict(type='TrackVisualizationHook', draw=False)) + +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='MasaTrackLocalVisualizer', vis_backends=vis_backends, name='visualizer') + +# custom hooks +custom_hooks = [ + # Synchronize model buffers such as running_mean and running_var in BN + # at the end of each epoch + dict(type='SyncBuffersHook') +] +auto_scale_lr = dict(enable=False, base_batch_size=16) +val_dataloader = dict( + dataset=dict( + ann_file='data/tao/annotations/tao_val_lvis_v1_classes.json', + ) +) +test_dataloader = val_dataloader +val_evaluator = dict( + ann_file='data/tao/annotations/tao_val_lvis_v1_classes.json', + outfile_prefix='results/masa_results/masa-detic-release-ovmot-test', + open_vocabulary=True, +) +test_evaluator = val_evaluator \ No newline at end of file diff --git a/configs/masa-detic/tao_teta_test/masa_detic_swinb_tao_test_detic_dets.py b/configs/masa-detic/tao_teta_test/masa_detic_swinb_tao_test_detic_dets.py new file mode 100644 index 0000000000000000000000000000000000000000..c3b47b9953aeb266083eee32dce612eb34f4f40d --- /dev/null +++ b/configs/masa-detic/tao_teta_test/masa_detic_swinb_tao_test_detic_dets.py @@ -0,0 +1,219 @@ +_base_ = [ + '../../../projects/Detic_new/configs/detic_centernet2_swin-b_fpn_4x_lvis-base_in21k-lvis-masa.py', + '../../datasets/tao/tao_dataset_v1.py', + '../../default_runtime.py' +] +default_scope = 'mmdet' +detector = _base_.model +detector.pop('data_preprocessor') +detector['init_cfg'] = dict( + type='Pretrained', + checkpoint= 'saved_models/tsa_models/detic_centernet2_swin-b_fpn_4x_lvis-base_in21k-lvis-ec91245d.pth' + # noqa: E501 +) +detector['type'] = 'DeticMasa' + +del _base_.model + +model = dict( + type='MASA', + freeze_detector=True, + unified_backbone=True, + load_public_dets = True, + benchmark = 'tao', + public_det_path = 'results/public_dets/tao_val_dets/teta_50_internms/detic_tao_val_det/', + data_preprocessor=dict( + type='TrackDataPreprocessor', + # Image normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + # Image padding parameters + pad_mask=True, # In instance segmentation, the mask needs to be padded + pad_size_divisor=32), # Padding the image to multiples of 32 + detector=detector, + masa_adapter=[ + dict( + type='FPN', + in_channels=[256, 512, 1024], + out_channels=256, + norm_cfg=dict(type='SyncBN', requires_grad=True), + num_outs=5), + dict( + type='DeformFusion', + in_channels=256, + out_channels=256, + num_blocks=3)], + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[8, 16, 32, 64, 128]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0) + ), + roi_head=dict( + type='StandardRoIHead', + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[8, 16, 32]), + bbox_head=dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=1, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.5), + max_per_img=100) + # soft-nms is also supported for rcnn testing + # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05) + ), + track_head=dict( + type='MasaTrackHead', + roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[8, 16, 32]), + embed_head=dict( + type='QuasiDenseEmbedHead', + num_convs=4, + num_fcs=1, + embed_channels=256, + norm_cfg=dict(type='GN', num_groups=32), + loss_track=dict(type='UnbiasedContrastLoss', loss_weight=0.25), + loss_track_aux=dict( + type='MarginL2Loss', + neg_pos_ub=3, + pos_margin=0, + neg_margin=0.1, + hard_mining=True, + loss_weight=1.0)), + # loss_bbox=dict(type='L1Loss', loss_weight=1.0), + train_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='CombinedSampler', + num=512, + pos_fraction=0.8, + neg_pos_ub=3, + add_gt_as_proposals=True, + pos_sampler=dict(type='InstanceBalancedPosSampler'), + neg_sampler=dict(type='RandomSampler')))), + tracker=dict( + type='MasaTaoTracker', + init_score_thr=0.0001, + obj_score_thr=0.0001, + match_score_thr=0.5, + memo_tracklet_frames=10, + memo_momentum=0.8, + with_cats=False, + max_distance=-1, + fps=1, + ) +) + +train_dataloader = None +train_cfg = None +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + logger=dict(type='LoggerHook', interval=50), + visualization=dict(type='TrackVisualizationHook', draw=False)) + +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='MasaTrackLocalVisualizer', vis_backends=vis_backends, name='visualizer') + +# custom hooks +custom_hooks = [ + # Synchronize model buffers such as running_mean and running_var in BN + # at the end of each epoch + dict(type='SyncBuffersHook') +] +auto_scale_lr = dict(enable=False, base_batch_size=16) +val_dataloader = dict( + dataset=dict( + ann_file='data/tao/annotations/tao_val_lvis_v1_classes.json', + ) +) +test_dataloader = val_dataloader +val_evaluator = dict( + ann_file='data/tao/annotations/tao_val_lvis_v1_classes.json', + outfile_prefix='results/masa_results/masa-detic-release-detic-dets-tao-test', +) +test_evaluator = val_evaluator \ No newline at end of file diff --git a/configs/masa-detic/tao_teta_test/masa_detic_swinb_tao_test_teter_swinT_dets.py b/configs/masa-detic/tao_teta_test/masa_detic_swinb_tao_test_teter_swinT_dets.py new file mode 100644 index 0000000000000000000000000000000000000000..3fdee56f26451280f09a21b6b326e474998be37d --- /dev/null +++ b/configs/masa-detic/tao_teta_test/masa_detic_swinb_tao_test_teter_swinT_dets.py @@ -0,0 +1,219 @@ +_base_ = [ + '../../../projects/Detic_new/configs/detic_centernet2_swin-b_fpn_4x_lvis-base_in21k-lvis-masa.py', + '../../datasets/tao/tao_dataset_v05.py', + '../../default_runtime.py' +] +default_scope = 'mmdet' +detector = _base_.model +detector.pop('data_preprocessor') +detector['init_cfg'] = dict( + type='Pretrained', + checkpoint= 'saved_models/tsa_models/detic_centernet2_swin-b_fpn_4x_lvis-base_in21k-lvis-ec91245d.pth' + # noqa: E501 +) +detector['type'] = 'DeticMasa' + +del _base_.model + +model = dict( + type='MASA', + freeze_detector=True, + unified_backbone=True, + load_public_dets = True, + benchmark='tao', + public_det_path = 'results/public_dets/tao_val_dets/teta_50_internms/teter_swinT_tao_val_internms_50/', + data_preprocessor=dict( + type='TrackDataPreprocessor', + # Image normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + # Image padding parameters + pad_mask=True, # In instance segmentation, the mask needs to be padded + pad_size_divisor=32), # Padding the image to multiples of 32 + detector=detector, + masa_adapter=[ + dict( + type='FPN', + in_channels=[256, 512, 1024], + out_channels=256, + norm_cfg=dict(type='SyncBN', requires_grad=True), + num_outs=5), + dict( + type='DeformFusion', + in_channels=256, + out_channels=256, + num_blocks=3)], + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[8, 16, 32, 64, 128]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0) + ), + roi_head=dict( + type='StandardRoIHead', + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[8, 16, 32]), + bbox_head=dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=1, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.5), + max_per_img=100) + # soft-nms is also supported for rcnn testing + # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05) + ), + track_head=dict( + type='MasaTrackHead', + roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[8, 16, 32]), + embed_head=dict( + type='QuasiDenseEmbedHead', + num_convs=4, + num_fcs=1, + embed_channels=256, + norm_cfg=dict(type='GN', num_groups=32), + loss_track=dict(type='UnbiasedContrastLoss', loss_weight=0.25), + loss_track_aux=dict( + type='MarginL2Loss', + neg_pos_ub=3, + pos_margin=0, + neg_margin=0.1, + hard_mining=True, + loss_weight=1.0)), + # loss_bbox=dict(type='L1Loss', loss_weight=1.0), + train_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='CombinedSampler', + num=512, + pos_fraction=0.8, + neg_pos_ub=3, + add_gt_as_proposals=True, + pos_sampler=dict(type='InstanceBalancedPosSampler'), + neg_sampler=dict(type='RandomSampler')))), + tracker=dict( + type='MasaTaoTracker', + init_score_thr=0.0001, + obj_score_thr=0.0001, + match_score_thr=0.5, + memo_tracklet_frames=10, + memo_momentum=0.8, + with_cats=False, + max_distance=-1, + fps=1, + ) +) + +train_dataloader = None +train_cfg = None +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + logger=dict(type='LoggerHook', interval=50), + visualization=dict(type='TrackVisualizationHook', draw=False)) + +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='MasaTrackLocalVisualizer', vis_backends=vis_backends, name='visualizer') + +# custom hooks +custom_hooks = [ + # Synchronize model buffers such as running_mean and running_var in BN + # at the end of each epoch + dict(type='SyncBuffersHook') +] +auto_scale_lr = dict(enable=False, base_batch_size=16) +val_dataloader = dict( + dataset=dict( + ann_file='data/tao/annotations/tao_val_lvis_v05_classes.json' + ) +) +test_dataloader = val_dataloader +val_evaluator = dict( + ann_file='data/tao/annotations/tao_val_lvis_v05_classes.json', + outfile_prefix='results/masa_results/masa-detic-release-test', +) +test_evaluator = val_evaluator \ No newline at end of file diff --git a/configs/masa-gdino/bdd_test/masa_gdino_bdd_mot_test.py b/configs/masa-gdino/bdd_test/masa_gdino_bdd_mot_test.py new file mode 100644 index 0000000000000000000000000000000000000000..6d89eb28c6f766a45f3f18293fdd0c6e2afdfa60 --- /dev/null +++ b/configs/masa-gdino/bdd_test/masa_gdino_bdd_mot_test.py @@ -0,0 +1,226 @@ +_base_ = [ + '../../../projects/grounding_dino/grounding_dino_swin-b_pretrain_mixeddata_masa.py', + '../../datasets/bdd/bdd_dataset.py', + '../../default_runtime.py' +] +default_scope = 'mmdet' +detector = _base_.model +# detector.backbone.update(dict(out_indices=(1, 2, 3))) +detector.pop('data_preprocessor') +detector['init_cfg'] = dict( + type='Pretrained', + checkpoint= 'saved_models/tsa_models/groundingdino_swinb_cogcoor_mmdet-55949c9c.pth' + # noqa: E501 +) +detector['type'] = 'GroundingDINOMasa' + +del _base_.model + +model = dict( + type='MASA', + freeze_detector=True, + unified_backbone=True, + load_public_dets = True, + benchmark = 'bdd', + public_det_path = 'results/public_dets/bdd_mot_yolox_dets/', + data_preprocessor=dict( + type='TrackDataPreprocessor', + # Image normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + # Image padding parameters + pad_mask=False, # In instance segmentation, the mask needs to be padded + pad_size_divisor=1024, # Padding the image to multiples of 32 + ), + detector=detector, + masa_adapter=[ + dict( + type='FPN', + in_channels=[256, 512, 1024], + out_channels=256, + norm_cfg=dict(type='SyncBN', requires_grad=True), + num_outs=5), + dict( + type='DeformFusion', + in_channels=256, + out_channels=256, + num_blocks=3)], + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[8, 16, 32, 64, 128]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0) + ), + roi_head=dict( + type='StandardRoIHead', + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[8, 16, 32]), + bbox_head=dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=1, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.02, + # nms=dict(type='nms', iou_threshold=0.5), + nms=dict(type='nms', + iou_threshold=0.5, + class_agnostic=True, + split_thr=100000), + max_per_img=50, + mask_thr_binary=0.5) + # soft-nms is also supported for rcnn testing + # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05) + ), + track_head=dict( + type='MasaTrackHead', + roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[8, 16, 32]), + embed_head=dict( + type='QuasiDenseEmbedHead', + num_convs=4, + num_fcs=1, + embed_channels=256, + norm_cfg=dict(type='GN', num_groups=32), + loss_track=dict(type='UnbiasedContrastLoss', loss_weight=0.25), + loss_track_aux=dict( + type='MarginL2Loss', + neg_pos_ub=3, + pos_margin=0, + neg_margin=0.1, + hard_mining=True, + loss_weight=1.0)), + # loss_bbox=dict(type='L1Loss', loss_weight=1.0), + train_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='CombinedSampler', + num=512, + pos_fraction=0.5, + neg_pos_ub=3, + add_gt_as_proposals=True, + pos_sampler=dict(type='InstanceBalancedPosSampler'), + neg_sampler=dict(type='RandomSampler')))), + tracker=dict( + type='MasaBDDTracker', + init_score_thr=0.5, + obj_score_thr=0.3, + match_score_thr=0.6, + memo_tracklet_frames=10, + memo_backdrop_frames=1, + memo_momentum=0.8, + nms_conf_thr=0.5, + nms_backdrop_iou_thr=0.3, + nms_class_iou_thr=0.7, + with_cats=False, + match_metric='bisoftmax') +) + +# runtime settings +train_dataloader = None +train_cfg = None +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + logger=dict(type='LoggerHook', interval=50), + visualization=dict(type='TrackVisualizationHook', draw=False), +checkpoint = dict(type='CheckpointHook', interval=1), +) + +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='MasaTrackLocalVisualizer', vis_backends=vis_backends, name='visualizer') + +val_dataloader = dict( + dataset=dict( + ann_file='data/bdd/annotations/box_track_20/box_track_val_cocofmt.json', + ) +) +test_dataloader = val_dataloader +val_evaluator = dict( + ann_file='data/bdd/annotations/box_track_20/box_track_val_cocofmt.json', + scalabel_gt='data/bdd/annotations/scalabel_gt/box_track_20/val/', + outfile_prefix='results/detic_masa_trained_bdd_demo', +metric=['TETA', 'HOTA', 'CLEAR'] +) +test_evaluator = val_evaluator \ No newline at end of file diff --git a/configs/masa-gdino/bdd_test/masa_gdino_bdd_mots_test.py b/configs/masa-gdino/bdd_test/masa_gdino_bdd_mots_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f60bc242fe0f1b3619154dcbfa4b190386e959ab --- /dev/null +++ b/configs/masa-gdino/bdd_test/masa_gdino_bdd_mots_test.py @@ -0,0 +1,227 @@ +_base_ = [ + '../../../projects/grounding_dino/grounding_dino_swin-b_pretrain_mixeddata_masa.py', + '../../datasets/bdd/bdd_dataset.py', + '../../default_runtime.py' +] +default_scope = 'mmdet' +detector = _base_.model +detector.pop('data_preprocessor') +detector['init_cfg'] = dict( + type='Pretrained', + checkpoint= 'saved_models/tsa_models/groundingdino_swinb_cogcoor_mmdet-55949c9c.pth' + # noqa: E501 +) +detector['type'] = 'GroundingDINOMasa' + +del _base_.model + +model = dict( + type='MASA', + freeze_detector=True, + unified_backbone=True, + load_public_dets = True, + with_segm=True, + benchmark = 'bdd', + public_det_path = 'results/public_dets/bdd_mots_val_uninext_dets/', + data_preprocessor=dict( + type='TrackDataPreprocessor', + # Image normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + # Image padding parameters + pad_mask=False, # In instance segmentation, the mask needs to be padded + pad_size_divisor=1024, # Padding the image to multiples of 32 + ), + detector=detector, + masa_adapter=[ + dict( + type='FPN', + in_channels=[256, 512, 1024], + out_channels=256, + norm_cfg=dict(type='SyncBN', requires_grad=True), + num_outs=5), + dict( + type='DeformFusion', + in_channels=256, + out_channels=256, + num_blocks=3)], + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[8, 16, 32, 64, 128]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0) + ), + roi_head=dict( + type='StandardRoIHead', + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[8, 16, 32]), + bbox_head=dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=1, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.02, + # nms=dict(type='nms', iou_threshold=0.5), + nms=dict(type='nms', + iou_threshold=0.5, + class_agnostic=True, + split_thr=100000), + max_per_img=50, + mask_thr_binary=0.5) + # soft-nms is also supported for rcnn testing + # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05) + ), + track_head=dict( + type='MasaTrackHead', + roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[8, 16, 32]), + embed_head=dict( + type='QuasiDenseEmbedHead', + num_convs=4, + num_fcs=1, + embed_channels=256, + norm_cfg=dict(type='GN', num_groups=32), + loss_track=dict(type='UnbiasedContrastLoss', loss_weight=0.25), + loss_track_aux=dict( + type='MarginL2Loss', + neg_pos_ub=3, + pos_margin=0, + neg_margin=0.1, + hard_mining=True, + loss_weight=1.0)), + # loss_bbox=dict(type='L1Loss', loss_weight=1.0), + train_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='CombinedSampler', + num=512, + pos_fraction=0.5, + neg_pos_ub=3, + add_gt_as_proposals=True, + pos_sampler=dict(type='InstanceBalancedPosSampler'), + neg_sampler=dict(type='RandomSampler')))), + tracker=dict( + type='MasaBDDTracker', + init_score_thr=0.5, + obj_score_thr=0.3, + match_score_thr=0.6, + memo_tracklet_frames=10, + memo_backdrop_frames=1, + memo_momentum=0.8, + nms_conf_thr=0.5, + nms_backdrop_iou_thr=0.3, + nms_class_iou_thr=0.7, + with_cats=False, + match_metric='bisoftmax') +) + +# runtime settings +train_dataloader = None +train_cfg = None +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + logger=dict(type='LoggerHook', interval=50), + visualization=dict(type='TrackVisualizationHook', draw=False), +checkpoint = dict(type='CheckpointHook', interval=1), +) + +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='MasaTrackLocalVisualizer', vis_backends=vis_backends, name='visualizer') + +val_dataloader = dict( + dataset=dict( + ann_file='data/bdd/annotations/seg_track_val_cocofmt.json', + ) +) +test_dataloader = val_dataloader +val_evaluator = dict( + ann_file='data/bdd/annotations/seg_track_val_cocofmt.json', + scalabel_gt='data/bdd/annotations/scalabel_gt/seg_track_20/val/', + outfile_prefix='results/masa_results/masa-groundingdino-release-bdd-mots-test', + metric=['TETA', 'HOTA', 'CLEAR'], + with_mask=True, +) +test_evaluator = val_evaluator \ No newline at end of file diff --git a/configs/masa-gdino/masa_gdino_swinb_inference.py b/configs/masa-gdino/masa_gdino_swinb_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..ffefffe68d58cd5f5dddb32f2a28c71b0fdaeb35 --- /dev/null +++ b/configs/masa-gdino/masa_gdino_swinb_inference.py @@ -0,0 +1,216 @@ +_base_ = [ + '../../projects/grounding_dino/grounding_dino_swin-b_pretrain_mixeddata_masa.py', + '../default_runtime.py' +] +default_scope = 'mmdet' +detector = _base_.model +detector.pop('data_preprocessor') +detector['init_cfg'] = dict( + type='Pretrained', + checkpoint= 'saved_models/tsa_models/groundingdino_swinb_cogcoor_mmdet-55949c9c.pth' + # noqa: E501 +) +detector['type'] = 'GroundingDINOMasa' + +del _base_.model + +model = dict( + type='MASA', + freeze_detector=True, + unified_backbone=True, + load_public_dets = False, + data_preprocessor=dict( + type='TrackDataPreprocessor', + # Image normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + # Image padding parameters + pad_mask=False, # In instance segmentation, the mask needs to be padded + pad_size_divisor=32), # Padding the image to multiples of 32 + detector=detector, + masa_adapter=[ + dict( + type='FPN', + in_channels=[256, 512, 1024], + out_channels=256, + norm_cfg=dict(type='SyncBN', requires_grad=True), + num_outs=5), + dict( + type='DeformFusion', + in_channels=256, + out_channels=256, + num_blocks=3)], + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[8, 16, 32, 64, 128]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0) + ), + roi_head=dict( + type='StandardRoIHead', + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[8, 16, 32]), + bbox_head=dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=1, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.02, + # nms=dict(type='nms', iou_threshold=0.5), + nms=dict(type='nms', + iou_threshold=0.5, + class_agnostic=True, + split_thr=100000), + max_per_img=50, + mask_thr_binary=0.5) + # soft-nms is also supported for rcnn testing + # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05) + ), + track_head=dict( + type='MasaTrackHead', + roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[8, 16, 32]), + embed_head=dict( + type='QuasiDenseEmbedHead', + num_convs=4, + num_fcs=1, + embed_channels=256, + norm_cfg=dict(type='GN', num_groups=32), + loss_track=dict(type='UnbiasedContrastLoss', loss_weight=0.25), + loss_track_aux=dict( + type='MarginL2Loss', + neg_pos_ub=3, + pos_margin=0, + neg_margin=0.1, + hard_mining=True, + loss_weight=1.0)), + train_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='CombinedSampler', + num=512, + pos_fraction=0.5, + neg_pos_ub=3, + add_gt_as_proposals=True, + pos_sampler=dict(type='InstanceBalancedPosSampler'), + neg_sampler=dict(type='RandomSampler')))), + tracker=dict( + type='MasaTaoTracker', + init_score_thr=0.1, + obj_score_thr=0.01, + match_score_thr=0.5, + memo_tracklet_frames=10, + memo_momentum=0.8, + with_cats=False, + max_distance=100, + fps=30, + ) +) + +inference_pipeline = [ + dict( + type='TransformBroadcaster', + transforms=[ + dict( + type='Resize', + scale=(1333, 800), + keep_ratio=True), + ]), + dict(type='PackTrackInputs') +] + +# runtime settings +train_cfg = None +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + logger=dict(type='LoggerHook', interval=50), + visualization=dict(type='TrackVisualizationHook', draw=False), +checkpoint = dict(type='CheckpointHook', interval=1), +) + +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='MasaTrackLocalVisualizer', vis_backends=vis_backends, name='visualizer') + diff --git a/configs/masa-gdino/masa_gdino_swinb_plug_and_play.py b/configs/masa-gdino/masa_gdino_swinb_plug_and_play.py new file mode 100644 index 0000000000000000000000000000000000000000..4bafa6ef11fd109e645176b82878a476cc7f52d2 --- /dev/null +++ b/configs/masa-gdino/masa_gdino_swinb_plug_and_play.py @@ -0,0 +1,218 @@ +_base_ = [ + '../../projects/grounding_dino/grounding_dino_swin-b_pretrain_mixeddata_masa.py', + '../default_runtime.py' +] +default_scope = 'mmdet' +detector = _base_.model +detector.pop('data_preprocessor') +detector['init_cfg'] = dict( + type='Pretrained', + checkpoint= 'saved_models/tsa_models/groundingdino_swinb_cogcoor_mmdet-55949c9c.pth' + # noqa: E501 +) +detector['type'] = 'GroundingDINOMasa' + +del _base_.model + +model = dict( + type='MASA', + freeze_detector=True, + unified_backbone=True, + load_public_dets = False, + given_dets = True, + data_preprocessor=dict( + type='TrackDataPreprocessor', + # Image normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + # Image padding parameters + pad_mask=False, # In instance segmentation, the mask needs to be padded + pad_size_divisor=32), # Padding the image to multiples of 32 + detector=detector, + masa_adapter=[ + dict( + type='FPN', + in_channels=[256, 512, 1024], + out_channels=256, + norm_cfg=dict(type='SyncBN', requires_grad=True), + num_outs=5), + dict( + type='DeformFusion', + in_channels=256, + out_channels=256, + num_blocks=3)], + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[8, 16, 32, 64, 128]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0) + ), + roi_head=dict( + type='StandardRoIHead', + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[8, 16, 32]), + bbox_head=dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=1, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.02, + # nms=dict(type='nms', iou_threshold=0.5), + nms=dict(type='nms', + iou_threshold=0.5, + class_agnostic=True, + split_thr=100000), + max_per_img=50, + mask_thr_binary=0.5) + # soft-nms is also supported for rcnn testing + # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05) + ), + track_head=dict( + type='MasaTrackHead', + roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[8, 16, 32]), + embed_head=dict( + type='QuasiDenseEmbedHead', + num_convs=4, + num_fcs=1, + embed_channels=256, + norm_cfg=dict(type='GN', num_groups=32), + loss_track=dict(type='UnbiasedContrastLoss', loss_weight=0.25), + loss_track_aux=dict( + type='MarginL2Loss', + neg_pos_ub=3, + pos_margin=0, + neg_margin=0.1, + hard_mining=True, + loss_weight=1.0)), + train_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='CombinedSampler', + num=512, + pos_fraction=0.5, + neg_pos_ub=3, + add_gt_as_proposals=True, + pos_sampler=dict(type='InstanceBalancedPosSampler'), + neg_sampler=dict(type='RandomSampler')))), + tracker=dict( + type='MasaTaoTracker', + init_score_thr=0.1, + obj_score_thr=0.01, + match_score_thr=0.5, + memo_tracklet_frames=10, + memo_momentum=0.8, + with_cats=False, + max_distance=100, + fps=30, + ) +) + +inference_pipeline = [ + dict( + type='TransformBroadcaster', + transforms=[ + dict( + type='Resize', + scale=(1333, 800), + keep_ratio=True), + ]), + dict(type='PackTrackInputs') +] + + +# runtime settings +train_cfg = None +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + logger=dict(type='LoggerHook', interval=50), + visualization=dict(type='TrackVisualizationHook', draw=False), +checkpoint = dict(type='CheckpointHook', interval=1), +) + +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='MasaTrackLocalVisualizer', vis_backends=vis_backends, name='visualizer') + diff --git a/configs/masa-gdino/open_vocabulary_mot_test/masa_gdino_swinb_open_vocabulary_test.py b/configs/masa-gdino/open_vocabulary_mot_test/masa_gdino_swinb_open_vocabulary_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0c0e7b6027c0244cc6a449a3f4f34eea5b67d08b --- /dev/null +++ b/configs/masa-gdino/open_vocabulary_mot_test/masa_gdino_swinb_open_vocabulary_test.py @@ -0,0 +1,236 @@ +_base_ = [ + '../../../projects/grounding_dino/grounding_dino_swin-b_pretrain_mixeddata_masa.py', + '../../datasets/tao/tao_dataset_v1.py', + '../../default_runtime.py' +] +default_scope = 'mmdet' +detector = _base_.model +# detector.backbone.update(dict(out_indices=(1, 2, 3))) +detector.pop('data_preprocessor') +detector['init_cfg'] = dict( + type='Pretrained', + checkpoint= 'saved_models/tsa_models/groundingdino_swinb_cogcoor_mmdet-55949c9c.pth' + # noqa: E501 +) +detector['type'] = 'GroundingDINOMasa' + +del _base_.model + +model = dict( + type='MASA', + freeze_detector=True, + unified_backbone=True, + load_public_dets = True, + benchmark = 'tao', + public_det_path = 'results/public_dets/tao_val_dets/teta_50_internms/detic_tao_val_det/', + data_preprocessor=dict( + type='TrackDataPreprocessor', + # Image normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + # Image padding parameters + pad_mask=False, # In instance segmentation, the mask needs to be padded + pad_size_divisor=1024, # Padding the image to multiples of 32 + ), + detector=detector, + masa_adapter=[ + dict( + type='FPN', + in_channels=[256, 512, 1024], + out_channels=256, + norm_cfg=dict(type='SyncBN', requires_grad=True), + num_outs=5), + dict( + type='DeformFusion', + in_channels=256, + out_channels=256, + num_blocks=3)], + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[8, 16, 32, 64, 128]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0) + ), + roi_head=dict( + type='StandardRoIHead', + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[8, 16, 32]), + bbox_head=dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=1, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.02, + # nms=dict(type='nms', iou_threshold=0.5), + nms=dict(type='nms', + iou_threshold=0.5, + class_agnostic=True, + split_thr=100000), + max_per_img=50, + mask_thr_binary=0.5) + # soft-nms is also supported for rcnn testing + # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05) + ), + track_head=dict( + type='MasaTrackHead', + roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[8, 16, 32]), + embed_head=dict( + type='QuasiDenseEmbedHead', + num_convs=4, + num_fcs=1, + embed_channels=256, + norm_cfg=dict(type='GN', num_groups=32), + loss_track=dict(type='UnbiasedContrastLoss', loss_weight=0.25), + loss_track_aux=dict( + type='MarginL2Loss', + neg_pos_ub=3, + pos_margin=0, + neg_margin=0.1, + hard_mining=True, + loss_weight=1.0)), + train_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='CombinedSampler', + num=512, + pos_fraction=0.5, + neg_pos_ub=3, + add_gt_as_proposals=True, + pos_sampler=dict(type='InstanceBalancedPosSampler'), + neg_sampler=dict(type='RandomSampler')))), + tracker=dict( + type='MasaTaoTracker', + init_score_thr=0.0001, + obj_score_thr=0.0001, + match_score_thr=0.5, + memo_tracklet_frames=10, + memo_momentum=0.8, + with_cats=False, + max_distance=-1, + fps=1, + ) +) + +test_pipeline = [ + dict( + type='TransformBroadcaster', + transforms=[ + dict(type='LoadImageFromFile'), + dict( + type='Resize', + scale=(1024, 1024), + keep_ratio=True), + dict(type='LoadTrackAnnotations') + ]), + dict(type='PackTrackInputs') +] + +# runtime settings +train_dataloader = None +train_cfg = None +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + logger=dict(type='LoggerHook', interval=50), + visualization=dict(type='TrackVisualizationHook', draw=False), +checkpoint = dict(type='CheckpointHook', interval=1), +) + +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='MasaTrackLocalVisualizer', vis_backends=vis_backends, name='visualizer') + +val_dataloader = dict( + dataset=dict( + ann_file='data/tao/annotations/tao_val_lvis_v1_classes.json', + pipeline=test_pipeline, + ) +) +test_dataloader = val_dataloader +test_evaluator = dict( + ann_file='data/tao/annotations/tao_val_lvis_v1_classes.json', + outfile_prefix='results/masa_results/masa-groundingdino-release-ovmot-test', + open_vocabulary=True, +) diff --git a/configs/masa-gdino/tao_teta_test/masa_gdino_swinb_tao_test_detic_dets.py b/configs/masa-gdino/tao_teta_test/masa_gdino_swinb_tao_test_detic_dets.py new file mode 100644 index 0000000000000000000000000000000000000000..8c2685ddd4501740453b758a79c8a6f092dec357 --- /dev/null +++ b/configs/masa-gdino/tao_teta_test/masa_gdino_swinb_tao_test_detic_dets.py @@ -0,0 +1,235 @@ +_base_ = [ + '../../../projects/grounding_dino/grounding_dino_swin-b_pretrain_mixeddata_masa.py', + '../../datasets/tao/tao_dataset_v1.py', + '../../default_runtime.py' +] +default_scope = 'mmdet' +detector = _base_.model +# detector.backbone.update(dict(out_indices=(1, 2, 3))) +detector.pop('data_preprocessor') +detector['init_cfg'] = dict( + type='Pretrained', + checkpoint= 'saved_models/tsa_models/groundingdino_swinb_cogcoor_mmdet-55949c9c.pth' + # noqa: E501 +) +detector['type'] = 'GroundingDINOMasa' + +del _base_.model + +model = dict( + type='MASA', + freeze_detector=True, + unified_backbone=True, + load_public_dets = True, + benchmark = 'tao', + public_det_path = 'results/public_dets/tao_val_dets/teta_50_internms/detic_tao_val_det/', + data_preprocessor=dict( + type='TrackDataPreprocessor', + # Image normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + # Image padding parameters + pad_mask=False, # In instance segmentation, the mask needs to be padded + pad_size_divisor=1024, # Padding the image to multiples of 32 + ), + detector=detector, + masa_adapter=[ + dict( + type='FPN', + in_channels=[256, 512, 1024], + out_channels=256, + norm_cfg=dict(type='SyncBN', requires_grad=True), + num_outs=5), + dict( + type='DeformFusion', + in_channels=256, + out_channels=256, + num_blocks=3)], + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[8, 16, 32, 64, 128]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0) + ), + roi_head=dict( + type='StandardRoIHead', + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[8, 16, 32]), + bbox_head=dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=1, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.02, + # nms=dict(type='nms', iou_threshold=0.5), + nms=dict(type='nms', + iou_threshold=0.5, + class_agnostic=True, + split_thr=100000), + max_per_img=50, + mask_thr_binary=0.5) + # soft-nms is also supported for rcnn testing + # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05) + ), + track_head=dict( + type='MasaTrackHead', + roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[8, 16, 32]), + embed_head=dict( + type='QuasiDenseEmbedHead', + num_convs=4, + num_fcs=1, + embed_channels=256, + norm_cfg=dict(type='GN', num_groups=32), + loss_track=dict(type='UnbiasedContrastLoss', loss_weight=0.25), + loss_track_aux=dict( + type='MarginL2Loss', + neg_pos_ub=3, + pos_margin=0, + neg_margin=0.1, + hard_mining=True, + loss_weight=1.0)), + train_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='CombinedSampler', + num=512, + pos_fraction=0.5, + neg_pos_ub=3, + add_gt_as_proposals=True, + pos_sampler=dict(type='InstanceBalancedPosSampler'), + neg_sampler=dict(type='RandomSampler')))), + tracker=dict( + type='MasaTaoTracker', + init_score_thr=0.0001, + obj_score_thr=0.0001, + match_score_thr=0.5, + memo_tracklet_frames=10, + memo_momentum=0.8, + with_cats=False, + max_distance=-1, + fps=1, + ) +) + +test_pipeline = [ + dict( + type='TransformBroadcaster', + transforms=[ + dict(type='LoadImageFromFile'), + dict( + type='Resize', + scale=(1024, 1024), + keep_ratio=True), + dict(type='LoadTrackAnnotations') + ]), + dict(type='PackTrackInputs') +] + +# runtime settings +train_dataloader = None +train_cfg = None +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + logger=dict(type='LoggerHook', interval=50), + visualization=dict(type='TrackVisualizationHook', draw=False), +checkpoint = dict(type='CheckpointHook', interval=1), +) + +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='MasaTrackLocalVisualizer', vis_backends=vis_backends, name='visualizer') + +val_dataloader = dict( + dataset=dict( + ann_file='data/tao/annotations/tao_val_lvis_v1_classes.json', + pipeline=test_pipeline, + ) +) +test_dataloader = val_dataloader +test_evaluator = dict( + ann_file='data/tao/annotations/tao_val_lvis_v1_classes.json', +outfile_prefix='results/masa_results/masa-groundingdino-release_detic_dets-test', +) diff --git a/configs/masa-gdino/tao_teta_test/masa_gdino_swinb_tao_test_teter_swinT_dets.py b/configs/masa-gdino/tao_teta_test/masa_gdino_swinb_tao_test_teter_swinT_dets.py new file mode 100644 index 0000000000000000000000000000000000000000..99d6cc8595cff419a12e5f71eb1350982697f3d4 --- /dev/null +++ b/configs/masa-gdino/tao_teta_test/masa_gdino_swinb_tao_test_teter_swinT_dets.py @@ -0,0 +1,240 @@ +_base_ = [ + '../../../projects/grounding_dino/grounding_dino_swin-b_pretrain_mixeddata_masa.py', + '../../datasets/tao/tao_dataset_v05.py', + '../../default_runtime.py' +] +default_scope = 'mmdet' +detector = _base_.model +# detector.backbone.update(dict(out_indices=(1, 2, 3))) +detector.pop('data_preprocessor') +detector['init_cfg'] = dict( + type='Pretrained', + checkpoint= 'saved_models/tsa_models/groundingdino_swinb_cogcoor_mmdet-55949c9c.pth' + # noqa: E501 +) +detector['type'] = 'GroundingDINOMasa' + +del _base_.model + +model = dict( + type='MASA', + freeze_detector=True, + unified_backbone=True, + load_public_dets = True, + public_det_path = 'results/public_dets/tao_val_dets/teta_50_internms/teter_swinT_tao_val_internms_50/', + data_preprocessor=dict( + type='TrackDataPreprocessor', + # Image normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + # Image padding parameters + pad_mask=False, # In instance segmentation, the mask needs to be padded + pad_size_divisor=1024, # Padding the image to multiples of 32 + ), + detector=detector, + masa_adapter=[ + dict( + type='FPN', + in_channels=[256, 512, 1024], + out_channels=256, + norm_cfg=dict(type='SyncBN', requires_grad=True), + num_outs=5), + dict( + type='DeformFusion', + in_channels=256, + out_channels=256, + num_blocks=3)], + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[8, 16, 32, 64, 128]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0) + ), + roi_head=dict( + type='StandardRoIHead', + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[8, 16, 32]), + bbox_head=dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=1, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.02, + # nms=dict(type='nms', iou_threshold=0.5), + nms=dict(type='nms', + iou_threshold=0.5, + class_agnostic=True, + split_thr=100000), + max_per_img=50, + mask_thr_binary=0.5) + # soft-nms is also supported for rcnn testing + # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05) + ), + track_head=dict( + type='MasaTrackHead', + roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[8, 16, 32]), + embed_head=dict( + type='QuasiDenseEmbedHead', + num_convs=4, + num_fcs=1, + embed_channels=256, + norm_cfg=dict(type='GN', num_groups=32), + loss_track=dict(type='UnbiasedContrastLoss', loss_weight=0.25), + loss_track_aux=dict( + type='MarginL2Loss', + neg_pos_ub=3, + pos_margin=0, + neg_margin=0.1, + hard_mining=True, + loss_weight=1.0)), + train_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='CombinedSampler', + num=512, + pos_fraction=0.5, + neg_pos_ub=3, + add_gt_as_proposals=True, + pos_sampler=dict(type='InstanceBalancedPosSampler'), + neg_sampler=dict(type='RandomSampler')))), + tracker=dict( + type='MasaTaoTracker', + init_score_thr=0.0001, + obj_score_thr=0.0001, + match_score_thr=0.5, + memo_tracklet_frames=10, + memo_momentum=0.8, + with_cats=False, + max_distance=-1, + fps=1, + ) +) + +test_pipeline = [ + dict( + type='TransformBroadcaster', + transforms=[ + dict(type='LoadImageFromFile'), + dict( + type='Resize', + scale=(1024, 1024), + keep_ratio=True), + dict(type='LoadTrackAnnotations') + ]), + dict(type='PackTrackInputs') +] + + +train_dataloader = None +train_cfg = None +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + logger=dict(type='LoggerHook', interval=50), + visualization=dict(type='TrackVisualizationHook', draw=False)) + +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='MasaTrackLocalVisualizer', vis_backends=vis_backends, name='visualizer') + +# custom hooks +custom_hooks = [ + # Synchronize model buffers such as running_mean and running_var in BN + # at the end of each epoch + dict(type='SyncBuffersHook') +] +auto_scale_lr = dict(enable=False, base_batch_size=16) +val_dataloader = dict( + dataset=dict( + ann_file='data/tao/annotations/tao_val_lvis_v05_classes.json', + pipeline=test_pipeline, + ) +) +test_dataloader = val_dataloader +val_evaluator = dict( + ann_file='data/tao/annotations/tao_val_lvis_v05_classes.json', + outfile_prefix='results/masa_results/masa-groundingdino-release-tao-teter-test', +) +test_evaluator = val_evaluator \ No newline at end of file diff --git a/configs/masa-one/bdd_test/masa_r50_bdd_mot_test.py b/configs/masa-one/bdd_test/masa_r50_bdd_mot_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1124727c80d33d481f4134002d4ace82ace34ed7 --- /dev/null +++ b/configs/masa-one/bdd_test/masa_r50_bdd_mot_test.py @@ -0,0 +1,235 @@ +_base_ = [ + '../../default_runtime.py', + '../../datasets/bdd/bdd_dataset.py', +] +default_scope = 'mmdet' + +model = dict( + type='MASA', + unified_backbone=False, + load_public_dets = True, + use_masa_backbone = True, + benchmark='bdd', + public_det_path='results/public_dets/bdd_mot_yolox_dets/', + data_preprocessor=dict( + type='TrackDataPreprocessor', + # Image normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + # Image padding parameters + pad_mask=True, # In instance segmentation, the mask needs to be padded + pad_size_divisor=32), # Padding the image to multiples of 32 + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + norm_cfg=dict(type='SyncBN', requires_grad=True), + norm_eval=True, + style='caffe',), + masa_adapter=[ + dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + norm_cfg=dict(type='SyncBN', requires_grad=True), + num_outs=5), + dict( + type='DeformFusion', + in_channels=256, + out_channels=256, + num_blocks=3)], + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0)), + roi_head=dict( + type='StandardRoIHead', + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='Shared4Conv1FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=1, + norm_cfg=dict(type='SyncBN', requires_grad=True), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.02, + # nms=dict(type='nms', iou_threshold=0.5), + nms=dict(type='nms', + iou_threshold=0.5, + class_agnostic=True, + split_thr=100000), + max_per_img=50, + mask_thr_binary=0.5) + # soft-nms is also supported for rcnn testing + # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05) + ), + track_head=dict( + type='MasaTrackHead', + roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + embed_head=dict( + type='QuasiDenseEmbedHead', + num_convs=4, + num_fcs=1, + embed_channels=256, + norm_cfg=dict(type='GN', num_groups=32), + loss_track=dict(type='UnbiasedContrastLoss', loss_weight=0.25), + loss_track_aux=dict( + type='MarginL2Loss', + neg_pos_ub=3, + pos_margin=0, + neg_margin=0.1, + hard_mining=True, + loss_weight=1.0)), + train_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='CombinedSampler', + num=512, + pos_fraction=0.5, + neg_pos_ub=3, + add_gt_as_proposals=True, + pos_sampler=dict(type='InstanceBalancedPosSampler'), + neg_sampler=dict(type='RandomSampler')))), + tracker=dict( + type='MasaBDDTracker', + init_score_thr=0.5, + obj_score_thr=0.3, + match_score_thr=0.6, + memo_tracklet_frames=10, + memo_backdrop_frames=1, + memo_momentum=0.8, + nms_conf_thr=0.5, + nms_backdrop_iou_thr=0.3, + nms_class_iou_thr=0.7, + with_cats=False, + match_metric='bisoftmax') +) + +test_pipeline = [ + dict( + type='TransformBroadcaster', + transforms=[ + dict(type='LoadImageFromFile'), + dict( + type='Resize', + scale=(1024, 1024), + keep_ratio=True), + dict(type='LoadTrackAnnotations') + ]), + dict(type='PackTrackInputs') +] + +# runtime settings +train_dataloader = None +train_cfg = None +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + logger=dict(type='LoggerHook', interval=50), + visualization=dict(type='TrackVisualizationHook', draw=False), +checkpoint = dict(type='CheckpointHook', interval=1), +) + +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='MasaTrackLocalVisualizer', vis_backends=vis_backends, name='visualizer') + +val_dataloader = dict( + dataset=dict( + ann_file='data/bdd/annotations/box_track_20/box_track_val_cocofmt.json', + pipeline=test_pipeline, + ) +) +test_dataloader = val_dataloader +val_evaluator = dict( + ann_file='data/bdd/annotations/box_track_20/box_track_val_cocofmt.json', + scalabel_gt='data/bdd/annotations/scalabel_gt/box_track_20/val/', + outfile_prefix='results/masa_results/masa-r50-release-bdd-mot-test', + metric=['TETA', 'HOTA', 'CLEAR'] +) +test_evaluator = val_evaluator \ No newline at end of file diff --git a/configs/masa-one/bdd_test/masa_r50_bdd_mots_test.py b/configs/masa-one/bdd_test/masa_r50_bdd_mots_test.py new file mode 100644 index 0000000000000000000000000000000000000000..47c1996174abace0e30a91511a04ef5a1ff90c64 --- /dev/null +++ b/configs/masa-one/bdd_test/masa_r50_bdd_mots_test.py @@ -0,0 +1,238 @@ +_base_ = [ + '../../default_runtime.py', + '../../datasets/bdd/bdd_dataset.py', +] +default_scope = 'mmdet' + +model = dict( + type='MASA', + unified_backbone=False, + load_public_dets = True, + use_masa_backbone = True, + benchmark='bdd', + with_segm=True, + public_det_path = 'results/public_dets/bdd_mots_val_uninext_dets/', + data_preprocessor=dict( + type='TrackDataPreprocessor', + # Image normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + # Image padding parameters + pad_mask=True, # In instance segmentation, the mask needs to be padded + pad_size_divisor=32), # Padding the image to multiples of 32 + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + norm_cfg=dict(type='SyncBN', requires_grad=True), + norm_eval=True, + style='caffe',), + masa_adapter=[ + dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + norm_cfg=dict(type='SyncBN', requires_grad=True), + num_outs=5), + dict( + type='DeformFusion', + in_channels=256, + out_channels=256, + num_blocks=3)], + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0)), + roi_head=dict( + type='StandardRoIHead', + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='Shared4Conv1FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=1, + norm_cfg=dict(type='SyncBN', requires_grad=True), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.02, + # nms=dict(type='nms', iou_threshold=0.5), + nms=dict(type='nms', + iou_threshold=0.5, + class_agnostic=True, + split_thr=100000), + max_per_img=50, + mask_thr_binary=0.5) + # soft-nms is also supported for rcnn testing + # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05) + ), + track_head=dict( + type='MasaTrackHead', + roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + embed_head=dict( + type='QuasiDenseEmbedHead', + num_convs=4, + num_fcs=1, + embed_channels=256, + norm_cfg=dict(type='GN', num_groups=32), + loss_track=dict(type='UnbiasedContrastLoss', loss_weight=0.25), + loss_track_aux=dict( + type='MarginL2Loss', + neg_pos_ub=3, + pos_margin=0, + neg_margin=0.1, + hard_mining=True, + loss_weight=1.0)), + train_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='CombinedSampler', + num=512, + pos_fraction=0.5, + neg_pos_ub=3, + add_gt_as_proposals=True, + pos_sampler=dict(type='InstanceBalancedPosSampler'), + neg_sampler=dict(type='RandomSampler')))), + tracker=dict( + type='MasaBDDTracker', + init_score_thr=0.5, + obj_score_thr=0.3, + match_score_thr=0.6, + memo_tracklet_frames=10, + memo_backdrop_frames=1, + memo_momentum=0.8, + nms_conf_thr=0.5, + nms_backdrop_iou_thr=0.3, + nms_class_iou_thr=0.7, + with_cats=False, + match_metric='bisoftmax') +) + +test_pipeline = [ + dict( + type='TransformBroadcaster', + transforms=[ + dict(type='LoadImageFromFile'), + dict( + type='Resize', + scale=(1024, 1024), + keep_ratio=True), + dict(type='LoadTrackAnnotations') + ]), + dict(type='PackTrackInputs') +] + +# runtime settings +train_dataloader = None +train_cfg = None +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + logger=dict(type='LoggerHook', interval=50), + visualization=dict(type='TrackVisualizationHook', draw=False), +checkpoint = dict(type='CheckpointHook', interval=1), +) + +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='MasaTrackLocalVisualizer', vis_backends=vis_backends, name='visualizer') + +val_dataloader = dict( + dataset=dict( + ann_file='data/bdd/annotations/seg_track_val_cocofmt.json', + pipeline=test_pipeline, + ) +) + +test_dataloader = val_dataloader +val_evaluator = dict( + ann_file='data/bdd/annotations/seg_track_val_cocofmt.json', + scalabel_gt='data/bdd/annotations/scalabel_gt/seg_track_20/val/', + outfile_prefix='results/masa_results/masa-r50-release-bdd-mots-test', + metric=['TETA', 'HOTA', 'CLEAR'], + with_mask=True, +) +test_evaluator = val_evaluator \ No newline at end of file diff --git a/configs/masa-one/masa_r50_plug_and_play.py b/configs/masa-one/masa_r50_plug_and_play.py new file mode 100644 index 0000000000000000000000000000000000000000..2fb20f94e800b10aee6ae1fbb0afa3bac7f53d87 --- /dev/null +++ b/configs/masa-one/masa_r50_plug_and_play.py @@ -0,0 +1,214 @@ +_base_ = [ + '../default_runtime.py' +] +default_scope = 'mmdet' + +model = dict( + type='MASA', + unified_backbone=False, + load_public_dets = False, + use_masa_backbone = True, + given_dets = True, + data_preprocessor=dict( + type='TrackDataPreprocessor', + # Image normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + # Image padding parameters + pad_mask=True, # In instance segmentation, the mask needs to be padded + pad_size_divisor=32), # Padding the image to multiples of 32 + # detector=detector, + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + norm_cfg=dict(type='SyncBN', requires_grad=True), + norm_eval=True, + style='caffe',), + masa_adapter=[ + dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + norm_cfg=dict(type='SyncBN', requires_grad=True), + num_outs=5), + dict( + type='DeformFusion', + in_channels=256, + out_channels=256, + num_blocks=3)], + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0)), + roi_head=dict( + type='StandardRoIHead', + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='Shared4Conv1FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=1, + norm_cfg=dict(type='SyncBN', requires_grad=True), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.02, + # nms=dict(type='nms', iou_threshold=0.5), + nms=dict(type='nms', + iou_threshold=0.5, + class_agnostic=True, + split_thr=100000), + max_per_img=50, + mask_thr_binary=0.5) + # soft-nms is also supported for rcnn testing + # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05) + ), + track_head=dict( + type='QuasiDenseTrackHead', + roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + embed_head=dict( + type='QuasiDenseEmbedHead', + num_convs=4, + num_fcs=1, + embed_channels=256, + norm_cfg=dict(type='GN', num_groups=32), + loss_track=dict(type='UnbiasedContrastLoss', loss_weight=0.25), + loss_track_aux=dict( + type='MarginL2Loss', + neg_pos_ub=3, + pos_margin=0, + neg_margin=0.1, + hard_mining=True, + loss_weight=1.0)), + train_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='CombinedSampler', + num=512, + pos_fraction=0.5, + neg_pos_ub=3, + add_gt_as_proposals=True, + pos_sampler=dict(type='InstanceBalancedPosSampler'), + neg_sampler=dict(type='RandomSampler')))), + tracker=dict( + type='MasaTaoTracker', + init_score_thr=0.1, + obj_score_thr=0.01, + match_score_thr=0.5, + memo_tracklet_frames=10, + memo_momentum=0.8, + with_cats=False, + max_distance=100, + fps=30, + ) +) + +inference_pipeline = [ + dict( + type='TransformBroadcaster', + transforms=[ + dict( + type='Resize', + scale=(1024, 1024), + keep_ratio=True), + ]), + dict(type='PackTrackInputs') +] + +# runtime settings +train_cfg = None +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + logger=dict(type='LoggerHook', interval=50), + visualization=dict(type='TrackVisualizationHook', draw=False), + checkpoint=dict(type='CheckpointHook', interval=12), +) + +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='MasaTrackLocalVisualizer', vis_backends=vis_backends, name='visualizer') diff --git a/configs/masa-one/open_vocabulary_mot_test/masa_r50_open_vocabulary_test.py b/configs/masa-one/open_vocabulary_mot_test/masa_r50_open_vocabulary_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f0e32233933312b099b4a2a65a384e1eae176190 --- /dev/null +++ b/configs/masa-one/open_vocabulary_mot_test/masa_r50_open_vocabulary_test.py @@ -0,0 +1,231 @@ +_base_ = [ + '../../default_runtime.py', + '../../datasets/tao/tao_dataset_v1.py', +] +default_scope = 'mmdet' + +model = dict( + type='MASA', + unified_backbone=False, + load_public_dets = True, + use_masa_backbone = True, + benchmark = 'tao', + public_det_path = 'results/public_dets/tao_val_dets/teta_50_internms/detic_tao_val_det/', + data_preprocessor=dict( + type='TrackDataPreprocessor', + # Image normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + # Image padding parameters + pad_mask=True, # In instance segmentation, the mask needs to be padded + pad_size_divisor=32), # Padding the image to multiples of 32 + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + norm_cfg=dict(type='SyncBN', requires_grad=True), + norm_eval=True, + style='caffe',), + masa_adapter=[ + dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + norm_cfg=dict(type='SyncBN', requires_grad=True), + num_outs=5), + dict( + type='DeformFusion', + in_channels=256, + out_channels=256, + num_blocks=3)], + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0)), + roi_head=dict( + type='StandardRoIHead', + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='Shared4Conv1FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=1, + norm_cfg=dict(type='SyncBN', requires_grad=True), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.02, + # nms=dict(type='nms', iou_threshold=0.5), + nms=dict(type='nms', + iou_threshold=0.5, + class_agnostic=True, + split_thr=100000), + max_per_img=50, + mask_thr_binary=0.5) + # soft-nms is also supported for rcnn testing + # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05) + ), + track_head=dict( + type='MasaTrackHead', + roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + embed_head=dict( + type='QuasiDenseEmbedHead', + num_convs=4, + num_fcs=1, + embed_channels=256, + norm_cfg=dict(type='GN', num_groups=32), + loss_track=dict(type='UnbiasedContrastLoss', loss_weight=0.25), + loss_track_aux=dict( + type='MarginL2Loss', + neg_pos_ub=3, + pos_margin=0, + neg_margin=0.1, + hard_mining=True, + loss_weight=1.0)), + train_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='CombinedSampler', + num=512, + pos_fraction=0.5, + neg_pos_ub=3, + add_gt_as_proposals=True, + pos_sampler=dict(type='InstanceBalancedPosSampler'), + neg_sampler=dict(type='RandomSampler')))), + tracker=dict( + type='MasaTaoTracker', + init_score_thr=0.0001, + obj_score_thr=0.0001, + match_score_thr=0.5, + memo_tracklet_frames=10, + memo_momentum=0.8, + with_cats=False, + max_distance=-1, + fps=1, + ) +) + +test_pipeline = [ + dict( + type='TransformBroadcaster', + transforms=[ + dict(type='LoadImageFromFile'), + dict( + type='Resize', + scale=(1024, 1024), + keep_ratio=True), + dict(type='LoadTrackAnnotations') + ]), + dict(type='PackTrackInputs') +] + +# runtime settings +train_dataloader = None +train_cfg = None +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + logger=dict(type='LoggerHook', interval=50), + visualization=dict(type='TrackVisualizationHook', draw=False), +checkpoint = dict(type='CheckpointHook', interval=1), +) + +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='MasaTrackLocalVisualizer', vis_backends=vis_backends, name='visualizer') + +val_dataloader = dict( + dataset=dict( + ann_file='data/tao/annotations/tao_val_lvis_v1_classes.json', + pipeline=test_pipeline, + ) +) +test_dataloader = val_dataloader +test_evaluator = dict( + ann_file='data/tao/annotations/tao_val_lvis_v1_classes.json', + outfile_prefix='results/masa_results/masa-r50-release-ovmot-test', + open_vocabulary=True, +) diff --git a/configs/masa-one/tao_teta_test/masa_r50_tao_test_detic_dets.py b/configs/masa-one/tao_teta_test/masa_r50_tao_test_detic_dets.py new file mode 100644 index 0000000000000000000000000000000000000000..b29120d7f4d6fd58e055e39403a85df5e4a2b106 --- /dev/null +++ b/configs/masa-one/tao_teta_test/masa_r50_tao_test_detic_dets.py @@ -0,0 +1,230 @@ +_base_ = [ + '../../default_runtime.py', + '../../datasets/tao/tao_dataset_v1.py', +] +default_scope = 'mmdet' + +model = dict( + type='MASA', + unified_backbone=False, + load_public_dets = True, + use_masa_backbone = True, + benchmark = 'tao', + public_det_path = 'results/public_dets/tao_val_dets/teta_50_internms/detic_tao_val_det/', + data_preprocessor=dict( + type='TrackDataPreprocessor', + # Image normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + # Image padding parameters + pad_mask=True, # In instance segmentation, the mask needs to be padded + pad_size_divisor=32), # Padding the image to multiples of 32 + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + norm_cfg=dict(type='SyncBN', requires_grad=True), + norm_eval=True, + style='caffe',), + masa_adapter=[ + dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + norm_cfg=dict(type='SyncBN', requires_grad=True), + num_outs=5), + dict( + type='DeformFusion', + in_channels=256, + out_channels=256, + num_blocks=3)], + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0)), + roi_head=dict( + type='StandardRoIHead', + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='Shared4Conv1FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=1, + norm_cfg=dict(type='SyncBN', requires_grad=True), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.02, + # nms=dict(type='nms', iou_threshold=0.5), + nms=dict(type='nms', + iou_threshold=0.5, + class_agnostic=True, + split_thr=100000), + max_per_img=50, + mask_thr_binary=0.5) + # soft-nms is also supported for rcnn testing + # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05) + ), + track_head=dict( + type='MasaTrackHead', + roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + embed_head=dict( + type='QuasiDenseEmbedHead', + num_convs=4, + num_fcs=1, + embed_channels=256, + norm_cfg=dict(type='GN', num_groups=32), + loss_track=dict(type='UnbiasedContrastLoss', loss_weight=0.25), + loss_track_aux=dict( + type='MarginL2Loss', + neg_pos_ub=3, + pos_margin=0, + neg_margin=0.1, + hard_mining=True, + loss_weight=1.0)), + train_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='CombinedSampler', + num=512, + pos_fraction=0.5, + neg_pos_ub=3, + add_gt_as_proposals=True, + pos_sampler=dict(type='InstanceBalancedPosSampler'), + neg_sampler=dict(type='RandomSampler')))), + tracker=dict( + type='MasaTaoTracker', + init_score_thr=0.0001, + obj_score_thr=0.0001, + match_score_thr=0.5, + memo_tracklet_frames=10, + memo_momentum=0.8, + with_cats=False, + max_distance=-1, + fps=1, + ) +) + +test_pipeline = [ + dict( + type='TransformBroadcaster', + transforms=[ + dict(type='LoadImageFromFile'), + dict( + type='Resize', + scale=(1024, 1024), + keep_ratio=True), + dict(type='LoadTrackAnnotations') + ]), + dict(type='PackTrackInputs') +] + +# runtime settings +train_dataloader = None +train_cfg = None +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + logger=dict(type='LoggerHook', interval=50), + visualization=dict(type='TrackVisualizationHook', draw=False), +checkpoint = dict(type='CheckpointHook', interval=1), +) + +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='MasaTrackLocalVisualizer', vis_backends=vis_backends, name='visualizer') + +val_dataloader = dict( + dataset=dict( + ann_file='data/tao/annotations/tao_val_lvis_v1_classes.json', + pipeline=test_pipeline, + ) +) +test_dataloader = val_dataloader +test_evaluator = dict( + ann_file='data/tao/annotations/tao_val_lvis_v1_classes.json', + outfile_prefix='results/masa_results/masa-r50-release_detic_dets-test', +) diff --git a/configs/masa-one/tao_teta_test/masa_r50_tao_test_teter_swinT_dets.py b/configs/masa-one/tao_teta_test/masa_r50_tao_test_teter_swinT_dets.py new file mode 100644 index 0000000000000000000000000000000000000000..93ea41c92a8940ca59d06c146aeea8003a15f608 --- /dev/null +++ b/configs/masa-one/tao_teta_test/masa_r50_tao_test_teter_swinT_dets.py @@ -0,0 +1,230 @@ +_base_ = [ + '../../default_runtime.py', + '../../datasets/tao/tao_dataset_v05.py', +] +default_scope = 'mmdet' + +model = dict( + type='MASA', + unified_backbone=False, + load_public_dets = True, + use_masa_backbone = True, + benchmark = 'tao', + public_det_path = 'results/public_dets/tao_val_dets/teta_50_internms/teter_swinT_tao_val_internms_50/', + data_preprocessor=dict( + type='TrackDataPreprocessor', + # Image normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + # Image padding parameters + pad_mask=True, # In instance segmentation, the mask needs to be padded + pad_size_divisor=32), # Padding the image to multiples of 32 + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + norm_cfg=dict(type='SyncBN', requires_grad=True), + norm_eval=True, + style='caffe',), + masa_adapter=[ + dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + norm_cfg=dict(type='SyncBN', requires_grad=True), + num_outs=5), + dict( + type='DeformFusion', + in_channels=256, + out_channels=256, + num_blocks=3)], + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0)), + roi_head=dict( + type='StandardRoIHead', + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='Shared4Conv1FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=1, + norm_cfg=dict(type='SyncBN', requires_grad=True), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.02, + # nms=dict(type='nms', iou_threshold=0.5), + nms=dict(type='nms', + iou_threshold=0.5, + class_agnostic=True, + split_thr=100000), + max_per_img=50, + mask_thr_binary=0.5) + # soft-nms is also supported for rcnn testing + # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05) + ), + track_head=dict( + type='MasaTrackHead', + roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + embed_head=dict( + type='QuasiDenseEmbedHead', + num_convs=4, + num_fcs=1, + embed_channels=256, + norm_cfg=dict(type='GN', num_groups=32), + loss_track=dict(type='UnbiasedContrastLoss', loss_weight=0.25), + loss_track_aux=dict( + type='MarginL2Loss', + neg_pos_ub=3, + pos_margin=0, + neg_margin=0.1, + hard_mining=True, + loss_weight=1.0)), + train_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='CombinedSampler', + num=512, + pos_fraction=0.5, + neg_pos_ub=3, + add_gt_as_proposals=True, + pos_sampler=dict(type='InstanceBalancedPosSampler'), + neg_sampler=dict(type='RandomSampler')))), + tracker=dict( + type='MasaTaoTracker', + init_score_thr=0.0001, + obj_score_thr=0.0001, + match_score_thr=0.5, + memo_tracklet_frames=10, + memo_momentum=0.8, + with_cats=False, + max_distance=-1, + fps=1, + ) +) + +test_pipeline = [ + dict( + type='TransformBroadcaster', + transforms=[ + dict(type='LoadImageFromFile'), + dict( + type='Resize', + scale=(1024, 1024), + keep_ratio=True), + dict(type='LoadTrackAnnotations') + ]), + dict(type='PackTrackInputs') +] + +# runtime settings +train_dataloader = None +train_cfg = None +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + logger=dict(type='LoggerHook', interval=50), + visualization=dict(type='TrackVisualizationHook', draw=False), +checkpoint = dict(type='CheckpointHook', interval=1), +) + +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='MasaTrackLocalVisualizer', vis_backends=vis_backends, name='visualizer') + +val_dataloader = dict( + dataset=dict( + ann_file='data/tao/annotations/tao_val_lvis_v05_classes.json', + pipeline=test_pipeline, + ) +) +test_dataloader = val_dataloader +test_evaluator = dict( + ann_file='data/tao/annotations/tao_val_lvis_v05_classes.json', + outfile_prefix='results/masa_results/masa-r50-release-tao-teter-test', +) diff --git a/configs/masa-sam/bdd_test/masa_sam_vitb_bdd_mot_test.py b/configs/masa-sam/bdd_test/masa_sam_vitb_bdd_mot_test.py new file mode 100644 index 0000000000000000000000000000000000000000..14a9f8613cb9c0f8deb828c2d2af2af58a265ba8 --- /dev/null +++ b/configs/masa-sam/bdd_test/masa_sam_vitb_bdd_mot_test.py @@ -0,0 +1,245 @@ +_base_ = [ + '../sam-vitb.py', + '../../datasets/bdd/bdd_dataset.py', + '../../default_runtime.py' +] +default_scope = 'mmdet' +detector = _base_.model +detector['init_cfg'] = dict( + type='Pretrained', + checkpoint= 'saved_models/pretrain_weights/sam_vit_b_01ec64_mmdet.pth' + # noqa: E501 +) +detector['type'] = 'SamMasa' + +del _base_.model + +model = dict( + type='MASA', + freeze_detector=True, + unified_backbone=True, + load_public_dets = True, + benchmark = 'bdd', + public_det_path = 'results/public_dets/bdd_mot_yolox_dets/', + data_preprocessor=dict( + type='TrackDataPreprocessor', + # Image normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + # Image padding parameters + pad_mask=False, # In instance segmentation, the mask needs to be padded + pad_size_divisor=1024), # Padding the image to multiples of 32 + detector=detector, + masa_adapter=[ + dict( + type='SimpleFPN', + in_channels=[768, 768, 768, 768], + out_channels=256, + use_residual=True, + num_outs=5), + dict( + type='DyHead', + in_channels=256, + out_channels=256, + num_blocks=3) + ], + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0) + ), + roi_head=dict( + type='StandardRoIHead', + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=1, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.02, + # nms=dict(type='nms', iou_threshold=0.5), + nms=dict(type='nms', + iou_threshold=0.5, + class_agnostic=True, + split_thr=100000), + max_per_img=50, + mask_thr_binary=0.5) + # soft-nms is also supported for rcnn testing + # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05) + ), + track_head=dict( + type='MasaTrackHead', + roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + embed_head=dict( + type='QuasiDenseEmbedHead', + num_convs=4, + num_fcs=1, + embed_channels=256, + norm_cfg=dict(type='GN', num_groups=32), + loss_track=dict(type='UnbiasedContrastLoss', loss_weight=0.25), + loss_track_aux=dict( + type='MarginL2Loss', + neg_pos_ub=3, + pos_margin=0, + neg_margin=0.1, + hard_mining=True, + loss_weight=1.0)), + train_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='CombinedSampler', + num=512, + pos_fraction=0.5, + neg_pos_ub=3, + add_gt_as_proposals=True, + pos_sampler=dict(type='InstanceBalancedPosSampler'), + neg_sampler=dict(type='RandomSampler')))), + tracker=dict( + type='MasaBDDTracker', + init_score_thr=0.5, + obj_score_thr=0.3, + match_score_thr=0.6, + memo_tracklet_frames=10, + memo_backdrop_frames=1, + memo_momentum=0.8, + nms_conf_thr=0.5, + nms_backdrop_iou_thr=0.3, + nms_class_iou_thr=0.7, + with_cats=False, + match_metric='bisoftmax') +) + +test_pipeline = [ + dict( + type='TransformBroadcaster', + transforms=[ + dict(type='LoadImageFromFile'), + dict( + type='Resize', + scale=(1024, 1024), + keep_ratio=True), + dict(type='LoadTrackAnnotations') + ]), + dict(type='PackTrackInputs') +] + +# runtime settings +train_dataloader = None +train_cfg = None +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + logger=dict(type='LoggerHook', interval=50), + visualization=dict(type='TrackVisualizationHook', draw=False), + checkpoint=dict(type='CheckpointHook', interval=12), +) + +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='MasaTrackLocalVisualizer', vis_backends=vis_backends, name='visualizer') + +# custom hooks +custom_hooks = [ + # Synchronize model buffers such as running_mean and running_var in BN + # at the end of each epoch + dict(type='SyncBuffersHook') +] +auto_scale_lr = dict(enable=False, base_batch_size=16) +val_dataloader = dict( + dataset=dict( + ann_file='data/bdd/annotations/box_track_20/box_track_val_cocofmt.json', + pipeline=test_pipeline, + ) +) +test_dataloader = val_dataloader +val_evaluator = dict( + ann_file='data/bdd/annotations/box_track_20/box_track_val_cocofmt.json', + scalabel_gt='data/bdd/annotations/scalabel_gt/box_track_20/val/', + outfile_prefix='results/masa_results/masa-sam-vitb-bdd-mot-test', + metric=['TETA', 'HOTA', 'CLEAR'] +) +test_evaluator = val_evaluator \ No newline at end of file diff --git a/configs/masa-sam/bdd_test/masa_sam_vitb_bdd_mots_test.py b/configs/masa-sam/bdd_test/masa_sam_vitb_bdd_mots_test.py new file mode 100644 index 0000000000000000000000000000000000000000..36537d446482e63f9b6140b92ccaddd65bbdc558 --- /dev/null +++ b/configs/masa-sam/bdd_test/masa_sam_vitb_bdd_mots_test.py @@ -0,0 +1,241 @@ +_base_ = [ + '../sam-vitb.py', + '../../datasets/bdd/bdd_dataset.py', + '../../default_runtime.py' +] +default_scope = 'mmdet' +detector = _base_.model +detector['init_cfg'] = dict( + type='Pretrained', + checkpoint= 'saved_models/pretrain_weights/sam_vit_b_01ec64_mmdet.pth' + # noqa: E501 +) +detector['type'] = 'SamMasa' + +del _base_.model + +model = dict( + type='MASA', + freeze_detector=True, + unified_backbone=True, + load_public_dets = True, + with_segm=True, + benchmark = 'bdd', + public_det_path = 'results/public_dets/bdd_mots_val_uninext_dets/', + data_preprocessor=dict( + type='TrackDataPreprocessor', + # Image normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + # Image padding parameters + pad_mask=False, # In instance segmentation, the mask needs to be padded + pad_size_divisor=1024), # Padding the image to multiples of 32 + detector=detector, + masa_adapter=[ + dict( + type='SimpleFPN', + in_channels=[768, 768, 768, 768], + out_channels=256, + use_residual=True, + num_outs=5), + dict( + type='DyHead', + in_channels=256, + out_channels=256, + num_blocks=3) + ], + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0) + ), + roi_head=dict( + type='StandardRoIHead', + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=1, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.02, + # nms=dict(type='nms', iou_threshold=0.5), + nms=dict(type='nms', + iou_threshold=0.5, + class_agnostic=True, + split_thr=100000), + max_per_img=50, + mask_thr_binary=0.5) + # soft-nms is also supported for rcnn testing + # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05) + ), + track_head=dict( + type='MasaTrackHead', + roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + embed_head=dict( + type='QuasiDenseEmbedHead', + num_convs=4, + num_fcs=1, + embed_channels=256, + norm_cfg=dict(type='GN', num_groups=32), + loss_track=dict(type='UnbiasedContrastLoss', loss_weight=0.25), + loss_track_aux=dict( + type='MarginL2Loss', + neg_pos_ub=3, + pos_margin=0, + neg_margin=0.1, + hard_mining=True, + loss_weight=1.0)), + train_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='CombinedSampler', + num=512, + pos_fraction=0.5, + neg_pos_ub=3, + add_gt_as_proposals=True, + pos_sampler=dict(type='InstanceBalancedPosSampler'), + neg_sampler=dict(type='RandomSampler')))), + tracker=dict( + type='MasaBDDTracker', + init_score_thr=0.5, + obj_score_thr=0.3, + match_score_thr=0.6, + memo_tracklet_frames=10, + memo_backdrop_frames=1, + memo_momentum=0.8, + nms_conf_thr=0.5, + nms_backdrop_iou_thr=0.3, + nms_class_iou_thr=0.7, + with_cats=False, + match_metric='bisoftmax') +) + +test_pipeline = [ + dict( + type='TransformBroadcaster', + transforms=[ + dict(type='LoadImageFromFile'), + dict( + type='Resize', + scale=(1024, 1024), + keep_ratio=True), + dict(type='LoadTrackAnnotations') + ]), + dict(type='PackTrackInputs') +] + +# runtime settings +train_dataloader = None +train_cfg = None +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + logger=dict(type='LoggerHook', interval=50), + visualization=dict(type='TrackVisualizationHook', draw=False), +checkpoint = dict(type='CheckpointHook', interval=1), +) + +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='MasaTrackLocalVisualizer', vis_backends=vis_backends, name='visualizer') + +val_dataloader = dict( + dataset=dict( + ann_file='data/bdd/annotations/seg_track_val_cocofmt.json', + pipeline=test_pipeline, + ) +) + +test_dataloader = val_dataloader +val_evaluator = dict( + ann_file='data/bdd/annotations/seg_track_val_cocofmt.json', + scalabel_gt='data/bdd/annotations/scalabel_gt/seg_track_20/val/', + outfile_prefix='results/masa_results/masa-sam-vitb-bdd-mots-test', + metric=['TETA', 'HOTA', 'CLEAR'], + with_mask=True, +) +test_evaluator = val_evaluator \ No newline at end of file diff --git a/configs/masa-sam/bdd_test/masa_sam_vith_bdd_mot_test.py b/configs/masa-sam/bdd_test/masa_sam_vith_bdd_mot_test.py new file mode 100644 index 0000000000000000000000000000000000000000..6453a5ba626ac97eb86f945afebb2db2f781238a --- /dev/null +++ b/configs/masa-sam/bdd_test/masa_sam_vith_bdd_mot_test.py @@ -0,0 +1,246 @@ +_base_ = [ + '../sam-vith.py', + '../../datasets/bdd/bdd_dataset.py', + '../../default_runtime.py' +] +default_scope = 'mmdet' +detector = _base_.model +detector['init_cfg'] = dict( + type='Pretrained', + checkpoint= 'saved_models/pretrain_weights/sam_vit_h_4b8939_mmdet.pth' + # noqa: E501 +) +detector['type'] = 'SamMasa' + +del _base_.model + +model = dict( + type='MASA', + freeze_detector=True, + unified_backbone=True, + load_public_dets = True, + benchmark = 'bdd', + public_det_path = 'results/public_dets/bdd_mot_yolox_dets/', + data_preprocessor=dict( + type='TrackDataPreprocessor', + # Image normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + # Image padding parameters + pad_mask=False, # In instance segmentation, the mask needs to be padded + pad_size_divisor=1024), # Padding the image to multiples of 32 + detector=detector, + masa_adapter=[ + dict( + type='SimpleFPN', + in_channels=[1280, 1280, 1280, 1280], + out_channels=256, + use_residual=True, + num_outs=5), + dict( + type='DyHead', + in_channels=256, + out_channels=256, + num_blocks=3) + ], + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0) + ), + roi_head=dict( + type='StandardRoIHead', + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=1, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.02, + # nms=dict(type='nms', iou_threshold=0.5), + nms=dict(type='nms', + iou_threshold=0.5, + class_agnostic=True, + split_thr=100000), + max_per_img=50, + mask_thr_binary=0.5) + # soft-nms is also supported for rcnn testing + # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05) + ), + track_head=dict( + type='MasaTrackHead', + roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + embed_head=dict( + type='QuasiDenseEmbedHead', + num_convs=4, + num_fcs=1, + embed_channels=256, + norm_cfg=dict(type='GN', num_groups=32), + loss_track=dict(type='UnbiasedContrastLoss', loss_weight=0.25), + loss_track_aux=dict( + type='MarginL2Loss', + neg_pos_ub=3, + pos_margin=0, + neg_margin=0.1, + hard_mining=True, + loss_weight=1.0)), + # loss_bbox=dict(type='L1Loss', loss_weight=1.0), + train_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='CombinedSampler', + num=512, + pos_fraction=0.5, + neg_pos_ub=3, + add_gt_as_proposals=True, + pos_sampler=dict(type='InstanceBalancedPosSampler'), + neg_sampler=dict(type='RandomSampler')))), + tracker=dict( + type='MasaBDDTracker', + init_score_thr=0.5, + obj_score_thr=0.3, + match_score_thr=0.6, + memo_tracklet_frames=10, + memo_backdrop_frames=1, + memo_momentum=0.8, + nms_conf_thr=0.5, + nms_backdrop_iou_thr=0.3, + nms_class_iou_thr=0.7, + with_cats=False, + match_metric='bisoftmax') +) + +test_pipeline = [ + dict( + type='TransformBroadcaster', + transforms=[ + dict(type='LoadImageFromFile'), + dict( + type='Resize', + scale=(1024, 1024), + keep_ratio=True), + dict(type='LoadTrackAnnotations') + ]), + dict(type='PackTrackInputs') +] + + +train_dataloader = None +train_cfg = None +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + logger=dict(type='LoggerHook', interval=50), + visualization=dict(type='TrackVisualizationHook', draw=False), + checkpoint=dict(type='CheckpointHook', interval=12), +) + +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='MasaTrackLocalVisualizer', vis_backends=vis_backends, name='visualizer') + +# custom hooks +custom_hooks = [ + # Synchronize model buffers such as running_mean and running_var in BN + # at the end of each epoch + dict(type='SyncBuffersHook') +] +auto_scale_lr = dict(enable=False, base_batch_size=16) +val_dataloader = dict( + dataset=dict( + ann_file='data/bdd/annotations/box_track_20/box_track_val_cocofmt.json', + pipeline=test_pipeline, + ) +) +test_dataloader = val_dataloader +val_evaluator = dict( + ann_file='data/bdd/annotations/box_track_20/box_track_val_cocofmt.json', + scalabel_gt='data/bdd/annotations/scalabel_gt/box_track_20/val/', + outfile_prefix='results/masa_results/masa-sam-vith-bdd-mot-test', +metric=['TETA', 'HOTA', 'CLEAR'] +) +test_evaluator = val_evaluator \ No newline at end of file diff --git a/configs/masa-sam/bdd_test/masa_sam_vith_bdd_mots_test.py b/configs/masa-sam/bdd_test/masa_sam_vith_bdd_mots_test.py new file mode 100644 index 0000000000000000000000000000000000000000..24ce6e55ab8eb69a81df33912eb230d2ff807b79 --- /dev/null +++ b/configs/masa-sam/bdd_test/masa_sam_vith_bdd_mots_test.py @@ -0,0 +1,240 @@ +_base_ = [ + '../sam-vith.py', + '../../datasets/bdd/bdd_dataset.py', + '../../default_runtime.py' +] +default_scope = 'mmdet' +detector = _base_.model +detector['init_cfg'] = dict( + type='Pretrained', + checkpoint= 'saved_models/pretrain_weights/sam_vit_h_4b8939_mmdet.pth' + # noqa: E501 +) +detector['type'] = 'SamMasa' + +del _base_.model + +model = dict( + type='MASA', + freeze_detector=True, + unified_backbone=True, + load_public_dets = True, + with_segm=True, + benchmark = 'bdd', + public_det_path = 'results/public_dets/bdd_mots_val_uninext_dets/', + data_preprocessor=dict( + type='TrackDataPreprocessor', + # Image normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + # Image padding parameters + pad_mask=False, # In instance segmentation, the mask needs to be padded + pad_size_divisor=1024), # Padding the image to multiples of 32 + detector=detector, + masa_adapter=[ + dict( + type='SimpleFPN', + in_channels=[1280, 1280, 1280, 1280], + out_channels=256, + use_residual=True, + num_outs=5), + dict( + type='DyHead', + in_channels=256, + out_channels=256, + num_blocks=3) + ], + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0) + ), + roi_head=dict( + type='StandardRoIHead', + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=1, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.02, + # nms=dict(type='nms', iou_threshold=0.5), + nms=dict(type='nms', + iou_threshold=0.5, + class_agnostic=True, + split_thr=100000), + max_per_img=50, + mask_thr_binary=0.5) + # soft-nms is also supported for rcnn testing + # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05) + ), + track_head=dict( + type='MasaTrackHead', + roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + embed_head=dict( + type='QuasiDenseEmbedHead', + num_convs=4, + num_fcs=1, + embed_channels=256, + norm_cfg=dict(type='GN', num_groups=32), + loss_track=dict(type='UnbiasedContrastLoss', loss_weight=0.25), + loss_track_aux=dict( + type='MarginL2Loss', + neg_pos_ub=3, + pos_margin=0, + neg_margin=0.1, + hard_mining=True, + loss_weight=1.0)), + # loss_bbox=dict(type='L1Loss', loss_weight=1.0), + train_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='CombinedSampler', + num=512, + pos_fraction=0.5, + neg_pos_ub=3, + add_gt_as_proposals=True, + pos_sampler=dict(type='InstanceBalancedPosSampler'), + neg_sampler=dict(type='RandomSampler')))), + tracker=dict( + type='MasaBDDTracker', + init_score_thr=0.5, + obj_score_thr=0.3, + match_score_thr=0.6, + memo_tracklet_frames=10, + memo_backdrop_frames=1, + memo_momentum=0.8, + nms_conf_thr=0.5, + nms_backdrop_iou_thr=0.3, + nms_class_iou_thr=0.7, + with_cats=False, + match_metric='bisoftmax') +) + +test_pipeline = [ + dict( + type='TransformBroadcaster', + transforms=[ + dict(type='LoadImageFromFile'), + dict( + type='Resize', + scale=(1024, 1024), + keep_ratio=True), + dict(type='LoadTrackAnnotations') + ]), + dict(type='PackTrackInputs') +] + + +train_dataloader = None +train_cfg = None +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + logger=dict(type='LoggerHook', interval=50), + visualization=dict(type='TrackVisualizationHook', draw=False), +checkpoint = dict(type='CheckpointHook', interval=1), +) + +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='MasaTrackLocalVisualizer', vis_backends=vis_backends, name='visualizer') + +val_dataloader = dict( + dataset=dict( + ann_file='data/bdd/annotations/seg_track_val_cocofmt.json', + pipeline=test_pipeline, + ) +) + +test_dataloader = val_dataloader +val_evaluator = dict( + outfile_prefix='results/masa_results/masa-sam-vith-bdd-mots-test', + metric=['TETA'], + with_mask=True, +) +test_evaluator = val_evaluator \ No newline at end of file diff --git a/configs/masa-sam/open_vocabulary_mot_test/masa_sam_vitb_open_vocabulary_test.py b/configs/masa-sam/open_vocabulary_mot_test/masa_sam_vitb_open_vocabulary_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f9abc1c7454ccf87c676860e41063ba5cdc86260 --- /dev/null +++ b/configs/masa-sam/open_vocabulary_mot_test/masa_sam_vitb_open_vocabulary_test.py @@ -0,0 +1,233 @@ +_base_ = [ + '../sam-vitb.py', + '../../datasets/tao/tao_dataset_v1.py', + '../../default_runtime.py' +] +default_scope = 'mmdet' +detector = _base_.model +detector['init_cfg'] = dict( + type='Pretrained', + checkpoint= 'saved_models/pretrain_weights/sam_vit_b_01ec64_mmdet.pth' + # noqa: E501 +) +detector['type'] = 'SamMasa' + +del _base_.model + +model = dict( + type='MASA', + freeze_detector=True, + unified_backbone=True, + load_public_dets = True, + benchmark = 'tao', + public_det_path = 'results/public_dets/tao_val_dets/teta_50_internms/detic_tao_val_det/', + data_preprocessor=dict( + type='TrackDataPreprocessor', + # Image normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + # Image padding parameters + pad_mask=False, # In instance segmentation, the mask needs to be padded + pad_size_divisor=1024), # Padding the image to multiples of 32 + detector=detector, + masa_adapter=[ + dict( + type='SimpleFPN', + in_channels=[768, 768, 768, 768], + out_channels=256, + use_residual=True, + num_outs=5), + dict( + type='DyHead', + in_channels=256, + out_channels=256, + num_blocks=3) + ], + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0) + ), + roi_head=dict( + type='StandardRoIHead', + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=1, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.02, + # nms=dict(type='nms', iou_threshold=0.5), + nms=dict(type='nms', + iou_threshold=0.5, + class_agnostic=True, + split_thr=100000), + max_per_img=50, + mask_thr_binary=0.5) + # soft-nms is also supported for rcnn testing + # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05) + ), + track_head=dict( + type='MasaTrackHead', + roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + embed_head=dict( + type='QuasiDenseEmbedHead', + num_convs=4, + num_fcs=1, + embed_channels=256, + norm_cfg=dict(type='GN', num_groups=32), + loss_track=dict(type='UnbiasedContrastLoss', loss_weight=0.25), + loss_track_aux=dict( + type='MarginL2Loss', + neg_pos_ub=3, + pos_margin=0, + neg_margin=0.1, + hard_mining=True, + loss_weight=1.0)), + train_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='CombinedSampler', + num=512, + pos_fraction=0.5, + neg_pos_ub=3, + add_gt_as_proposals=True, + pos_sampler=dict(type='InstanceBalancedPosSampler'), + neg_sampler=dict(type='RandomSampler')))), + tracker=dict( + type='MasaTaoTracker', + init_score_thr=0.0001, + obj_score_thr=0.0001, + match_score_thr=0.5, + memo_tracklet_frames=10, + memo_momentum=0.8, + with_cats=False, + max_distance=-1, + fps=1, + ) +) + +test_pipeline = [ + dict( + type='TransformBroadcaster', + transforms=[ + dict(type='LoadImageFromFile'), + dict( + type='Resize', + scale=(1024, 1024), + keep_ratio=True), + dict(type='LoadTrackAnnotations') + ]), + dict(type='PackTrackInputs') +] + +# runtime settings +train_dataloader = None +train_cfg = None +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + logger=dict(type='LoggerHook', interval=50), + visualization=dict(type='TrackVisualizationHook', draw=False)) + +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='MasaTrackLocalVisualizer', vis_backends=vis_backends, name='visualizer') + +auto_scale_lr = dict(enable=False, base_batch_size=16) +val_dataloader = dict( + dataset=dict( + ann_file='data/tao/annotations/tao_val_lvis_v1_classes.json', + pipeline=test_pipeline, + ) +) +test_dataloader = val_dataloader +test_evaluator = dict( + ann_file='data/tao/annotations/tao_val_lvis_v1_classes.json', + outfile_prefix='results/masa_results/masa-sam-b-release-ovmot-test', + open_vocabulary=True, +) diff --git a/configs/masa-sam/open_vocabulary_mot_test/masa_sam_vith_open_vocabulary_test.py b/configs/masa-sam/open_vocabulary_mot_test/masa_sam_vith_open_vocabulary_test.py new file mode 100644 index 0000000000000000000000000000000000000000..bcbc4d16439fa8c2b7987455d1f1def0df8818e3 --- /dev/null +++ b/configs/masa-sam/open_vocabulary_mot_test/masa_sam_vith_open_vocabulary_test.py @@ -0,0 +1,234 @@ +_base_ = [ + '../sam-vith.py', + '../../datasets/tao/tao_dataset_v1.py', + '../../default_runtime.py' +] +default_scope = 'mmdet' +detector = _base_.model +detector['init_cfg'] = dict( + type='Pretrained', + checkpoint= 'saved_models/pretrain_weights/sam_vit_h_4b8939_mmdet.pth' + # noqa: E501 +) +detector['type'] = 'SamMasa' + +del _base_.model + +model = dict( + type='MASA', + freeze_detector=True, + unified_backbone=True, + load_public_dets = True, + benchmark = 'tao', + public_det_path = 'results/public_dets/tao_val_dets/teta_50_internms/detic_tao_val_det/', + data_preprocessor=dict( + type='TrackDataPreprocessor', + # Image normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + # Image padding parameters + pad_mask=False, # In instance segmentation, the mask needs to be padded + pad_size_divisor=1024), # Padding the image to multiples of 32 + detector=detector, + masa_adapter=[ + dict( + type='SimpleFPN', + in_channels=[1280, 1280, 1280, 1280], + out_channels=256, + use_residual=True, + num_outs=5), + dict( + type='DyHead', + in_channels=256, + out_channels=256, + num_blocks=3) + ], + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0) + ), + roi_head=dict( + type='StandardRoIHead', + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=1, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.02, + # nms=dict(type='nms', iou_threshold=0.5), + nms=dict(type='nms', + iou_threshold=0.5, + class_agnostic=True, + split_thr=100000), + max_per_img=50, + mask_thr_binary=0.5) + # soft-nms is also supported for rcnn testing + # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05) + ), + track_head=dict( + type='MasaTrackHead', + roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + embed_head=dict( + type='QuasiDenseEmbedHead', + num_convs=4, + num_fcs=1, + embed_channels=256, + norm_cfg=dict(type='GN', num_groups=32), + loss_track=dict(type='UnbiasedContrastLoss', loss_weight=0.25), + loss_track_aux=dict( + type='MarginL2Loss', + neg_pos_ub=3, + pos_margin=0, + neg_margin=0.1, + hard_mining=True, + loss_weight=1.0)), + # loss_bbox=dict(type='L1Loss', loss_weight=1.0), + train_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='CombinedSampler', + num=512, + pos_fraction=0.8, + neg_pos_ub=3, + add_gt_as_proposals=True, + pos_sampler=dict(type='InstanceBalancedPosSampler'), + neg_sampler=dict(type='RandomSampler')))), + tracker=dict( + type='MasaTaoTracker', + init_score_thr=0.0001, + obj_score_thr=0.0001, + match_score_thr=0.5, + memo_tracklet_frames=10, + memo_momentum=0.8, + with_cats=False, + max_distance=-1, + fps=1, + ) +) + +test_pipeline = [ + dict( + type='TransformBroadcaster', + transforms=[ + dict(type='LoadImageFromFile'), + dict( + type='Resize', + scale=(1024, 1024), + keep_ratio=True), + dict(type='LoadTrackAnnotations') + ]), + dict(type='PackTrackInputs') +] + + +train_dataloader = None +train_cfg = None +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + logger=dict(type='LoggerHook', interval=50), + visualization=dict(type='TrackVisualizationHook', draw=False)) + +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='MasaTrackLocalVisualizer', vis_backends=vis_backends, name='visualizer') + +auto_scale_lr = dict(enable=False, base_batch_size=16) +val_dataloader = dict( + dataset=dict( + ann_file='data/tao/annotations/tao_val_lvis_v1_classes.json', + pipeline=test_pipeline, + ) +) +test_dataloader = val_dataloader +test_evaluator = dict( + ann_file='data/tao/annotations/tao_val_lvis_v1_classes.json', + outfile_prefix='results/masa_results/masa-sam-h-release-ovmot-test', + open_vocabulary=True, +) diff --git a/configs/masa-sam/sam-vitb.py b/configs/masa-sam/sam-vitb.py new file mode 100644 index 0000000000000000000000000000000000000000..4f318ba026466cc65bafecccf12980937bd6b861 --- /dev/null +++ b/configs/masa-sam/sam-vitb.py @@ -0,0 +1,30 @@ +prompt_embed_dim=256 +model = dict( + type='SamMasa', + backbone=dict( + type='ImageEncoderViT', + depth=12, + embed_dim=768, + img_size=1024, + mlp_ratio=4, + num_heads=12, + patch_size=16, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=[2, 5, 8, 11], + window_size=14, + out_chans=prompt_embed_dim, + out_indices=[2, 5, 8, 11]), + mask_decoder=dict( + type='MaskDecoder', + num_multimask_outputs=3, + transformer_dim=prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256), + prompt_encoder=dict( + type='PromptEncoder', + embed_dim=prompt_embed_dim, + image_embedding_size=(64, 64), + input_image_size=(1024, 1024), + mask_in_chans=16), +) \ No newline at end of file diff --git a/configs/masa-sam/sam-vith.py b/configs/masa-sam/sam-vith.py new file mode 100644 index 0000000000000000000000000000000000000000..0db34bc41c641b14602f66eb4644f4282d8cb1df --- /dev/null +++ b/configs/masa-sam/sam-vith.py @@ -0,0 +1,30 @@ +prompt_embed_dim=256 +model = dict( + type='SamMasa', + backbone=dict( + type='ImageEncoderViT', + depth=32, + embed_dim=1280, + img_size=1024, + mlp_ratio=4, + num_heads=16, + patch_size=16, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=[7, 15, 23, 31], + window_size=14, + out_chans=prompt_embed_dim, + out_indices=[7, 15, 23, 31]), + mask_decoder=dict( + type='MaskDecoder', + num_multimask_outputs=3, + transformer_dim=prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256), + prompt_encoder=dict( + type='PromptEncoder', + embed_dim=prompt_embed_dim, + image_embedding_size=(64, 64), + input_image_size=(1024, 1024), + mask_in_chans=16), +) \ No newline at end of file diff --git a/configs/masa-sam/tao_teta_test/masa_sam_vitb_tao_test_detic_dets.py b/configs/masa-sam/tao_teta_test/masa_sam_vitb_tao_test_detic_dets.py new file mode 100644 index 0000000000000000000000000000000000000000..13b47c2c276a9daa3ff2208cd4611a5111694938 --- /dev/null +++ b/configs/masa-sam/tao_teta_test/masa_sam_vitb_tao_test_detic_dets.py @@ -0,0 +1,232 @@ +_base_ = [ + '../sam-vitb.py', + '../../datasets/tao/tao_dataset_v1.py', + '../../default_runtime.py' +] +default_scope = 'mmdet' +detector = _base_.model +detector['init_cfg'] = dict( + type='Pretrained', + checkpoint= 'saved_models/pretrain_weights/sam_vit_b_01ec64_mmdet.pth' + # noqa: E501 +) +detector['type'] = 'SamMasa' + +del _base_.model + +model = dict( + type='MASA', + freeze_detector=True, + unified_backbone=True, + load_public_dets = True, + benchmark = 'tao', + public_det_path = 'results/public_dets/tao_val_dets/teta_50_internms/detic_tao_val_det/', + data_preprocessor=dict( + type='TrackDataPreprocessor', + # Image normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + # Image padding parameters + pad_mask=False, # In instance segmentation, the mask needs to be padded + pad_size_divisor=1024), # Padding the image to multiples of 32 + detector=detector, + masa_adapter=[ + dict( + type='SimpleFPN', + in_channels=[768, 768, 768, 768], + out_channels=256, + use_residual=True, + num_outs=5), + dict( + type='DyHead', + in_channels=256, + out_channels=256, + num_blocks=3) + ], + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0) + ), + roi_head=dict( + type='StandardRoIHead', + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=1, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.02, + # nms=dict(type='nms', iou_threshold=0.5), + nms=dict(type='nms', + iou_threshold=0.5, + class_agnostic=True, + split_thr=100000), + max_per_img=50, + mask_thr_binary=0.5) + # soft-nms is also supported for rcnn testing + # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05) + ), + track_head=dict( + type='MasaTrackHead', + roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + embed_head=dict( + type='QuasiDenseEmbedHead', + num_convs=4, + num_fcs=1, + embed_channels=256, + norm_cfg=dict(type='GN', num_groups=32), + loss_track=dict(type='UnbiasedContrastLoss', loss_weight=0.25), + loss_track_aux=dict( + type='MarginL2Loss', + neg_pos_ub=3, + pos_margin=0, + neg_margin=0.1, + hard_mining=True, + loss_weight=1.0)), + train_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='CombinedSampler', + num=512, + pos_fraction=0.5, + neg_pos_ub=3, + add_gt_as_proposals=True, + pos_sampler=dict(type='InstanceBalancedPosSampler'), + neg_sampler=dict(type='RandomSampler')))), + tracker=dict( + type='MasaTaoTracker', + init_score_thr=0.0001, + obj_score_thr=0.0001, + match_score_thr=0.5, + memo_tracklet_frames=10, + memo_momentum=0.8, + with_cats=False, + max_distance=-1, + fps=1, + ) +) + +test_pipeline = [ + dict( + type='TransformBroadcaster', + transforms=[ + dict(type='LoadImageFromFile'), + dict( + type='Resize', + scale=(1024, 1024), + keep_ratio=True), + dict(type='LoadTrackAnnotations') + ]), + dict(type='PackTrackInputs') +] + +# runtime settings +train_dataloader = None +train_cfg = None +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + logger=dict(type='LoggerHook', interval=50), + visualization=dict(type='TrackVisualizationHook', draw=False)) + +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='MasaTrackLocalVisualizer', vis_backends=vis_backends, name='visualizer') + +auto_scale_lr = dict(enable=False, base_batch_size=16) +val_dataloader = dict( + dataset=dict( + ann_file='data/tao/annotations/tao_val_lvis_v1_classes.json', + pipeline=test_pipeline, + ) +) +test_dataloader = val_dataloader +test_evaluator = dict( + ann_file='data/tao/annotations/tao_val_lvis_v1_classes.json', +outfile_prefix='results/masa_results/masa-sam-vitb-tao-test-detic-dets', +) diff --git a/configs/masa-sam/tao_teta_test/masa_sam_vitb_tao_test_teter_swinT_dets.py b/configs/masa-sam/tao_teta_test/masa_sam_vitb_tao_test_teter_swinT_dets.py new file mode 100644 index 0000000000000000000000000000000000000000..88ec627ab9eb70d4c562db43a18b8152b7a19914 --- /dev/null +++ b/configs/masa-sam/tao_teta_test/masa_sam_vitb_tao_test_teter_swinT_dets.py @@ -0,0 +1,238 @@ +_base_ = [ + '../sam-vitb.py', + '../../datasets/tao/tao_dataset_v05.py', + '../../default_runtime.py' +] +default_scope = 'mmdet' +detector = _base_.model +detector['init_cfg'] = dict( + type='Pretrained', + checkpoint= 'saved_models/pretrain_weights/sam_vit_b_01ec64_mmdet.pth' + # noqa: E501 +) +detector['type'] = 'SamMasa' + +del _base_.model + +model = dict( + type='MASA', + freeze_detector=True, + unified_backbone=True, + load_public_dets = True, + public_det_path = 'results/public_dets/tao_val_dets/teta_50_internms/teter_swinT_tao_val_internms_50/', + data_preprocessor=dict( + type='TrackDataPreprocessor', + # Image normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + # Image padding parameters + pad_mask=False, # In instance segmentation, the mask needs to be padded + pad_size_divisor=1024), # Padding the image to multiples of 32 + detector=detector, + masa_adapter=[ + dict( + type='SimpleFPN', + in_channels=[768, 768, 768, 768], + out_channels=256, + use_residual=True, + num_outs=5), + dict( + type='DyHead', + in_channels=256, + out_channels=256, + num_blocks=3) + ], + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0) + ), + roi_head=dict( + type='StandardRoIHead', + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=1, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.02, + # nms=dict(type='nms', iou_threshold=0.5), + nms=dict(type='nms', + iou_threshold=0.5, + class_agnostic=True, + split_thr=100000), + max_per_img=50, + mask_thr_binary=0.5) + # soft-nms is also supported for rcnn testing + # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05) + ), + track_head=dict( + type='MasaTrackHead', + roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + embed_head=dict( + type='QuasiDenseEmbedHead', + num_convs=4, + num_fcs=1, + embed_channels=256, + norm_cfg=dict(type='GN', num_groups=32), + loss_track=dict(type='UnbiasedContrastLoss', loss_weight=0.25), + loss_track_aux=dict( + type='MarginL2Loss', + neg_pos_ub=3, + pos_margin=0, + neg_margin=0.1, + hard_mining=True, + loss_weight=1.0)), + train_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='CombinedSampler', + num=512, + pos_fraction=0.5, + neg_pos_ub=3, + add_gt_as_proposals=True, + pos_sampler=dict(type='InstanceBalancedPosSampler'), + neg_sampler=dict(type='RandomSampler')))), + tracker=dict( + type='MasaTaoTracker', + init_score_thr=0.0001, + obj_score_thr=0.0001, + match_score_thr=0.5, + memo_tracklet_frames=10, + memo_momentum=0.8, + with_cats=False, + max_distance=-1, + fps=1, + ) +) + +test_pipeline = [ + dict( + type='TransformBroadcaster', + transforms=[ + dict(type='LoadImageFromFile'), + dict( + type='Resize', + scale=(1024, 1024), + keep_ratio=True), + dict(type='LoadTrackAnnotations') + ]), + dict(type='PackTrackInputs') +] + + +train_dataloader = None +train_cfg = None +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + logger=dict(type='LoggerHook', interval=50), + visualization=dict(type='TrackVisualizationHook', draw=False)) + +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='MasaTrackLocalVisualizer', vis_backends=vis_backends, name='visualizer') + +# custom hooks +custom_hooks = [ + # Synchronize model buffers such as running_mean and running_var in BN + # at the end of each epoch + dict(type='SyncBuffersHook') +] +auto_scale_lr = dict(enable=False, base_batch_size=16) +val_dataloader = dict( + dataset=dict( + ann_file='data/tao/annotations/tao_val_lvis_v05_classes.json', + pipeline=test_pipeline, + ) +) +test_dataloader = val_dataloader +val_evaluator = dict( + ann_file='data/tao/annotations/tao_val_lvis_v05_classes.json', +outfile_prefix='results/masa_results/masa-sam-vitb-tao-test-teter-swinT-dets', +) +test_evaluator = val_evaluator \ No newline at end of file diff --git a/configs/masa-sam/tao_teta_test/masa_sam_vith_tao_test_detic_dets.py b/configs/masa-sam/tao_teta_test/masa_sam_vith_tao_test_detic_dets.py new file mode 100644 index 0000000000000000000000000000000000000000..fe1e9077c094605fa8bf2fe4ef31156c58a1435c --- /dev/null +++ b/configs/masa-sam/tao_teta_test/masa_sam_vith_tao_test_detic_dets.py @@ -0,0 +1,233 @@ +_base_ = [ + '../sam-vith.py', + '../../datasets/tao/tao_dataset_v1.py', + '../../default_runtime.py' +] +default_scope = 'mmdet' +detector = _base_.model +detector['init_cfg'] = dict( + type='Pretrained', + checkpoint= 'saved_models/pretrain_weights/sam_vit_h_4b8939_mmdet.pth' + # noqa: E501 +) +detector['type'] = 'SamMasa' + +del _base_.model + +model = dict( + type='MASA', + freeze_detector=True, + unified_backbone=True, + load_public_dets = True, + benchmark = 'tao', + public_det_path = 'results/public_dets/tao_val_dets/teta_50_internms/detic_tao_val_det/', + data_preprocessor=dict( + type='TrackDataPreprocessor', + # Image normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + # Image padding parameters + pad_mask=False, # In instance segmentation, the mask needs to be padded + pad_size_divisor=1024), # Padding the image to multiples of 32 + detector=detector, + masa_adapter=[ + dict( + type='SimpleFPN', + in_channels=[1280, 1280, 1280, 1280], + out_channels=256, + use_residual=True, + num_outs=5), + dict( + type='DyHead', + in_channels=256, + out_channels=256, + num_blocks=3) + ], + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0) + ), + roi_head=dict( + type='StandardRoIHead', + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=1, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.02, + # nms=dict(type='nms', iou_threshold=0.5), + nms=dict(type='nms', + iou_threshold=0.5, + class_agnostic=True, + split_thr=100000), + max_per_img=50, + mask_thr_binary=0.5) + # soft-nms is also supported for rcnn testing + # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05) + ), + track_head=dict( + type='MasaTrackHead', + roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + embed_head=dict( + type='QuasiDenseEmbedHead', + num_convs=4, + num_fcs=1, + embed_channels=256, + norm_cfg=dict(type='GN', num_groups=32), + loss_track=dict(type='UnbiasedContrastLoss', loss_weight=0.25), + loss_track_aux=dict( + type='MarginL2Loss', + neg_pos_ub=3, + pos_margin=0, + neg_margin=0.1, + hard_mining=True, + loss_weight=1.0)), + # loss_bbox=dict(type='L1Loss', loss_weight=1.0), + train_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='CombinedSampler', + num=512, + pos_fraction=0.8, + neg_pos_ub=3, + add_gt_as_proposals=True, + pos_sampler=dict(type='InstanceBalancedPosSampler'), + neg_sampler=dict(type='RandomSampler')))), + tracker=dict( + type='MasaTaoTracker', + init_score_thr=0.0001, + obj_score_thr=0.0001, + match_score_thr=0.5, + memo_tracklet_frames=10, + memo_momentum=0.8, + with_cats=False, + max_distance=-1, + fps=1, + ) +) + +test_pipeline = [ + dict( + type='TransformBroadcaster', + transforms=[ + dict(type='LoadImageFromFile'), + dict( + type='Resize', + scale=(1024, 1024), + keep_ratio=True), + dict(type='LoadTrackAnnotations') + ]), + dict(type='PackTrackInputs') +] + + +train_dataloader = None +train_cfg = None +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + logger=dict(type='LoggerHook', interval=50), + visualization=dict(type='TrackVisualizationHook', draw=False)) + +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='MasaTrackLocalVisualizer', vis_backends=vis_backends, name='visualizer') + +auto_scale_lr = dict(enable=False, base_batch_size=16) +val_dataloader = dict( + dataset=dict( + ann_file='data/tao/annotations/tao_val_lvis_v1_classes.json', + pipeline=test_pipeline, + ) +) +test_dataloader = val_dataloader +test_evaluator = dict( + ann_file='data/tao/annotations/tao_val_lvis_v1_classes.json', + outfile_prefix='results/masa_results/masa-sam-vith-tao-test-detic-dets', +) diff --git a/configs/masa-sam/tao_teta_test/masa_sam_vith_tao_test_teter_swinT_dets.py b/configs/masa-sam/tao_teta_test/masa_sam_vith_tao_test_teter_swinT_dets.py new file mode 100644 index 0000000000000000000000000000000000000000..9171c95bf15d7897a06e332afd268ea8fb5f061d --- /dev/null +++ b/configs/masa-sam/tao_teta_test/masa_sam_vith_tao_test_teter_swinT_dets.py @@ -0,0 +1,239 @@ +_base_ = [ + '../sam-vith.py', + '../../datasets/tao/tao_dataset_v05.py', + '../../default_runtime.py' +] +default_scope = 'mmdet' +detector = _base_.model +detector['init_cfg'] = dict( + type='Pretrained', + checkpoint= 'saved_models/pretrain_weights/sam_vit_h_4b8939_mmdet.pth' + # noqa: E501 +) +detector['type'] = 'SamMasa' + +del _base_.model + +model = dict( + type='MASA', + freeze_detector=True, + unified_backbone=True, + load_public_dets = True, + public_det_path = 'results/public_dets/tao_val_dets/teta_50_internms/teter_swinT_tao_val_internms_50/', + data_preprocessor=dict( + type='TrackDataPreprocessor', + # Image normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + # Image padding parameters + pad_mask=False, # In instance segmentation, the mask needs to be padded + pad_size_divisor=1024), # Padding the image to multiples of 32 + detector=detector, + masa_adapter=[ + dict( + type='SimpleFPN', + in_channels=[1280, 1280, 1280, 1280], + out_channels=256, + use_residual=True, + num_outs=5), + dict( + type='DyHead', + in_channels=256, + out_channels=256, + num_blocks=3) + ], + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0) + ), + roi_head=dict( + type='StandardRoIHead', + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=1, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.02, + # nms=dict(type='nms', iou_threshold=0.5), + nms=dict(type='nms', + iou_threshold=0.5, + class_agnostic=True, + split_thr=100000), + max_per_img=50, + mask_thr_binary=0.5) + # soft-nms is also supported for rcnn testing + # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05) + ), + track_head=dict( + type='MasaTrackHead', + roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + embed_head=dict( + type='QuasiDenseEmbedHead', + num_convs=4, + num_fcs=1, + embed_channels=256, + norm_cfg=dict(type='GN', num_groups=32), + loss_track=dict(type='UnbiasedContrastLoss', loss_weight=0.25), + loss_track_aux=dict( + type='MarginL2Loss', + neg_pos_ub=3, + pos_margin=0, + neg_margin=0.1, + hard_mining=True, + loss_weight=1.0)), + # loss_bbox=dict(type='L1Loss', loss_weight=1.0), + train_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='CombinedSampler', + num=512, + pos_fraction=0.8, + neg_pos_ub=3, + add_gt_as_proposals=True, + pos_sampler=dict(type='InstanceBalancedPosSampler'), + neg_sampler=dict(type='RandomSampler')))), + tracker=dict( + type='MasaTaoTracker', + init_score_thr=0.0001, + obj_score_thr=0.0001, + match_score_thr=0.5, + memo_tracklet_frames=10, + memo_momentum=0.8, + with_cats=False, + max_distance=-1, + fps=1, + ) +) + +test_pipeline = [ + dict( + type='TransformBroadcaster', + transforms=[ + dict(type='LoadImageFromFile'), + dict( + type='Resize', + scale=(1024, 1024), + keep_ratio=True), + dict(type='LoadTrackAnnotations') + ]), + dict(type='PackTrackInputs') +] + + +train_dataloader = None +train_cfg = None +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + logger=dict(type='LoggerHook', interval=50), + visualization=dict(type='TrackVisualizationHook', draw=False)) + +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='MasaTrackLocalVisualizer', vis_backends=vis_backends, name='visualizer') + +# custom hooks +custom_hooks = [ + # Synchronize model buffers such as running_mean and running_var in BN + # at the end of each epoch + dict(type='SyncBuffersHook') +] +auto_scale_lr = dict(enable=False, base_batch_size=16) +val_dataloader = dict( + dataset=dict( + ann_file='data/tao/annotations/tao_val_lvis_v05_classes.json', + pipeline=test_pipeline, + ) +) +test_dataloader = val_dataloader +val_evaluator = dict( + ann_file='data/tao/annotations/tao_val_lvis_v05_classes.json', + outfile_prefix='results/masa_results/masa-sam-vith-tao-test-teter-swinT-dets', +) +test_evaluator = val_evaluator \ No newline at end of file diff --git a/environment_docker.yml b/environment_docker.yml new file mode 100644 index 0000000000000000000000000000000000000000..dfac909ecda8d24fb7d68f534984e96b006da377 --- /dev/null +++ b/environment_docker.yml @@ -0,0 +1,302 @@ +name: masaenv +channels: + - pytorch + - nvidia + - conda-forge + - defaults +dependencies: + - _libgcc_mutex=0.1=conda_forge + - _openmp_mutex=4.5=2_gnu + - aom=3.9.1=hac33072_0 + - blas=1.0=mkl + - brotli-python=1.0.9=py311h6a678d5_8 + - bzip2=1.0.8=h5eee18b_6 + - ca-certificates=2024.6.2=hbcca054_0 + - cairo=1.18.0=h3faef2a_0 + - certifi=2024.6.2=pyhd8ed1ab_0 + - charset-normalizer=2.0.4=pyhd3eb1b0_0 + - cuda-cudart=11.8.89=0 + - cuda-cupti=11.8.87=0 + - cuda-libraries=11.8.0=0 + - cuda-nvrtc=11.8.89=0 + - cuda-nvtx=11.8.86=0 + - cuda-runtime=11.8.0=0 + - cudatoolkit=11.8.0=h6a678d5_0 + - dav1d=1.2.1=hd590300_0 + - expat=2.6.2=h59595ed_0 + - ffmpeg=7.0.1=gpl_hb399a10_100 + - font-ttf-dejavu-sans-mono=2.37=hab24e00_0 + - font-ttf-inconsolata=3.000=h77eed37_0 + - font-ttf-source-code-pro=2.038=h77eed37_0 + - font-ttf-ubuntu=0.83=h77eed37_2 + - fontconfig=2.14.2=h14ed4e7_0 + - fonts-conda-ecosystem=1=0 + - fonts-conda-forge=1=0 + - freetype=2.12.1=h4a9f257_0 + - fribidi=1.0.10=h36c2ea0_0 + - gmp=6.3.0=h59595ed_1 + - gmpy2=2.1.2=py311hc9b5ff0_0 + - gnutls=3.7.9=hb077bed_0 + - graphite2=1.3.13=h59595ed_1003 + - harfbuzz=8.5.0=hfac3d4d_0 + - icu=73.2=h59595ed_0 + - idna=3.7=py311h06a4308_0 + - intel-openmp=2023.1.0=hdb19cb5_46306 + - jinja2=3.1.4=py311h06a4308_0 + - jpeg=9e=h5eee18b_1 + - lame=3.100=h7b6447c_0 + - lcms2=2.12=h3be6417_0 + - ld_impl_linux-64=2.38=h1181459_1 + - lerc=3.0=h295c915_0 + - libabseil=20240116.2=cxx17_h59595ed_0 + - libass=0.17.1=h8fe9dca_1 + - libcublas=11.11.3.6=0 + - libcufft=10.9.0.58=0 + - libcufile=1.9.1.3=0 + - libcurand=10.3.5.147=0 + - libcusolver=11.4.1.48=0 + - libcusparse=11.7.5.86=0 + - libdeflate=1.17=h5eee18b_1 + - libdrm=2.4.120=hd590300_0 + - libexpat=2.6.2=h59595ed_0 + - libffi=3.4.4=h6a678d5_1 + - libgcc-ng=13.2.0=h77fa898_10 + - libglib=2.80.2=hf974151_0 + - libgomp=13.2.0=h77fa898_10 + - libhwloc=2.10.0=default_h5622ce7_1001 + - libiconv=1.17=hd590300_2 + - libidn2=2.3.4=h5eee18b_0 + - libjpeg-turbo=2.0.0=h9bf148f_0 + - libnpp=11.8.0.86=0 + - libnsl=2.0.1=hd590300_0 + - libnvjpeg=11.9.0.86=0 + - libopenvino=2024.1.0=h2da1b83_7 + - libopenvino-auto-batch-plugin=2024.1.0=hb045406_7 + - libopenvino-auto-plugin=2024.1.0=hb045406_7 + - libopenvino-hetero-plugin=2024.1.0=h5c03a75_7 + - libopenvino-intel-cpu-plugin=2024.1.0=h2da1b83_7 + - libopenvino-intel-gpu-plugin=2024.1.0=h2da1b83_7 + - libopenvino-intel-npu-plugin=2024.1.0=he02047a_7 + - libopenvino-ir-frontend=2024.1.0=h5c03a75_7 + - libopenvino-onnx-frontend=2024.1.0=h07e8aee_7 + - libopenvino-paddle-frontend=2024.1.0=h07e8aee_7 + - libopenvino-pytorch-frontend=2024.1.0=he02047a_7 + - libopenvino-tensorflow-frontend=2024.1.0=h39126c6_7 + - libopenvino-tensorflow-lite-frontend=2024.1.0=he02047a_7 + - libopus=1.3.1=h7f98852_1 + - libpciaccess=0.18=hd590300_0 + - libpng=1.6.39=h5eee18b_0 + - libprotobuf=4.25.3=h08a7969_0 + - libsqlite=3.46.0=hde9e2c9_0 + - libstdcxx-ng=13.2.0=hc0a3c3a_10 + - libtasn1=4.19.0=h5eee18b_0 + - libtiff=4.5.1=h6a678d5_0 + - libunistring=0.9.10=h27cfd23_0 + - libuuid=2.38.1=h0b41bf4_0 + - libva=2.21.0=h4ab18f5_2 + - libvpx=1.14.1=hac33072_0 + - libwebp-base=1.3.2=h5eee18b_0 + - libxcb=1.15=h0b41bf4_0 + - libxcrypt=4.4.36=hd590300_1 + - libxml2=2.12.7=hc051c1a_1 + - libzlib=1.2.13=h4ab18f5_6 + - llvm-openmp=14.0.6=h9e868ea_0 + - lz4-c=1.9.4=h6a678d5_1 + - markupsafe=2.1.3=py311h5eee18b_0 + - mkl=2023.1.0=h213fc3f_46344 + - mkl-service=2.4.0=py311h5eee18b_1 + - mkl_fft=1.3.8=py311h5eee18b_0 + - mkl_random=1.2.4=py311hdb19cb5_0 + - mpc=1.1.0=h10f8cd9_1 + - mpfr=4.0.2=hb69a4c5_1 + - mpmath=1.3.0=py311h06a4308_0 + - ncurses=6.4=h6a678d5_0 + - nettle=3.9.1=h7ab15ed_0 + - networkx=3.2.1=py311h06a4308_0 + - numpy=1.26.4=py311h08b1b3b_0 + - numpy-base=1.26.4=py311hf175353_0 + - ocl-icd=2.3.2=hd590300_1 + - openh264=2.4.1=h59595ed_0 + - openjpeg=2.4.0=h3ad879b_0 + - openssl=3.3.1=h4ab18f5_0 + - p11-kit=0.24.1=hc5aa10d_0 + - pcre2=10.43=hcad00b1_0 + - pillow=10.3.0=py311h5eee18b_0 + - pip=24.0=py311h06a4308_0 + - pixman=0.43.2=h59595ed_0 + - pthread-stubs=0.4=h36c2ea0_1001 + - pugixml=1.14=h59595ed_0 + - pysocks=1.7.1=py311h06a4308_0 + - python=3.11.8=hab00c5b_0_cpython + - pytorch=2.1.2=py3.11_cuda11.8_cudnn8.7.0_0 + - pytorch-cuda=11.8=h7e8668a_5 + - pytorch-mutex=1.0=cuda + - pyyaml=6.0.1=py311h5eee18b_0 + - readline=8.2=h5eee18b_0 + - snappy=1.2.0=hdb0a2a9_1 + - sqlite=3.45.3=h5eee18b_0 + - svt-av1=2.1.0=hac33072_0 + - sympy=1.12=py311h06a4308_0 + - tbb=2021.12.0=h297d8ca_1 + - tk=8.6.14=h39e8969_0 + - torchaudio=2.1.2=py311_cu118 + - torchtriton=2.1.0=py311 + - torchvision=0.16.2=py311_cu118 + - typing_extensions=4.11.0=py311h06a4308_0 + - wheel=0.43.0=py311h06a4308_0 + - x264=1!164.3095=h166bdaf_2 + - x265=3.5=h924138e_3 + - xorg-fixesproto=5.0=h7f98852_1002 + - xorg-kbproto=1.0.7=h7f98852_1002 + - xorg-libice=1.1.1=hd590300_0 + - xorg-libsm=1.2.4=h7391055_0 + - xorg-libx11=1.8.9=h8ee46fc_0 + - xorg-libxau=1.0.11=hd590300_0 + - xorg-libxdmcp=1.1.3=h7f98852_0 + - xorg-libxext=1.3.4=h0b41bf4_2 + - xorg-libxfixes=5.0.3=h7f98852_1004 + - xorg-libxrender=0.9.11=hd590300_0 + - xorg-renderproto=0.11.1=h7f98852_1002 + - xorg-xextproto=7.3.0=h0b41bf4_1003 + - xorg-xproto=7.0.31=h7f98852_1007 + - xz=5.4.6=h5eee18b_1 + - yaml=0.2.5=h7b6447c_0 + - zlib=1.2.13=h4ab18f5_6 + - zstd=1.5.5=hc292b87_2 + - pip: + - addict==2.4.0 + - aiofiles==23.2.1 + - aliyun-python-sdk-core==2.15.1 + - aliyun-python-sdk-kms==2.16.3 + - altair==5.3.0 + - annotated-types==0.7.0 + - anyio==4.4.0 + - attrs==23.2.0 + - boto3==1.34.128 + - botocore==1.34.128 + - cffi==1.16.0 + - click==8.1.7 + - clip==1.0 + - colorama==0.4.6 + - contourpy==1.2.1 + - crcmod==1.7 + - cryptography==42.0.8 + - cycler==0.12.1 + - cython==3.0.10 + - decorator==4.4.2 + - defusedxml==0.7.1 + - dnspython==2.6.1 + - einops==0.8.0 + - email-validator==2.1.2 + - fairscale==0.4.13 + - fastapi==0.111.0 + - fastapi-cli==0.0.4 + - ffmpy==0.3.2 + - filelock==3.14.0 + - fonttools==4.53.0 + - fsspec==2024.6.0 + - ftfy==6.2.0 + - gradio==4.36.1 + - gradio-client==1.0.1 + - h11==0.14.0 + - h5py==3.11.0 + - httpcore==1.0.5 + - httptools==0.6.1 + - httpx==0.27.0 + - huggingface-hub==0.23.4 + - imageio==2.34.1 + - importlib-metadata==7.1.0 + - importlib-resources==6.4.0 + - jmespath==0.10.0 + - joblib==1.4.2 + - jsonschema==4.22.0 + - jsonschema-specifications==2023.12.1 + - kiwisolver==1.4.5 + - llvmlite==0.43.0 + - lvis==0.5.3 + - markdown==3.6 + - markdown-it-py==3.0.0 + - matplotlib==3.9.0 + - mdurl==0.1.2 + - mmcv==2.1.0 + - mmdet==3.3.0 + - mmengine==0.10.4 + - model-index==0.1.11 + - motmetrics==1.4.0 + - moviepy==0.2.3.5 + - nanoid==2.0.0 + - natsort==8.4.0 + - nltk==3.8.1 + - numba==0.60.0 + - opencv-python==4.10.0.84 + - opencv-python-headless==4.10.0.84 + - opendatalab==0.0.10 + - openmim==0.3.9 + - openxlab==0.1.0 + - ordered-set==4.1.0 + - orjson==3.10.5 + - oss2==2.17.0 + - packaging==24.1 + - pandas==2.2.2 + - platformdirs==4.2.2 + - plyfile==1.0.3 + - psutil==5.9.8 + - pycocotools==2.0.8 + - pycparser==2.22 + - pycryptodome==3.20.0 + - pydantic==2.7.4 + - pydantic-core==2.18.4 + - pydub==0.25.1 + - pygments==2.18.0 + - pyparsing==3.1.2 + - python-dateutil==2.9.0.post0 + - python-dotenv==1.0.1 + - python-multipart==0.0.9 + - pytz==2023.4 + - referencing==0.35.1 + - regex==2024.5.15 + - requests==2.32.3 + - rich==13.4.2 + - rpds-py==0.18.1 + - ruff==0.4.9 + - s3transfer==0.10.1 + - safetensors==0.4.3 + - scalabel==0.3.0 + - scipy==1.13.1 + - script-utils==0.0.1 + - seaborn==0.13.2 + - semantic-version==2.10.0 + - setuptools==60.2.0 + - shapely==2.0.4 + - shellingham==1.5.4 + - six==1.16.0 + - sniffio==1.3.1 + - starlette==0.37.2 + - supervision==0.21.0 + - tabulate==0.9.0 + - tao==0.1.0 + - termcolor==2.4.0 + - terminaltables==3.1.10 + - teta==0.1.0 + - tokenizers==0.15.2 + - toml==0.10.2 + - tomli==2.0.1 + - tomlkit==0.12.0 + - toolz==0.12.1 + - tqdm==4.65.2 + - trackeval==1.0.dev1 + - transformers==4.38.2 + - typer==0.12.3 + - tzdata==2024.1 + - ujson==5.10.0 + - urllib3==2.2.2 + - uvicorn==0.30.1 + - uvloop==0.19.0 + - watchfiles==0.22.0 + - wcwidth==0.2.13 + - websockets==11.0.3 + - xmltodict==0.13.0 + - yacs==0.1.8 + - yapf==0.40.2 + - youtube-dl==2021.12.17 + - zipp==3.19.2 diff --git a/masa/__init__.py b/masa/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1b93e00d7af7a2378ca7e9826a9ba5311bc87a4e --- /dev/null +++ b/masa/__init__.py @@ -0,0 +1,3 @@ +from .datasets import * # noqa +from .models import * # noqa +from .visualization import * # noqa diff --git a/masa/__pycache__/__init__.cpython-311.pyc b/masa/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16d2e4b1311229dbd8ac5c359d98d17a2a8d6bd8 Binary files /dev/null and b/masa/__pycache__/__init__.cpython-311.pyc differ diff --git a/masa/apis/__init__.py b/masa/apis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e60ff2cde74d46b5f1093793b3fb96485578ad43 --- /dev/null +++ b/masa/apis/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .masa_inference import (build_test_pipeline, inference_detector, + inference_masa, init_masa) + +__all__ = [ + "inference_masa", + "init_masa", + "inference_detector", + "build_test_pipeline", +] diff --git a/masa/apis/__pycache__/__init__.cpython-311.pyc b/masa/apis/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1c9d863d46f4399028662b9868e9caefb4740c1 Binary files /dev/null and b/masa/apis/__pycache__/__init__.cpython-311.pyc differ diff --git a/masa/apis/__pycache__/masa_inference.cpython-311.pyc b/masa/apis/__pycache__/masa_inference.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb63f690211cee140265da3013b5d7816f20d38e Binary files /dev/null and b/masa/apis/__pycache__/masa_inference.cpython-311.pyc differ diff --git a/masa/apis/masa_inference.py b/masa/apis/masa_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..8079a42b380c801c633a7dc5a075ae126f0b5bec --- /dev/null +++ b/masa/apis/masa_inference.py @@ -0,0 +1,297 @@ +import copy +import time +import warnings +from pathlib import Path +from typing import Optional, Sequence, Union + +import numpy as np +import torch +import torch.nn as nn +from mmcv.ops import RoIPool +from mmcv.transforms import Compose +from mmdet.evaluation import get_classes +from mmdet.registry import MODELS +from mmdet.structures import DetDataSample, SampleList +from mmdet.utils import ConfigType, get_test_pipeline_cfg +from mmengine.config import Config +from mmengine.dataset import default_collate +from mmengine.model.utils import revert_sync_batchnorm +from mmengine.registry import init_default_scope +from mmengine.runner import autocast, load_checkpoint + +ImagesType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]] + + +def init_masa( + config: Union[str, Path, Config], + checkpoint: Optional[str] = None, + palette: str = "none", + device: str = "cuda:0", + cfg_options: Optional[dict] = None, +) -> nn.Module: + """Initialize a unified masa detector from config file. + + Args: + config (str, :obj:`Path`, or :obj:`mmengine.Config`): Config file path, + :obj:`Path`, or the config object. + checkpoint (str, optional): Checkpoint path. If left as None, the model + will not load any weights. + palette (str): Color palette used for visualization. If palette + is stored in checkpoint, use checkpoint's palette first, otherwise + use externally passed palette. Currently, supports 'coco', 'voc', + 'citys' and 'random'. Defaults to none. + device (str): The device where the anchors will be put on. + Defaults to cuda:0. + cfg_options (dict, optional): Options to override some settings in + the used config. + + Returns: + nn.Module: The constructed detector. + """ + if isinstance(config, (str, Path)): + config = Config.fromfile(config) + elif not isinstance(config, Config): + raise TypeError( + "config must be a filename or Config object, " f"but got {type(config)}" + ) + + with_backbone = config.model.get("backbone", False) + if with_backbone: + if cfg_options is not None: + config.merge_from_dict(cfg_options) + elif "init_cfg" in config.model.backbone: + config.model.backbone.init_cfg = None + else: + if cfg_options is not None: + config.merge_from_dict(cfg_options) + elif "init_cfg" in config.model.detector.backbone: + config.model.detector.backbone.init_cfg = None + + scope = config.get("default_scope", "mmdet") + if scope is not None: + init_default_scope(config.get("default_scope", "mmdet")) + + model = MODELS.build(config.model) + model = revert_sync_batchnorm(model) + if checkpoint is None: + warnings.simplefilter("once") + warnings.warn("checkpoint is None, use COCO classes by default.") + model.dataset_meta = {"classes": get_classes("coco")} + else: + checkpoint = load_checkpoint(model, checkpoint, map_location="cpu") + # Weights converted from elsewhere may not have meta fields. + checkpoint_meta = checkpoint.get("meta", {}) + + # save the dataset_meta in the model for convenience + if "dataset_meta" in checkpoint_meta: + # mmdet 3.x, all keys should be lowercase + model.dataset_meta = { + k.lower(): v for k, v in checkpoint_meta["dataset_meta"].items() + } + elif "CLASSES" in checkpoint_meta: + # < mmdet 3.x + classes = checkpoint_meta["CLASSES"] + model.dataset_meta = {"classes": classes} + else: + warnings.simplefilter("once") + warnings.warn( + "dataset_meta or class names are not saved in the " + "checkpoint's meta data, use COCO classes by default." + ) + model.dataset_meta = {"classes": get_classes("coco")} + + # Priority: args.palette -> config -> checkpoint + if palette != "none": + model.dataset_meta["palette"] = palette + else: + if "palette" not in model.dataset_meta: + warnings.warn( + "palette does not exist, random is used by default. " + "You can also set the palette to customize." + ) + model.dataset_meta["palette"] = "random" + + model.cfg = config # save the config in the model for convenience + model.to(device) + model.eval() + return model + + +def inference_detector( + model: nn.Module, + imgs: ImagesType, + test_pipeline: Optional[Compose] = None, + text_prompt: Optional[str] = None, + custom_entities: bool = False, + fp16: bool = False, +) -> Union[DetDataSample, SampleList]: + """Inference image(s) with the detector. + + Args: + model (nn.Module): The loaded detector. + imgs (str, ndarray, Sequence[str/ndarray]): + Either image files or loaded images. + test_pipeline (:obj:`Compose`): Test pipeline. + + Returns: + :obj:`DetDataSample` or list[:obj:`DetDataSample`]: + If imgs is a list or tuple, the same length list type results + will be returned, otherwise return the detection results directly. + """ + + if isinstance(imgs, (list, tuple)): + is_batch = True + else: + imgs = [imgs] + is_batch = False + + cfg = model.cfg + + if test_pipeline is None: + cfg = cfg.copy() + test_pipeline = get_test_pipeline_cfg(cfg) + if isinstance(imgs[0], np.ndarray): + # Calling this method across libraries will result + # in module unregistered error if not prefixed with mmdet. + test_pipeline[0].type = "mmdet.LoadImageFromNDArray" + + test_pipeline = Compose(test_pipeline) + + if model.data_preprocessor.device.type == "cpu": + for m in model.modules(): + assert not isinstance( + m, RoIPool + ), "CPU inference with RoIPool is not supported currently." + + result_list = [] + for i, img in enumerate(imgs): + # prepare data + if isinstance(img, np.ndarray): + # TODO: remove img_id. + data_ = dict(img=img, img_id=0) + else: + # TODO: remove img_id. + data_ = dict(img_path=img, img_id=0) + + if text_prompt: + data_["text"] = text_prompt + data_["custom_entities"] = custom_entities + + # build the data pipeline + data_ = test_pipeline(data_) + + data_["inputs"] = [data_["inputs"]] + data_["data_samples"] = [data_["data_samples"]] + + # forward the model + with torch.no_grad(): + with autocast(enabled=fp16): + results = model.test_step(data_)[0] + + result_list.append(results) + + if not is_batch: + return result_list[0] + else: + return result_list + + +def inference_masa( + model: nn.Module, + img: np.ndarray, + frame_id: int, + video_len: int, + test_pipeline: Optional[Compose] = None, + text_prompt=None, + custom_entities: bool = False, + det_bboxes=None, + det_labels=None, + fp16=False, + detector_type="mmdet", + show_fps=False, +) -> SampleList: + """Inference image(s) with the masa model. + + Args: + model (nn.Module): The loaded mot model. + img (np.ndarray): Loaded image. + frame_id (int): frame id. + video_len (int): demo video length + Returns: + SampleList: The tracking data samples. + """ + data = dict( + img=[img.astype(np.float32)], + # img=[img.astype(np.uint8)], + frame_id=[frame_id], + ori_shape=[img.shape[:2]], + img_id=[frame_id + 1], + ori_video_length=[video_len], + ) + + if text_prompt is not None: + if detector_type == "mmdet": + data["text"] = [text_prompt] + data["custom_entities"] = [custom_entities] + elif detector_type == "yolo-world": + data["texts"] = [text_prompt] + data["custom_entities"] = [custom_entities] + + data = test_pipeline(data) + + # forward the model + with torch.no_grad(): + data = default_collate([data]) + if det_bboxes is not None: + data["data_samples"][0].video_data_samples[0].det_bboxes = det_bboxes + data["data_samples"][0].video_data_samples[0].det_labels = det_labels + # measure FPS ## + if show_fps: + start = time.time() + with autocast(enabled=fp16): + result = model.test_step(data)[0] + end = time.time() + fps = 1 / (end - start) + return result, fps + + else: + with autocast(enabled=fp16): + result = model.test_step(data)[0] + return result + + +def build_test_pipeline( + cfg: ConfigType, with_text=False, detector_type="mmdet" +) -> ConfigType: + """Build test_pipeline for mot/vis demo. In mot/vis infer, original + test_pipeline should remove the "LoadImageFromFile" and + "LoadTrackAnnotations". + + Args: + cfg (ConfigDict): The loaded config. + Returns: + ConfigType: new test_pipeline + """ + # remove the "LoadImageFromFile" and "LoadTrackAnnotations" in pipeline + transform_broadcaster = cfg.inference_pipeline[0].copy() + if detector_type == "yolo-world": + kept_transform = [] + for transform in transform_broadcaster["transforms"]: + if ( + transform["type"] == "mmyolo.YOLOv5KeepRatioResize" + or transform["type"] == "mmyolo.LetterResize" + ): + kept_transform.append(transform) + transform_broadcaster["transforms"] = kept_transform + pack_track_inputs = cfg.test_dataloader.dataset.pipeline[-1].copy() + test_pipeline = Compose([transform_broadcaster, pack_track_inputs]) + else: + for transform in transform_broadcaster["transforms"]: + if "Resize" in transform["type"]: + transform_broadcaster["transforms"] = transform + pack_track_inputs = cfg.inference_pipeline[-1].copy() + if with_text: + pack_track_inputs["meta_keys"] = ("text", "custom_entities") + test_pipeline = Compose([transform_broadcaster, pack_track_inputs]) + + return test_pipeline diff --git a/masa/datasets/__init__.py b/masa/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2dec4f76e425fbe342ddc9c585a1a076b9623d85 --- /dev/null +++ b/masa/datasets/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Tencent Inc. All rights reserved. +from .bdd_masa_dataset import BDDVideoDataset +from .dataset_wrappers import SeqMultiImageMixDataset +from .evaluation import * # NOQA +from .masa_dataset import MASADataset +from .pipelines import * # NOQA +from .rsconcat_dataset import RandomSampleConcatDataset +from .tao_masa_dataset import Taov1Dataset, Taov05Dataset +from .utils import yolow_collate + +__all__ = [ + "yolow_collate", + "RandomSampleConcatDataset", + "MASADataset", + "SeqMultiImageMixDataset", + "Taov05Dataset", + "Taov1Dataset", + "BDDVideoDataset", +] diff --git a/masa/datasets/__pycache__/__init__.cpython-311.pyc b/masa/datasets/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e869259e48fd1cbd48bbdf2fda383da704a4a956 Binary files /dev/null and b/masa/datasets/__pycache__/__init__.cpython-311.pyc differ diff --git a/masa/datasets/__pycache__/bdd_masa_dataset.cpython-311.pyc b/masa/datasets/__pycache__/bdd_masa_dataset.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a821bf00a906cdbef6129041bf2fb71833b55ede Binary files /dev/null and b/masa/datasets/__pycache__/bdd_masa_dataset.cpython-311.pyc differ diff --git a/masa/datasets/__pycache__/dataset_wrappers.cpython-311.pyc b/masa/datasets/__pycache__/dataset_wrappers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62ad5c33058389d52baefb3fcfdaf5d24e0ce980 Binary files /dev/null and b/masa/datasets/__pycache__/dataset_wrappers.cpython-311.pyc differ diff --git a/masa/datasets/__pycache__/masa_dataset.cpython-311.pyc b/masa/datasets/__pycache__/masa_dataset.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c75da21756bb8f5790de4819ccab80a4bdd2491 Binary files /dev/null and b/masa/datasets/__pycache__/masa_dataset.cpython-311.pyc differ diff --git a/masa/datasets/__pycache__/rsconcat_dataset.cpython-311.pyc b/masa/datasets/__pycache__/rsconcat_dataset.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67cd17bf802fa10775f264725101bc0427fbae60 Binary files /dev/null and b/masa/datasets/__pycache__/rsconcat_dataset.cpython-311.pyc differ diff --git a/masa/datasets/__pycache__/tao_masa_dataset.cpython-311.pyc b/masa/datasets/__pycache__/tao_masa_dataset.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c417019acd32f61edf46b1f607a490a78807ecc Binary files /dev/null and b/masa/datasets/__pycache__/tao_masa_dataset.cpython-311.pyc differ diff --git a/masa/datasets/__pycache__/utils.cpython-311.pyc b/masa/datasets/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c15d9a3418a328bc2967bffb0a3f45794015a0b Binary files /dev/null and b/masa/datasets/__pycache__/utils.cpython-311.pyc differ diff --git a/masa/datasets/bdd_masa_dataset.py b/masa/datasets/bdd_masa_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..59aad57a5cffcf701f6bd4a51d6f325d9c17bb21 --- /dev/null +++ b/masa/datasets/bdd_masa_dataset.py @@ -0,0 +1,102 @@ +from collections import defaultdict +from typing import Any, List, Tuple + +import numpy as np +from mmdet.datasets import BaseVideoDataset +from mmdet.registry import DATASETS + + +@DATASETS.register_module() +class BDDVideoDataset(BaseVideoDataset): + """Dataset for TAO benchmark. + """ + + METAINFO = { + "classes": ( + "pedestrian", + "rider", + "car", + "truck", + "bus", + "train", + "motorcycle", + "bicycle", + ), + "palette": None, + } + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.flag = np.zeros(len(self), dtype=np.uint8) + + def _rand_another(self, idx): + """Get another random index from the same group as the given index.""" + pool = np.where(self.flag == self.flag[idx])[0] + return np.random.choice(pool) + + def prepare_data(self, idx) -> Any: + """Get date processed by ``self.pipeline``. Note that ``idx`` is a + video index in default since the base element of video dataset is a + video. However, in some cases, we need to specific both the video index + and frame index. For example, in traing mode, we may want to sample the + specific frames and all the frames must be sampled once in a epoch; in + test mode, we may want to output data of a single image rather than the + whole video for saving memory. + + Args: + idx (int): The index of ``data_info``. + + Returns: + Any: Depends on ``self.pipeline``. + """ + if isinstance(idx, tuple): + assert len(idx) == 2, "The length of idx must be 2: " + "(video_index, frame_index)" + video_idx, frame_idx = idx[0], idx[1] + else: + video_idx, frame_idx = idx, None + + data_info = self.get_data_info(video_idx) + if self.test_mode: + # Support two test_mode: frame-level and video-level + final_data_info = defaultdict(list) + if frame_idx is None: + frames_idx_list = list(range(data_info["video_length"])) + else: + frames_idx_list = [frame_idx] + for index in frames_idx_list: + frame_ann = data_info["images"][index] + frame_ann["video_id"] = data_info["video_id"] + # Collate data_list (list of dict to dict of list) + for key, value in frame_ann.items(): + final_data_info[key].append(value) + # copy the info in video-level into img-level + # TODO: the value of this key is the same as that of + # `video_length` in test mode + final_data_info["ori_video_length"].append(data_info["video_length"]) + + final_data_info["video_length"] = [len(frames_idx_list)] * len( + frames_idx_list + ) + return self.pipeline(final_data_info) + else: + # Specify `key_frame_id` for the frame sampling in the pipeline + if frame_idx is not None: + data_info["key_frame_id"] = frame_idx + + for i in range(self.max_refetch): + # To confirm the results passed the training pipeline + # of the wrapper is not None. + try: + + data = self.pipeline(data_info) + except Exception as e: + print("Error occurred while running pipeline", f" with error: {e}") + # print('Empty instances due to augmentation, re-sampling...') + video_idx = self._rand_another(video_idx) + data_info = self.get_data_info(video_idx) + continue + + if data is not None: + break + return data diff --git a/masa/datasets/dataset_wrappers.py b/masa/datasets/dataset_wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..187a426e6a0d7b1d2409e7b1ad1f7794264dc008 --- /dev/null +++ b/masa/datasets/dataset_wrappers.py @@ -0,0 +1,418 @@ +import collections +import copy +import random +from typing import List, Sequence, Union + +import numpy as np +from mmdet.datasets.base_det_dataset import BaseDetDataset +from mmdet.datasets.base_video_dataset import BaseVideoDataset +from mmdet.registry import DATASETS, TRANSFORMS +from mmengine.dataset import BaseDataset, force_full_init + +from .rsconcat_dataset import RandomSampleJointVideoConcatDataset + + +@DATASETS.register_module(force=True) +class SeqMultiImageMixDataset: + """A wrapper of multiple images mixed dataset. + + Suitable for training on multiple images mixed data augmentation like + mosaic and mixup. For the augmentation pipeline of mixed image data, + the `get_indexes` method needs to be provided to obtain the image + indexes, and you can set `skip_flags` to change the pipeline running + process. At the same time, we provide the `dynamic_scale` parameter + to dynamically change the output image size. + + Args: + dataset (:obj:`CustomDataset`): The dataset to be mixed. + pipeline (Sequence[dict]): Sequence of transform object or + config dict to be composed. + dynamic_scale (tuple[int], optional): The image scale can be changed + dynamically. Default to None. It is deprecated. + skip_type_keys (list[str], optional): Sequence of type string to + be skip pipeline. Default to None. + max_refetch (int): The maximum number of retry iterations for getting + valid results from the pipeline. If the number of iterations is + greater than `max_refetch`, but results is still None, then the + iteration is terminated and raise the error. Default: 15. + """ + + def __init__( + self, + dataset: Union[BaseDataset, dict], + pipeline: Sequence[str], + skip_type_keys: Union[Sequence[str], None] = None, + max_refetch: int = 15, + lazy_init: bool = False, + ) -> None: + assert isinstance(pipeline, collections.abc.Sequence) + if skip_type_keys is not None: + assert all( + [isinstance(skip_type_key, str) for skip_type_key in skip_type_keys] + ) + self._skip_type_keys = skip_type_keys + + self.pipeline = [] + self.pipeline_types = [] + for transform in pipeline: + if isinstance(transform, dict): + self.pipeline_types.append(transform["type"]) + transform = TRANSFORMS.build(transform) + self.pipeline.append(transform) + else: + raise TypeError("pipeline must be a dict") + + self.dataset: BaseDataset + if isinstance(dataset, dict): + self.dataset = DATASETS.build(dataset) + elif isinstance(dataset, BaseDataset): + self.dataset = dataset + else: + raise TypeError( + "elements in datasets sequence should be config or " + f"`BaseDataset` instance, but got {type(dataset)}" + ) + + self._metainfo = self.dataset.metainfo + if hasattr(self.dataset, "flag"): + self.flag = self.dataset.flag + self.num_samples = len(self.dataset) + self.max_refetch = max_refetch + + self._fully_initialized = False + if not lazy_init: + self.full_init() + + self.generate_indices() + + def generate_indices(self): + cat_datasets = self.dataset.datasets + for dataset in cat_datasets: + self.test_mode = dataset.test_mode + assert not self.test_mode, "'ConcatDataset' should not exist in " + "test mode" + video_indices = [] + img_indices = [] + if isinstance(dataset, BaseVideoDataset): + num_videos = len(dataset) + for video_ind in range(num_videos): + video_indices.extend( + [ + (video_ind, frame_ind) + for frame_ind in range(dataset.get_len_per_video(video_ind)) + ] + ) + elif isinstance(dataset, BaseDetDataset): + num_imgs = len(dataset) + for img_ind in range(num_imgs): + img_indices.extend([img_ind]) + + ###### special process to make debug task easier ##### + def alternate_merge(list1, list2): + # Create a new list to hold the merged elements + merged_list = [] + + # Get the length of the shorter list + min_length = min(len(list1), len(list2)) + + # Append elements alternately from both lists + for i in range(min_length): + merged_list.append(list1[i]) + merged_list.append(list2[i]) + + # Append the remaining elements from the longer list + if len(list1) > len(list2): + merged_list.extend(list1[min_length:]) + else: + merged_list.extend(list2[min_length:]) + + return merged_list + + self.indices = alternate_merge(img_indices, video_indices) + + @property + def metainfo(self) -> dict: + """Get the meta information of the multi-image-mixed dataset. + + Returns: + dict: The meta information of multi-image-mixed dataset. + """ + return copy.deepcopy(self._metainfo) + + def full_init(self): + """Loop to ``full_init`` each dataset.""" + if self._fully_initialized: + return + + self.dataset.full_init() + self._ori_len = len(self.dataset) + self._fully_initialized = True + + @force_full_init + def get_data_info(self, idx: int) -> dict: + """Get annotation by index. + + Args: + idx (int): Global index of ``ConcatDataset``. + + Returns: + dict: The idx-th annotation of the datasets. + """ + return self.dataset.get_data_info(idx) + + @force_full_init + def get_transform_indexes(self, transform, results, t_type="SeqMosaic"): + num_samples = len(results["img_id"]) + for i in range(self.max_refetch): + # Make sure the results passed the loading pipeline + # of the original dataset is not None. + indexes = transform.get_indexes(self.dataset) + if not isinstance(indexes, collections.abc.Sequence): + indexes = [indexes] + mix_results = [copy.deepcopy(self.dataset[index]) for index in indexes] + if None not in mix_results: + if t_type == "SeqMosaic": + results["mosaic_mix_results"] = [mix_results] * num_samples + elif t_type == "SeqMixUp": + results["mixup_mix_results"] = [mix_results] * num_samples + elif t_type == "SeqCopyPaste": + results["copypaste_mix_results"] = [mix_results] * num_samples + return results + else: + raise RuntimeError( + "The loading pipeline of the original dataset" + " always return None. Please check the correctness " + "of the dataset and its pipeline." + ) + + @force_full_init + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + while True: + results = copy.deepcopy(self.dataset[idx]) + + for (transform, transform_type) in zip(self.pipeline, self.pipeline_types): + if ( + self._skip_type_keys is not None + and transform_type in self._skip_type_keys + ): + continue + if transform_type == "MasaTransformBroadcaster": + for sub_transform in transform.transforms: + if hasattr(sub_transform, "get_indexes"): + sub_transform_type = type(sub_transform).__name__ + results = self.get_transform_indexes( + sub_transform, results, sub_transform_type + ) + + elif hasattr(transform, "get_indexes"): + for i in range(self.max_refetch): + # Make sure the results passed the loading pipeline + # of the original dataset is not None. + indexes = transform.get_indexes(self.dataset) + if not isinstance(indexes, collections.abc.Sequence): + indexes = [indexes] + mix_results = [ + copy.deepcopy(self.dataset[index]) for index in indexes + ] + if None not in mix_results: + results["mix_results"] = mix_results + break + else: + raise RuntimeError( + "The loading pipeline of the original dataset" + " always return None. Please check the correctness " + "of the dataset and its pipeline." + ) + + for i in range(self.max_refetch): + # To confirm the results passed the training pipeline + # of the wrapper is not None. + try: + updated_results = transform(copy.deepcopy(results)) + except Exception as e: + print( + "Error occurred while running pipeline", + f"{transform} with error: {e}", + ) + # print('Empty instances due to augmentation, re-sampling...') + idx = self._rand_another(idx) + continue + if updated_results is not None: + results = updated_results + break + else: + raise RuntimeError( + "The training pipeline of the dataset wrapper" + " always return None.Please check the correctness " + "of the dataset and its pipeline." + ) + + if "mosaic_mix_results" in results: + results.pop("mosaic_mix_results") + + if "mixup_mix_results" in results: + results.pop("mixup_mix_results") + + if "copypaste_mix_results" in results: + results.pop("copypaste_mix_results") + + return results + + def update_skip_type_keys(self, skip_type_keys): + """Update skip_type_keys. It is called by an external hook. + + Args: + skip_type_keys (list[str], optional): Sequence of type + string to be skip pipeline. + """ + assert all([isinstance(skip_type_key, str) for skip_type_key in skip_type_keys]) + self._skip_type_keys = skip_type_keys + + def _rand_another(self, idx): + """Get another random index from the same group as the given index.""" + return np.random.choice(self.indices) + + +@DATASETS.register_module() +class SeqRandomMultiImageVideoMixDataset(SeqMultiImageMixDataset): + def __init__( + self, video_pipeline: Sequence[str], video_sample_ratio=0.5, *args, **kwargs + ): + super().__init__(*args, **kwargs) + + self.video_pipeline = [] + self.video_pipeline_types = [] + for transform in video_pipeline: + if isinstance(transform, dict): + self.video_pipeline_types.append(transform["type"]) + transform = TRANSFORMS.build(transform) + self.video_pipeline.append(transform) + else: + raise TypeError("pipeline must be a dict") + + self.video_sample_ratio = video_sample_ratio + assert isinstance(self.dataset, RandomSampleJointVideoConcatDataset) + + @force_full_init + def get_transform_indexes( + self, transform, results, sample_video, t_type="SeqMosaic" + ): + num_samples = len(results["img_id"]) + for i in range(self.max_refetch): + # Make sure the results passed the loading pipeline + # of the original dataset is not None. + + indexes = transform.get_indexes(self.dataset.datasets[0]) + if not isinstance(indexes, collections.abc.Sequence): + indexes = [indexes] + if sample_video: + mix_results = [copy.deepcopy(self.dataset[0]) for index in indexes] + else: + mix_results = [copy.deepcopy(self.dataset[1]) for index in indexes] + + if None not in mix_results: + if t_type == "SeqMosaic": + results["mosaic_mix_results"] = [mix_results] * num_samples + elif t_type == "SeqMixUp": + results["mixup_mix_results"] = [mix_results] * num_samples + elif t_type == "SeqCopyPaste": + results["copypaste_mix_results"] = [mix_results] * num_samples + return results + else: + raise RuntimeError( + "The loading pipeline of the original dataset" + " always return None. Please check the correctness " + "of the dataset and its pipeline." + ) + + def __getitem__(self, idx): + + while True: + if random.random() < self.video_sample_ratio: + sample_video = True + else: + sample_video = False + if sample_video: + results = copy.deepcopy(self.dataset[0]) + pipeline = self.video_pipeline + pipeline_type = self.video_pipeline_types + + else: + results = copy.deepcopy(self.dataset[1]) + pipeline = self.pipeline + pipeline_type = self.pipeline_types + # if results['img_id'][0] != results['img_id'][1]: + # self.update_skip_type_keys(['SeqMosaic', 'SeqMixUp']) + # else: + # self._skip_type_keys = None + + for (transform, transform_type) in zip(pipeline, pipeline_type): + if ( + self._skip_type_keys is not None + and transform_type in self._skip_type_keys + ): + continue + if transform_type == "MasaTransformBroadcaster": + for sub_transform in transform.transforms: + if hasattr(sub_transform, "get_indexes"): + sub_transform_type = type(sub_transform).__name__ + results = self.get_transform_indexes( + sub_transform, results, sample_video, sub_transform_type + ) + + elif hasattr(transform, "get_indexes"): + for i in range(self.max_refetch): + # Make sure the results passed the loading pipeline + # of the original dataset is not None. + indexes = transform.get_indexes(self.dataset) + if not isinstance(indexes, collections.abc.Sequence): + indexes = [indexes] + mix_results = [ + copy.deepcopy(self.dataset[index]) for index in indexes + ] + if None not in mix_results: + results["mix_results"] = mix_results + break + else: + raise RuntimeError( + "The loading pipeline of the original dataset" + " always return None. Please check the correctness " + "of the dataset and its pipeline." + ) + + for i in range(self.max_refetch): + # To confirm the results passed the training pipeline + # of the wrapper is not None. + try: + updated_results = transform(copy.deepcopy(results)) + except Exception as e: + print( + "Error occurred while running pipeline", + f"{transform} with error: {e}", + ) + # print('Empty instances due to augmentation, re-sampling...') + # idx = self._rand_another(idx) + continue + if updated_results is not None: + results = updated_results + break + else: + raise RuntimeError( + "The training pipeline of the dataset wrapper" + " always return None.Please check the correctness " + "of the dataset and its pipeline." + ) + + if "mosaic_mix_results" in results: + results.pop("mosaic_mix_results") + + if "mixup_mix_results" in results: + results.pop("mixup_mix_results") + + if "copypaste_mix_results" in results: + results.pop("copypaste_mix_results") + + return results diff --git a/masa/datasets/evaluation/__init__.py b/masa/datasets/evaluation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9702f0533208a50e9670059472ccb8d3b30e87c9 --- /dev/null +++ b/masa/datasets/evaluation/__init__.py @@ -0,0 +1,4 @@ +from .bdd_teta_metric import BDDTETAMetric +from .tao_teta_metric import TaoTETAMetric + +__all__ = ["TaoTETAMetric", "BDDTETAMetric"] diff --git a/masa/datasets/evaluation/__pycache__/__init__.cpython-311.pyc b/masa/datasets/evaluation/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16688069be751a53dca9e1b8f0239f023556dce6 Binary files /dev/null and b/masa/datasets/evaluation/__pycache__/__init__.cpython-311.pyc differ diff --git a/masa/datasets/evaluation/__pycache__/bdd_teta_metric.cpython-311.pyc b/masa/datasets/evaluation/__pycache__/bdd_teta_metric.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..288a99b4cd836b621c1797b3ca71c75229720bd6 Binary files /dev/null and b/masa/datasets/evaluation/__pycache__/bdd_teta_metric.cpython-311.pyc differ diff --git a/masa/datasets/evaluation/__pycache__/tao_teta_metric.cpython-311.pyc b/masa/datasets/evaluation/__pycache__/tao_teta_metric.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7fc620f5394bad7952f69157c5dfeeb0d2a992da Binary files /dev/null and b/masa/datasets/evaluation/__pycache__/tao_teta_metric.cpython-311.pyc differ diff --git a/masa/datasets/evaluation/__pycache__/utils.cpython-311.pyc b/masa/datasets/evaluation/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9cd6abf8abb30b370201baaf46a78dba72d843b6 Binary files /dev/null and b/masa/datasets/evaluation/__pycache__/utils.cpython-311.pyc differ diff --git a/masa/datasets/evaluation/bdd_teta_metric.py b/masa/datasets/evaluation/bdd_teta_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..d68735f34f85c4892ceaeabce36d56d20d79345d --- /dev/null +++ b/masa/datasets/evaluation/bdd_teta_metric.py @@ -0,0 +1,614 @@ +import json +import os +import os.path as osp +import shutil +import tempfile +import time +from collections import defaultdict +from itertools import chain +from multiprocessing import Pool +from typing import List, Optional, Sequence, Union + +import pandas as pd +import torch +import tqdm + +try: + import teta +except ImportError: + teta = None +from pathlib import Path + +import mmengine +import mmengine.fileio as fileio +from mmdet.datasets.api_wrappers import COCO +from mmdet.evaluation.metrics.base_video_metric import BaseVideoMetric +from mmdet.registry import METRICS, TASK_UTILS +from mmengine.dist import (all_gather_object, barrier, broadcast, + broadcast_object_list, get_dist_info, + is_main_process) +from mmengine.logging import MMLogger +from scalabel.eval.box_track import BoxTrackResult, bdd100k_to_scalabel +from scalabel.eval.hota import HOTAResult, evaluate_track_hota +from scalabel.eval.hotas import evaluate_seg_track_hota +from scalabel.eval.mot import TrackResult, acc_single_video_mot, evaluate_track +from scalabel.eval.mots import acc_single_video_mots, evaluate_seg_track +from scalabel.eval.teta import TETAResult, evaluate_track_teta +from scalabel.eval.tetas import evaluate_seg_track_teta +from scalabel.label.io import group_and_sort, load, load_label_config + +from .utils import mask_postprocess, mask_prepare + +cpu_num = os.cpu_count() +NPROC: int = min(4, cpu_num if cpu_num else 1) + +MOT_CFG_FILE = os.path.join( + str(Path(__file__).parent.absolute()), "dataset_configs/box_track.toml" +) +MOTS_CFG_FILE = os.path.join( + str(Path(__file__).parent.absolute()), "dataset_configs/seg_track.toml" +) + + +def get_tmpdir() -> str: + """return the same tmpdir for all processes.""" + rank, world_size = get_dist_info() + MAX_LEN = 512 + # 32 is whitespace + dir_tensor = torch.full((MAX_LEN,), 32, dtype=torch.uint8) + if rank == 0: + tmpdir = tempfile.mkdtemp() + tmpdir = torch.tensor(bytearray(tmpdir.encode()), dtype=torch.uint8) + dir_tensor[: len(tmpdir)] = tmpdir + broadcast(dir_tensor, 0) + tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip() + return tmpdir + + +@METRICS.register_module() +class BDDTETAMetric(BaseVideoMetric): + """Evaluation metrics for MOT Challenge. + + Args: + metric (str | list[str]): Metrics to be evaluated. Options are + 'HOTA', 'CLEAR', 'Identity'. + Defaults to ['HOTA', 'CLEAR', 'Identity']. + outfile_prefix (str, optional): Path to save the formatted results. + Defaults to None. + track_iou_thr (float): IoU threshold for tracking evaluation. + Defaults to 0.5. + benchmark (str): Benchmark to be evaluated. Defaults to 'MOT17'. + format_only (bool): If True, only formatting the results to the + official format and not performing evaluation. Defaults to False. + postprocess_tracklet_cfg (List[dict], optional): configs for tracklets + postprocessing methods. `InterpolateTracklets` is supported. + Defaults to [] + - InterpolateTracklets: + - min_num_frames (int, optional): The minimum length of a + track that will be interpolated. Defaults to 5. + - max_num_frames (int, optional): The maximum disconnected + length in a track. Defaults to 20. + - use_gsi (bool, optional): Whether to use the GSI (Gaussian- + smoothed interpolation) method. Defaults to False. + - smooth_tau (int, optional): smoothing parameter in GSI. + Defaults to 10. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Default: None + Returns: + """ + + TRACKER = "masa-tracker" + allowed_metrics = ["TETA", "HOTA", "CLEAR"] + default_prefix: Optional[str] = "tao_teta_metric" + + def __init__( + self, + metric: Union[str, List[str]] = ["TETA", "HOTA", "CLEAR"], + outfile_prefix: Optional[str] = None, + track_iou_thr: float = 0.5, + format_only: bool = False, + ann_file: Optional[str] = None, + scalabel_gt: Optional[str] = None, + dataset_type: str = "BDDVideoDataset", + use_postprocess: bool = False, + postprocess_tracklet_cfg: Optional[List[dict]] = [], + collect_device: str = "cpu", + tcc: bool = True, + scalabel_format=True, + with_mask=False, + prefix: Optional[str] = None, + ) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + if teta is None: + raise RuntimeError( + "teta is not installed," + "please install it by: python -m pip install git+https://github.com/SysCV/tet.git/#subdirectory=teta " + ) + + if isinstance(metric, list): + metrics = metric + elif isinstance(metric, str): + metrics = [metric] + else: + raise TypeError("metric must be a list or a str.") + for metric in metrics: + if metric not in self.allowed_metrics: + raise KeyError(f"metric {metric} is not supported.") + self.metrics = metrics + self.format_only = format_only + self.scalabel_format = scalabel_format + if self.format_only: + assert outfile_prefix is not None, "outfile_prefix must be not" + "None when format_only is True, otherwise the result files will" + "be saved to a temp directory which will be cleaned up at the end." + self.use_postprocess = use_postprocess + self.postprocess_tracklet_cfg = postprocess_tracklet_cfg.copy() + self.postprocess_tracklet_methods = [ + TASK_UTILS.build(cfg) for cfg in self.postprocess_tracklet_cfg + ] + self.track_iou_thr = track_iou_thr + self.tmp_dir = tempfile.TemporaryDirectory() + self.tmp_dir.name = get_tmpdir() + self.seq_pred = defaultdict(lambda: []) + self.gt_dir = self._get_gt_dir() + self.pred_dir = self._get_pred_dir(outfile_prefix) + self.outfile_prefix = outfile_prefix + + self.ann_file = ann_file + self.scalabel_gt = scalabel_gt + self.tcc = tcc + self.with_mask = with_mask + + with fileio.get_local_path(self.ann_file) as local_path: + self.coco = COCO(local_path) + + # get the class list according to the dataset type + assert dataset_type in ["BDDVideoDataset"] + if dataset_type == "BDDVideoDataset": + from masa.datasets import BDDVideoDataset + + self.class_list = BDDVideoDataset.METAINFO["classes"] + + self.cat_ids = self.coco.get_cat_ids(cat_names=self.class_list) + + def __del__(self): + # To avoid tmpdir being cleaned up too early, because in multiple + # consecutive ValLoops, the value of `self.tmp_dir.name` is unchanged, + # and calling `tmp_dir.cleanup()` in compute_metrics will cause errors. + self.tmp_dir.cleanup() + + def _get_pred_dir(self, outfile_prefix): + """Get directory to save the prediction results.""" + logger: MMLogger = MMLogger.get_current_instance() + + if outfile_prefix is None: + outfile_prefix = self.tmp_dir.name + else: + if osp.exists(outfile_prefix) and is_main_process(): + logger.info("remove previous results.") + shutil.rmtree(outfile_prefix) + pred_dir = osp.join(outfile_prefix, self.TRACKER) + os.makedirs(pred_dir, exist_ok=True) + return pred_dir + + def _get_gt_dir(self): + """Get directory to save the gt files.""" + output_dir = osp.join(self.tmp_dir.name, "gt") + os.makedirs(output_dir, exist_ok=True) + return output_dir + + def transform_gt_and_pred(self, img_data_sample): + + # load predictions + assert "pred_track_instances" in img_data_sample + pred_instances = img_data_sample["pred_track_instances"] + + pred_instances_list = [] + + for i in range(len(pred_instances["instances_id"])): + data_dict = dict() + data_dict["image_id"] = img_data_sample["img_id"] + data_dict["track_id"] = int(pred_instances["instances_id"][i]) + data_dict["bbox"] = self.xyxy2xywh(pred_instances["bboxes"][i]) + data_dict["score"] = float(pred_instances["scores"][i]) + data_dict["category_id"] = self.cat_ids[pred_instances["labels"][i]] + data_dict["video_id"] = img_data_sample["video_id"] + if self.with_mask: + if isinstance(pred_instances["masks"][i]["counts"], bytes): + pred_instances["masks"][i]["counts"] = pred_instances["masks"][i][ + "counts" + ].decode() + data_dict["segmentation"] = pred_instances["masks"][i] + pred_instances_list.append(data_dict) + + return pred_instances_list + + def process_image(self, data_samples, video_len): + + img_data_sample = data_samples[0].to_dict() + video_id = img_data_sample["video_id"] + pred_instances_list = self.transform_gt_and_pred(img_data_sample) + self.seq_pred[video_id].extend(pred_instances_list) + + def process_video(self, data_samples): + + video_len = len(data_samples) + for frame_id in range(video_len): + img_data_sample = data_samples[frame_id].to_dict() + # load basic info + video_id = img_data_sample["video_id"] + pred_instances_list = self.transform_gt_and_pred(img_data_sample) + self.seq_pred[video_id].extend(pred_instances_list) + + def compute_metrics(self, results: list = None) -> dict: + + logger: MMLogger = MMLogger.get_current_instance() + + eval_results = dict() + + if self.format_only: + logger.info("Only formatting results to the official format.") + return eval_results + + resfile_path = os.path.join( + self.outfile_prefix, "bdd_track_scalabel_format.json" + ) + + bdd100k_config = load_label_config(MOT_CFG_FILE) + print("Start loading.") + + gts = group_and_sort(load(self.scalabel_gt).frames) + results = group_and_sort(load(resfile_path).frames) + print("gt_len", len(gts), "results", len(results)) + print("Finish loading.") + print("Start evaluation") + print("Ignore unknown cats") + + logger.info("Tracking evaluation.") + t = time.time() + gts = [bdd100k_to_scalabel(gt, bdd100k_config) for gt in gts] + results = [bdd100k_to_scalabel(result, bdd100k_config) for result in results] + + if "CLEAR" in self.metrics: + if self.with_mask: + mot_result = evaluate_seg_track( + acc_single_video_mots, + gts, + results, + bdd100k_config, + ignore_unknown_cats=True, + nproc=NPROC, + ) + + else: + + mot_result = evaluate_track( + acc_single_video_mot, + gts, + results, + bdd100k_config, + ignore_unknown_cats=True, + nproc=NPROC, + ) + print("CLEAR and IDF1 results :") + print(mot_result) + print(mot_result.summary()) + + if "HOTA" in self.metrics: + if self.with_mask: + hota_result = evaluate_seg_track_hota( + gts, results, bdd100k_config, NPROC + ) + else: + hota_result = evaluate_track_hota(gts, results, bdd100k_config, NPROC) + print("HOTA results :") + print(hota_result) + print(hota_result.summary()) + + if "TETA" in self.metrics: + if self.with_mask: + teta_result = evaluate_seg_track_teta( + gts, results, bdd100k_config, NPROC + ) + else: + teta_result = evaluate_track_teta(gts, results, bdd100k_config, NPROC) + + print("TETA results :") + print(teta_result) + print(teta_result.summary()) + + if ( + "CLEAR" in self.metrics + and "HOTA" in self.metrics + and "TETA" in self.metrics + ): + print("Aggregated results: ") + combined_result = BoxTrackResult( + **{**mot_result.dict(), **hota_result.dict(), **teta_result.dict()} + ) + print(combined_result) + print(combined_result.summary()) + + t = time.time() - t + logger.info("evaluation finishes with %.1f s.", t) + + print("Completed evaluation") + return eval_results + + def evaluate(self, size: int = 1) -> dict: + """Evaluate the model performance of the whole dataset after processing + all batches. + + Args: + size (int): Length of the entire validation dataset. + Defaults to None. + + Returns: + dict: Evaluation metrics dict on the val dataset. The keys are the + names of the metrics, and the values are corresponding results. + + """ + logger: MMLogger = MMLogger.get_current_instance() + + logger.info(f"Wait for all processes to complete prediction.") + # wait for all processes to complete prediction. + barrier() + + logger.info(f"Start gathering tracking results.") + + # gather seq_info and convert the list of dict to a dict. + # convert self.seq_info to dict first to make it picklable. + gathered_seq_info = all_gather_object(dict(self.seq_pred)) + + if is_main_process(): + + all_seq_pred = dict() + for _seq_info in gathered_seq_info: + all_seq_pred.update(_seq_info) + all_seq_pred = self.compute_global_track_id(all_seq_pred) + + # merge all the values (list of pred in each videos) into a single long list + all_seq_pred_json = list(chain.from_iterable(all_seq_pred.values())) + + if self.scalabel_format: + all_seq_pred_json = self.format_scalabel_pred(all_seq_pred_json) + + result_files_path = ( + f"{self.outfile_prefix}/bdd_track_scalabel_format.json" + ) + + logger.info(f"Saving json pred file into {result_files_path}") + mmengine.dump(all_seq_pred_json, result_files_path) + else: + if self.tcc and all_seq_pred_json: + all_seq_pred_json = self.majority_vote(all_seq_pred_json) + else: + all_seq_pred_json = all_seq_pred_json + + result_files_path = f"{self.outfile_prefix}/bdd_track_cocofmt.json" + + logger.info(f"Saving json pred file into {result_files_path}") + mmengine.dump(all_seq_pred_json, result_files_path) + + logger.info(f"Start evaluation") + + _metrics = self.compute_metrics() + + # Add prefix to metric names + if self.prefix: + _metrics = {"/".join((self.prefix, k)): v for k, v in _metrics.items()} + metrics = [_metrics] + else: + metrics = [None] # type: ignore + + broadcast_object_list(metrics) + self.seq_pred.clear() + + return metrics[0] + + def format_scalabel_pred(self, all_seq_pred_json): + """Convert the prediction results to the format of Scalabel. + + Args: + all_seq_pred_json (list): The prediction results. + + Returns: + list: The formatted prediction results. + """ + + bdd_scalabel_gt = json.load(open(self.ann_file)) + bdd_cid_cinfo_mapping = {} + for c in bdd_scalabel_gt["categories"]: + if c["id"] not in bdd_cid_cinfo_mapping: + bdd_cid_cinfo_mapping[c["id"]] = c + # imid info mapping + imid_iminfo_mapping = {} + for i in bdd_scalabel_gt["images"]: + if i["id"] not in imid_iminfo_mapping: + imid_iminfo_mapping[i["id"]] = i + # vidid info mapping + vid_vinfo_mapping = {} + for i in bdd_scalabel_gt["videos"]: + if i["id"] not in vid_vinfo_mapping: + vid_vinfo_mapping[i["id"]] = i + + if self.tcc and all_seq_pred_json: + mc_res = self.majority_vote(all_seq_pred_json) + else: + mc_res = all_seq_pred_json + + imid_results_mapping = self.convert_coco_result_to_bdd( + mc_res, bdd_cid_cinfo_mapping, imid_iminfo_mapping, vid_vinfo_mapping, + ) + + if self.with_mask: + scalabel_results = self.overlapping_masks_removal(imid_results_mapping) + else: + scalabel_results = list(imid_results_mapping.values()) + + return scalabel_results + + def overlapping_masks_removal(self, imid_results_mapping, nproc=NPROC): + + with Pool(nproc) as pool: + print("\nCollecting mask information") + mask_infors = pool.map( + mask_prepare, tqdm.tqdm(list(imid_results_mapping.values())) + ) + + print("\nRemoving overlaps and retrieving valid masks and indexes.") + results = pool.starmap(mask_postprocess, mask_infors) + + return results + + def compute_global_track_id(self, all_seq_pred): + + max_track_id = 0 + + for video_id, seq_pred in all_seq_pred.items(): + track_ids = [] + + for frame_pred in seq_pred: + track_ids.append(frame_pred["track_id"]) + frame_pred["track_id"] += max_track_id + track_ids = list(set(track_ids)) + + if track_ids: + max_track_id += max(track_ids) + 1 + + return all_seq_pred + + def majority_vote(self, prediction): + + tid_res_mapping = {} + for res in prediction: + tid = res["track_id"] + if tid not in tid_res_mapping: + tid_res_mapping[tid] = [res] + else: + tid_res_mapping[tid].append(res) + # change the results to data frame + df_pred_res = pd.DataFrame(prediction) + # group the results by track_id + # df_pred_res = df_pred_res.apply(changebbox, axis=1) + groued_df_pred_res = df_pred_res.groupby("track_id") + + # change the majority + class_by_majority_count_res = [] + for tid, group in tqdm.tqdm(groued_df_pred_res): + cid = group["category_id"].mode()[0] + group["category_id"] = cid + dict_list = group.to_dict("records") + class_by_majority_count_res += dict_list + return class_by_majority_count_res + + def xyxy2xywh(self, bbox): + """Convert ``xyxy`` style bounding boxes to ``xywh`` style for COCO + evaluation. + + Args: + bbox (numpy.ndarray): The bounding boxes, shape (4, ), in + ``xyxy`` order. + + Returns: + list[float]: The converted bounding boxes, in ``xywh`` order. + """ + + _bbox = bbox.tolist() + return [ + _bbox[0], + _bbox[1], + _bbox[2] - _bbox[0], + _bbox[3] - _bbox[1], + ] + + def convert_pred_to_label_format(self, coco_pred, bdd_cid_cinfo_mapping): + """ + convert the single prediction result to label format for bdd + + coco_pred: + 'image_id': 1, + 'bbox': [998.872802734375, + 379.5665283203125, + 35.427490234375, + 59.21759033203125], + 'score': 0.9133418202400208, + 'category_id': 1, + 'video_id': 1, + 'track_id': 16 + + - labels [ ]: list of dicts + - id: string + - category: string + - box2d: + - x1: float + - y1: float + - x2: float + - y2: float + Args: + coco_pred: coco_pred dict. + bdd_cid_cinfo_mapping: bdd category id to category infomation mapping. + Return: + a new dict in bdd format. + """ + new_label = {} + new_label["id"] = coco_pred["track_id"] + new_label["score"] = coco_pred["score"] + new_label["category"] = bdd_cid_cinfo_mapping[coco_pred["category_id"]]["name"] + new_label["box2d"] = { + "x1": coco_pred["bbox"][0], + "y1": coco_pred["bbox"][1], + "x2": coco_pred["bbox"][0] + coco_pred["bbox"][2], + "y2": coco_pred["bbox"][1] + coco_pred["bbox"][3], + } + if "segmentation" in coco_pred: + new_label["rle"] = coco_pred["segmentation"] + + return new_label + + def convert_coco_result_to_bdd( + self, new_pred, bdd_cid_cinfo_mapping, imid_iminfo_mapping, vid_vinfo_mapping + ): + """ + Args: + new_pred: list of coco predictions + bdd_cid_cinfo_mapping: bdd category id to category infomation mapping. + Return: + submitable result for bdd eval + """ + + imid_new_dict_mapping = {} + for item in tqdm.tqdm(new_pred): + imid = item["image_id"] + if imid not in imid_new_dict_mapping: + new_dict = {} + new_dict["name"] = imid_iminfo_mapping[imid]["file_name"] + new_dict["videoName"] = vid_vinfo_mapping[ + imid_iminfo_mapping[imid]["video_id"] + ]["name"] + new_dict["frameIndex"] = imid_iminfo_mapping[imid]["frame_id"] + new_dict["labels"] = [ + self.convert_pred_to_label_format(item, bdd_cid_cinfo_mapping) + ] + imid_new_dict_mapping[imid] = new_dict + else: + imid_new_dict_mapping[imid]["labels"].append( + self.convert_pred_to_label_format(item, bdd_cid_cinfo_mapping) + ) + for key in imid_iminfo_mapping: + if key not in imid_new_dict_mapping: + new_dict = {} + new_dict["name"] = imid_iminfo_mapping[key]["file_name"] + new_dict["videoName"] = vid_vinfo_mapping[ + imid_iminfo_mapping[key]["video_id"] + ]["name"] + new_dict["frameIndex"] = imid_iminfo_mapping[key]["frame_id"] + new_dict["labels"] = [] + imid_new_dict_mapping[key] = new_dict + + return imid_new_dict_mapping diff --git a/masa/datasets/evaluation/dataset_configs/box_track.toml b/masa/datasets/evaluation/dataset_configs/box_track.toml new file mode 100644 index 0000000000000000000000000000000000000000..2edcb999c4dae833e2c639ddde2995cb1133976f --- /dev/null +++ b/masa/datasets/evaluation/dataset_configs/box_track.toml @@ -0,0 +1,38 @@ +[image_size] +height = 720 +width = 1280 + +[[attributes]] +name = "crowd" +type = "switch" +tag = "c" + +[[categories]] +name = "human" + [[categories.subcategories]] + name = "pedestrian" + + [[categories.subcategories]] + name = "rider" + +[[categories]] +name = "vehicle" + [[categories.subcategories]] + name = "car" + + [[categories.subcategories]] + name = "truck" + + [[categories.subcategories]] + name = "bus" + + [[categories.subcategories]] + name = "train" + +[[categories]] +name = "bike" + [[categories.subcategories]] + name = "motorcycle" + + [[categories.subcategories]] + name = "bicycle" \ No newline at end of file diff --git a/masa/datasets/evaluation/dataset_configs/seg_track.toml b/masa/datasets/evaluation/dataset_configs/seg_track.toml new file mode 100644 index 0000000000000000000000000000000000000000..0af7b452239b574a8fb6b739c5812969aa1b69e3 --- /dev/null +++ b/masa/datasets/evaluation/dataset_configs/seg_track.toml @@ -0,0 +1,38 @@ +[imageSize] +height = 720 +width = 1280 + +[[attributes]] +name = "crowd" +type = "switch" +tag = "c" + +[[categories]] +name = "human" + [[categories.subcategories]] + name = "pedestrian" + + [[categories.subcategories]] + name = "rider" + +[[categories]] +name = "vehicle" + [[categories.subcategories]] + name = "car" + + [[categories.subcategories]] + name = "truck" + + [[categories.subcategories]] + name = "bus" + + [[categories.subcategories]] + name = "train" + +[[categories]] +name = "bike" + [[categories.subcategories]] + name = "motorcycle" + + [[categories.subcategories]] + name = "bicycle" diff --git a/masa/datasets/evaluation/tao_teta_metric.py b/masa/datasets/evaluation/tao_teta_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..f594deb9b498bd880e860994932632074d928487 --- /dev/null +++ b/masa/datasets/evaluation/tao_teta_metric.py @@ -0,0 +1,449 @@ +import os +import os.path as osp +import pickle +import shutil +import tempfile +from collections import defaultdict +from itertools import chain +from typing import List, Optional, Sequence, Union + +import numpy as np +import pandas as pd +import torch +import tqdm + +try: + import teta +except ImportError: + teta = None + +import mmengine +import mmengine.fileio as fileio +from mmdet.datasets.api_wrappers import COCO +from mmdet.evaluation.metrics.base_video_metric import BaseVideoMetric +from mmdet.registry import METRICS, TASK_UTILS +from mmengine.dist import (all_gather_object, barrier, broadcast, + broadcast_object_list, get_dist_info, + is_main_process) +from mmengine.logging import MMLogger + + +def get_tmpdir() -> str: + """return the same tmpdir for all processes.""" + rank, world_size = get_dist_info() + MAX_LEN = 512 + # 32 is whitespace + dir_tensor = torch.full((MAX_LEN,), 32, dtype=torch.uint8) + if rank == 0: + tmpdir = tempfile.mkdtemp() + tmpdir = torch.tensor(bytearray(tmpdir.encode()), dtype=torch.uint8) + dir_tensor[: len(tmpdir)] = tmpdir + broadcast(dir_tensor, 0) + tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip() + return tmpdir + + +@METRICS.register_module() +class TaoTETAMetric(BaseVideoMetric): + """Evaluation metrics for TAO TETA and open-vocabulary MOT benchmark. + + Args: + metric (str | list[str]): Metrics to be evaluated. Options are + 'TETA' + Defaults to ['TETA']. + outfile_prefix (str, optional): Path to save the formatted results. + Defaults to None. + track_iou_thr (float): IoU threshold for tracking evaluation. + Defaults to 0.5. + benchmark (str): Benchmark to be evaluated. Defaults to 'MOT17'. + format_only (bool): If True, only formatting the results to the + official format and not performing evaluation. Defaults to False. + postprocess_tracklet_cfg (List[dict], optional): configs for tracklets + postprocessing methods. `InterpolateTracklets` is supported. + Defaults to [] + - InterpolateTracklets: + - min_num_frames (int, optional): The minimum length of a + track that will be interpolated. Defaults to 5. + - max_num_frames (int, optional): The maximum disconnected + length in a track. Defaults to 20. + - use_gsi (bool, optional): Whether to use the GSI (Gaussian- + smoothed interpolation) method. Defaults to False. + - smooth_tau (int, optional): smoothing parameter in GSI. + Defaults to 10. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Default: None + Returns: + """ + + TRACKER = "masa-tracker" + allowed_metrics = ["TETA"] + default_prefix: Optional[str] = "tao_teta_metric" + + def __init__( + self, + metric: Union[str, List[str]] = ["TETA"], + outfile_prefix: Optional[str] = None, + track_iou_thr: float = 0.5, + format_only: bool = False, + ann_file: Optional[str] = None, + dataset_type: str = "Taov1Dataset", + use_postprocess: bool = False, + postprocess_tracklet_cfg: Optional[List[dict]] = [], + collect_device: str = "cpu", + tcc: bool = True, + open_vocabulary=False, + prefix: Optional[str] = None, + ) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + if teta is None: + raise RuntimeError( + "teta is not installed," + "please install it by: python -m pip install git+https://github.com/SysCV/tet.git/#subdirectory=teta " + ) + + if isinstance(metric, list): + metrics = metric + elif isinstance(metric, str): + metrics = [metric] + else: + raise TypeError("metric must be a list or a str.") + for metric in metrics: + if metric not in self.allowed_metrics: + raise KeyError(f"metric {metric} is not supported.") + self.metrics = metrics + self.format_only = format_only + if self.format_only: + assert outfile_prefix is not None, "outfile_prefix must be not" + "None when format_only is True, otherwise the result files will" + "be saved to a temp directory which will be cleaned up at the end." + self.use_postprocess = use_postprocess + self.postprocess_tracklet_cfg = postprocess_tracklet_cfg.copy() + self.postprocess_tracklet_methods = [ + TASK_UTILS.build(cfg) for cfg in self.postprocess_tracklet_cfg + ] + self.track_iou_thr = track_iou_thr + self.tmp_dir = tempfile.TemporaryDirectory() + self.tmp_dir.name = get_tmpdir() + self.seq_pred = defaultdict(lambda: []) + self.gt_dir = self._get_gt_dir() + self.pred_dir = self._get_pred_dir(outfile_prefix) + self.outfile_prefix = outfile_prefix + + self.ann_file = ann_file + self.tcc = tcc + self.open_vocabulary = open_vocabulary + + with fileio.get_local_path(self.ann_file) as local_path: + self.coco = COCO(local_path) + + # get the class list according to the dataset type + assert dataset_type in ["Taov05Dataset", "Taov1Dataset"] + if dataset_type == "Taov05Dataset": + from masa.datasets import Taov05Dataset + + self.class_list = Taov05Dataset.METAINFO["classes"] + if dataset_type == "Taov1Dataset": + from masa.datasets import Taov1Dataset + + self.class_list = Taov1Dataset.METAINFO["classes"] + self.cat_ids = self.coco.get_cat_ids(cat_names=self.class_list) + + def __del__(self): + # To avoid tmpdir being cleaned up too early, because in multiple + # consecutive ValLoops, the value of `self.tmp_dir.name` is unchanged, + # and calling `tmp_dir.cleanup()` in compute_metrics will cause errors. + self.tmp_dir.cleanup() + + def _get_pred_dir(self, outfile_prefix): + """Get directory to save the prediction results.""" + logger: MMLogger = MMLogger.get_current_instance() + + if outfile_prefix is None: + outfile_prefix = self.tmp_dir.name + else: + if osp.exists(outfile_prefix) and is_main_process(): + logger.info("remove previous results.") + shutil.rmtree(outfile_prefix) + pred_dir = osp.join(outfile_prefix, self.TRACKER) + os.makedirs(pred_dir, exist_ok=True) + return pred_dir + + def _get_gt_dir(self): + """Get directory to save the gt files.""" + output_dir = osp.join(self.tmp_dir.name, "gt") + os.makedirs(output_dir, exist_ok=True) + return output_dir + + def transform_gt_and_pred(self, img_data_sample): + + # load predictions + assert "pred_track_instances" in img_data_sample + pred_instances = img_data_sample["pred_track_instances"] + + pred_instances_list = [] + + for i in range(len(pred_instances["instances_id"])): + data_dict = dict() + data_dict["image_id"] = img_data_sample["img_id"] + data_dict["track_id"] = int(pred_instances["instances_id"][i]) + data_dict["bbox"] = self.xyxy2xywh(pred_instances["bboxes"][i]) + data_dict["score"] = float(pred_instances["scores"][i]) + data_dict["category_id"] = self.cat_ids[pred_instances["labels"][i]] + data_dict["video_id"] = img_data_sample["video_id"] + pred_instances_list.append(data_dict) + + return pred_instances_list + + def process_image(self, data_samples, video_len): + + img_data_sample = data_samples[0].to_dict() + video_id = img_data_sample["video_id"] + pred_instances_list = self.transform_gt_and_pred(img_data_sample) + self.seq_pred[video_id].extend(pred_instances_list) + + def process_video(self, data_samples): + + video_len = len(data_samples) + for frame_id in range(video_len): + img_data_sample = data_samples[frame_id].to_dict() + # load basic info + video_id = img_data_sample["video_id"] + pred_instances_list = self.transform_gt_and_pred(img_data_sample) + self.seq_pred[video_id].extend(pred_instances_list) + + def compute_metrics(self, results: list = None) -> dict: + + logger: MMLogger = MMLogger.get_current_instance() + + eval_results = dict() + + if self.format_only: + logger.info("Only formatting results to the official format.") + return eval_results + + resfile_path = self.outfile_prefix + + # Command line interface: + default_eval_config = teta.config.get_default_eval_config() + # print only combined since TrackMAP is undefined for per sequence breakdowns + default_eval_config["PRINT_ONLY_COMBINED"] = True + default_eval_config["DISPLAY_LESS_PROGRESS"] = True + default_eval_config["OUTPUT_TEM_RAW_DATA"] = True + default_eval_config["NUM_PARALLEL_CORES"] = 8 + default_dataset_config = teta.config.get_default_dataset_config() + default_dataset_config["TRACKERS_TO_EVAL"] = ["MASA"] + default_dataset_config["GT_FOLDER"] = self.ann_file + default_dataset_config["OUTPUT_FOLDER"] = resfile_path + default_dataset_config["TRACKER_SUB_FOLDER"] = os.path.join( + resfile_path, "tao_track.json" + ) + + evaluator = teta.Evaluator(default_eval_config) + dataset_list = [teta.datasets.TAO(default_dataset_config)] + print("Overall classes performance") + eval_results, _ = evaluator.evaluate(dataset_list, [teta.metrics.TETA()]) + + if self.open_vocabulary: + eval_results_path = os.path.join( + resfile_path, "MASA", "teta_summary_results.pth" + ) + eval_res = pickle.load(open(eval_results_path, "rb")) + + base_class_synset = set( + [ + c["name"] + for c in self.coco.dataset["categories"] + if c["frequency"] != "r" + ] + ) + novel_class_synset = set( + [ + c["name"] + for c in self.coco.dataset["categories"] + if c["frequency"] == "r" + ] + ) + + self.compute_teta_on_ovsetup( + eval_res, base_class_synset, novel_class_synset + ) + + return eval_results + + def evaluate(self, size: int = 1) -> dict: + """Evaluate the model performance of the whole dataset after processing + all batches. + + Args: + size (int): Length of the entire validation dataset. + Defaults to None. + + Returns: + dict: Evaluation metrics dict on the val dataset. The keys are the + names of the metrics, and the values are corresponding results. + + """ + logger: MMLogger = MMLogger.get_current_instance() + + logger.info(f"Wait for all processes to complete prediction.") + # wait for all processes to complete prediction. + barrier() + + logger.info(f"Start gathering tracking results.") + + # gather seq_info and convert the list of dict to a dict. + # convert self.seq_info to dict first to make it picklable. + gathered_seq_info = all_gather_object(dict(self.seq_pred)) + + if is_main_process(): + + all_seq_pred = dict() + for _seq_info in gathered_seq_info: + all_seq_pred.update(_seq_info) + all_seq_pred = self.compute_global_track_id(all_seq_pred) + + # merge all the values (list of pred in each videos) into a single long list + all_seq_pred_json = list(chain.from_iterable(all_seq_pred.values())) + + if self.tcc and all_seq_pred_json: + all_seq_pred_json = self.majority_vote(all_seq_pred_json) + + result_files_path = f"{self.outfile_prefix}/tao_track.json" + + logger.info(f"Saving json pred file into {result_files_path}") + mmengine.dump(all_seq_pred_json, result_files_path) + + logger.info(f"Start evaluation") + + _metrics = self.compute_metrics() + + # Add prefix to metric names + if self.prefix: + _metrics = {"/".join((self.prefix, k)): v for k, v in _metrics.items()} + metrics = [_metrics] + else: + metrics = [None] # type: ignore + + broadcast_object_list(metrics) + self.seq_pred.clear() + + return metrics[0] + + def compute_global_track_id(self, all_seq_pred): + + max_track_id = 0 + + for video_id, seq_pred in all_seq_pred.items(): + track_ids = [] + + for frame_pred in seq_pred: + track_ids.append(frame_pred["track_id"]) + frame_pred["track_id"] += max_track_id + track_ids = list(set(track_ids)) + + if track_ids: + max_track_id += max(track_ids) + 1 + + return all_seq_pred + + def majority_vote(self, prediction): + + tid_res_mapping = {} + for res in prediction: + tid = res["track_id"] + if tid not in tid_res_mapping: + tid_res_mapping[tid] = [res] + else: + tid_res_mapping[tid].append(res) + # change the results to data frame + df_pred_res = pd.DataFrame(prediction) + # group the results by track_id + # df_pred_res = df_pred_res.apply(changebbox, axis=1) + groued_df_pred_res = df_pred_res.groupby("track_id") + + # change the majority + class_by_majority_count_res = [] + for tid, group in tqdm.tqdm(groued_df_pred_res): + cid = group["category_id"].mode()[0] + group["category_id"] = cid + dict_list = group.to_dict("records") + class_by_majority_count_res += dict_list + return class_by_majority_count_res + + def xyxy2xywh(self, bbox): + """Convert ``xyxy`` style bounding boxes to ``xywh`` style for COCO + evaluation. + + Args: + bbox (numpy.ndarray): The bounding boxes, shape (4, ), in + ``xyxy`` order. + + Returns: + list[float]: The converted bounding boxes, in ``xywh`` order. + """ + + _bbox = bbox.tolist() + return [ + _bbox[0], + _bbox[1], + _bbox[2] - _bbox[0], + _bbox[3] - _bbox[1], + ] + + def compute_teta_on_ovsetup(self, teta_res, base_class_names, novel_class_names): + if "COMBINED_SEQ" in teta_res: + teta_res = teta_res["COMBINED_SEQ"] + + frequent_teta = [] + rare_teta = [] + for key in teta_res: + if key in base_class_names: + frequent_teta.append(np.array(teta_res[key]["TETA"][50]).astype(float)) + elif key in novel_class_names: + rare_teta.append(np.array(teta_res[key]["TETA"][50]).astype(float)) + + print("Base and Novel classes performance") + + # print the header + print( + "{:<10} {:<10} {:<10} {:<10} {:<10} {:<10} {:<10} {:<10} {:<10} {:<10} {:<10}".format( + "TETA50:", + "TETA", + "LocA", + "AssocA", + "ClsA", + "LocRe", + "LocPr", + "AssocRe", + "AssocPr", + "ClsRe", + "ClsPr", + ) + ) + + if frequent_teta: + freq_teta_mean = np.mean(np.stack(frequent_teta), axis=0) + + # print the frequent teta mean + print("{:<10} ".format("Base"), end="") + print(*["{:<10.3f}".format(num) for num in freq_teta_mean]) + + else: + print("No Base classes to evaluate!") + freq_teta_mean = None + if rare_teta: + rare_teta_mean = np.mean(np.stack(rare_teta), axis=0) + + # print the rare teta mean + print("{:<10} ".format("Novel"), end="") + print(*["{:<10.3f}".format(num) for num in rare_teta_mean]) + else: + print("No Novel classes to evaluate!") + rare_teta_mean = None + + return freq_teta_mean, rare_teta_mean diff --git a/masa/datasets/evaluation/utils.py b/masa/datasets/evaluation/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7dd0a72335c94a80d222f5bf88572accd1b1636a --- /dev/null +++ b/masa/datasets/evaluation/utils.py @@ -0,0 +1,51 @@ +import numpy as np +from pycocotools import mask as mask_utils + +SHAPE = [720, 1280] + + +def mask_prepare(track_dict): + scores, masks = [], [] + labels = track_dict["labels"] + for instance in labels: + masks.append(mask_utils.decode(instance["rle"])) + scores.append(instance["score"]) + return scores, masks, track_dict + + +def mask_postprocess(scores, masks, track_dict): + sorted_idxs = np.argsort(scores)[::-1] # Sort indices in descending order of scores + processed_area = np.zeros( + SHAPE, dtype=np.uint8 + ) # Empty mask to record processed areas + + for idx in sorted_idxs: + current_mask = masks[idx] + # Remove overlapping parts with already processed areas + current_mask = np.where(processed_area, 0, current_mask) + if current_mask.sum() > 0: # Only keep non-empty masks + # Update processed area + processed_area = np.maximum(processed_area, current_mask) + + masks[idx] = current_mask + + valid_rle_masks = [ + mask_utils.encode(np.asfortranarray(masks[idx])) + for idx in sorted_idxs + if masks[idx].sum() > 0 + ] + valid_idxs = [idx for idx in sorted_idxs if masks[idx].sum() > 0] + + valid_track_dicts = track_dict.copy() + + valid_labels = [] + for i in range(len(valid_idxs)): + vidx = valid_idxs[i] + if isinstance(valid_rle_masks[i]["counts"], bytes): + valid_rle_masks[i]["counts"] = valid_rle_masks[i]["counts"].decode() + valid_track_dicts["labels"][vidx]["rle"] = valid_rle_masks[i] + valid_labels.append(valid_track_dicts["labels"][vidx]) + + valid_track_dicts["labels"] = valid_labels + + return valid_track_dicts diff --git a/masa/datasets/masa_dataset.py b/masa/datasets/masa_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..dfc22b60490933e40d98cbbb96a616b13d2aacd1 --- /dev/null +++ b/masa/datasets/masa_dataset.py @@ -0,0 +1,275 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import logging +import os.path as osp +import pickle +from typing import List, Union + +import h5py +import tqdm +from mmdet.datasets.api_wrappers import COCO +from mmdet.datasets.base_det_dataset import BaseDetDataset +from mmdet.registry import DATASETS +from mmengine.fileio import get_local_path +from mmengine.logging import print_log + + +@DATASETS.register_module() +class MASADataset(BaseDetDataset): + """Dataset for COCO.""" + + METAINFO = { + "classes": ("object"), + # palette is a list of color tuples, which is used for visualization. + "palette": [(220, 20, 60)], + } + COCOAPI = COCO + # ann_id is unique in coco dataset. + ANN_ID_UNIQUE = True + + def __init__(self, anno_hdf5_path=None, img_prefix=None, *args, **kwargs): + + self.anno_hdf5_path = anno_hdf5_path + self.img_prefix = img_prefix + super().__init__(*args, **kwargs) + + def read_dicts_from_hdf5(self, hdf5_file_path, pkl_file_path): + with h5py.File(hdf5_file_path, "r") as hf: + # Retrieve the dataset corresponding to the specified .pkl file path + dataset = hf[pkl_file_path] + binary_data = dataset[()] + # Deserialize the binary data and load the list of dictionaries + list_of_dicts = pickle.loads(binary_data) + return list_of_dicts + + def get_ann_info(self, img_info): + """Get COCO annotation by index. + + Args: + idx (int): Index of data. + + Returns: + dict: Annotation info of specified index. + """ + + if self.anno_hdf5_path is not None: + try: + ann_info = self.read_dicts_from_hdf5( + self.anno_hdf5_path, img_info["file_name"].replace(".jpg", ".pkl") + ) + return ann_info + except: + print(self.anno_hdf5_path) + print(img_info["file_name"].replace(".jpg", ".pkl")) + return None + else: + img_id = img_info["id"] + ann_ids = self.coco.get_ann_ids(img_ids=[img_id], cat_ids=self.cat_ids) + ann_info = self.coco.load_anns(ann_ids) + return ann_info + + def __getitem__(self, idx: int) -> dict: + """Get the idx-th image and data information of dataset after + ``self.pipeline``, and ``full_init`` will be called if the dataset has + not been fully initialized. + + During training phase, if ``self.pipeline`` get ``None``, + ``self._rand_another`` will be called until a valid image is fetched or + the maximum limit of refetech is reached. + + Args: + idx (int): The index of self.data_list. + + Returns: + dict: The idx-th image and data information of dataset after + ``self.pipeline``. + """ + # Performing full initialization by calling `__getitem__` will consume + # extra memory. If a dataset is not fully initialized by setting + # `lazy_init=True` and then fed into the dataloader. Different workers + # will simultaneously read and parse the annotation. It will cost more + # time and memory, although this may work. Therefore, it is recommended + # to manually call `full_init` before dataset fed into dataloader to + # ensure all workers use shared RAM from master process. + if not self._fully_initialized: + print_log( + "Please call `full_init()` method manually to accelerate " "the speed.", + logger="current", + level=logging.WARNING, + ) + self.full_init() + + if self.test_mode: + data = self.prepare_data(idx) + if data is None: + raise Exception( + "Test time pipline should not get `None` " "data_sample" + ) + return data + + for _ in range(self.max_refetch + 1): + try: + data = self.prepare_data(idx) + except Exception as e: + data = None + # Broken images or random augmentations may cause the returned data + # to be None + if data is None: + idx = self._rand_another() + continue + return data + + raise Exception( + f"Cannot find valid image after {self.max_refetch}! " + "Please check your image path and pipeline" + ) + + def load_data_list(self) -> List[dict]: + """Load annotations from an annotation file named as ``self.ann_file`` + + Returns: + List[dict]: A list of annotation. + """ # noqa: E501 + with get_local_path( + self.ann_file, backend_args=self.backend_args + ) as local_path: + self.coco = self.COCOAPI(local_path) + # The order of returned `cat_ids` will not + # change with the order of the `classes` + self.cat_ids = self.coco.get_cat_ids(cat_names=self.metainfo["classes"]) + self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)} + self.cat_img_map = copy.deepcopy(self.coco.cat_img_map) + + img_ids = self.coco.get_img_ids() + data_list = [] + total_ann_ids = [] + print("Loading data list...") + for img_id in tqdm.tqdm(img_ids): + raw_img_info = self.coco.load_imgs([img_id])[0] + raw_img_info["img_id"] = img_id + ann_ids = self.coco.get_ann_ids(img_ids=[img_id]) + raw_ann_info = self.coco.load_anns(ann_ids) + + total_ann_ids.extend(ann_ids) + + parsed_data_info = self.parse_data_info( + {"raw_ann_info": raw_ann_info, "raw_img_info": raw_img_info} + ) + data_list.append(parsed_data_info) + if self.ANN_ID_UNIQUE: + assert len(set(total_ann_ids)) == len( + total_ann_ids + ), f"Annotation ids in '{self.ann_file}' are not unique!" + + del self.coco + + return data_list + + def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]: + """Parse raw annotation to target format. + + Args: + raw_data_info (dict): Raw data information load from ``ann_file`` + + Returns: + Union[dict, List[dict]]: Parsed annotation. + """ + img_info = raw_data_info["raw_img_info"] + ann_info = raw_data_info["raw_ann_info"] + + data_info = {} + + # TODO: need to change data_prefix['img'] to data_prefix['img_path'] + img_path = osp.join(self.data_prefix["img"], img_info["file_name"]) + if self.data_prefix.get("seg", None): + seg_map_path = osp.join( + self.data_prefix["seg"], + img_info["file_name"].rsplit(".", 1)[0] + self.seg_map_suffix, + ) + else: + seg_map_path = None + data_info["img_path"] = img_path + data_info["img_id"] = img_info["img_id"] + data_info["seg_map_path"] = seg_map_path + data_info["height"] = img_info["height"] + data_info["width"] = img_info["width"] + + if self.return_classes: + data_info["text"] = self.metainfo["classes"] + data_info["caption_prompt"] = self.caption_prompt + data_info["custom_entities"] = True + + instances = [] + for i, ann in enumerate(ann_info): + instance = {} + + if ann.get("ignore", False): + continue + x1, y1, w, h = ann["bbox"] + inter_w = max(0, min(x1 + w, img_info["width"]) - max(x1, 0)) + inter_h = max(0, min(y1 + h, img_info["height"]) - max(y1, 0)) + if inter_w * inter_h == 0: + continue + if ann["area"] <= 0 or w < 1 or h < 1: + continue + if "category_id" not in ann: + ann["category_id"] = 1 + if ann["category_id"] not in self.cat_ids: + continue + bbox = [x1, y1, x1 + w, y1 + h] + + if ann.get("iscrowd", False): + instance["ignore_flag"] = 1 + else: + instance["ignore_flag"] = 0 + instance["bbox"] = bbox + instance["bbox_label"] = self.cat2label[ann["category_id"]] + + if ann.get("segmentation", None): + instance["mask"] = ann["segmentation"] + + if "instance_id" in ann: + instance["instance_id"] = ann["instance_id"] + else: + instance["instance_id"] = ann["id"] + + instances.append(instance) + data_info["instances"] = instances + return data_info + + def filter_data(self) -> List[dict]: + """Filter annotations according to filter_cfg. + + Returns: + List[dict]: Filtered results. + """ + if self.test_mode: + return self.data_list + + if self.filter_cfg is None: + return self.data_list + + filter_empty_gt = self.filter_cfg.get("filter_empty_gt", False) + min_size = self.filter_cfg.get("min_size", 0) + + # obtain images that contain annotation + ids_with_ann = set(data_info["img_id"] for data_info in self.data_list) + # obtain images that contain annotations of the required categories + ids_in_cat = set() + for i, class_id in enumerate(self.cat_ids): + ids_in_cat |= set(self.cat_img_map[class_id]) + # merge the image id sets of the two conditions and use the merged set + # to filter out images if self.filter_empty_gt=True + ids_in_cat &= ids_with_ann + + valid_data_infos = [] + for i, data_info in enumerate(self.data_list): + img_id = data_info["img_id"] + width = data_info["width"] + height = data_info["height"] + if filter_empty_gt and img_id not in ids_in_cat: + continue + if min(width, height) >= min_size: + valid_data_infos.append(data_info) + + return valid_data_infos diff --git a/masa/datasets/pipelines/__init__.py b/masa/datasets/pipelines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3f548a4e94e475eae08a8d5ca8fce9dd9d331ae0 --- /dev/null +++ b/masa/datasets/pipelines/__init__.py @@ -0,0 +1,15 @@ +from .formatting import PackMatchInputs +from .framesample import MixUniformRefFrameSample +from .transforms import SeqCopyPaste, SeqMixUp, SeqMosaic, SeqRandomAffine +from .wrappers import MasaTransformBroadcaster + +__all__ = [ + "MasaTransformBroadcaster", + "MixUniformRefFrameSample", + "PackMatchInputs", + "SeqMosaic", + "SeqMixUp", + "SeqCopyPaste", + "SeqRandomAffine", + "PackMatchInputs", +] diff --git a/masa/datasets/pipelines/__pycache__/__init__.cpython-311.pyc b/masa/datasets/pipelines/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab7593f47662a8008de4458e88fc48ade89be442 Binary files /dev/null and b/masa/datasets/pipelines/__pycache__/__init__.cpython-311.pyc differ diff --git a/masa/datasets/pipelines/__pycache__/formatting.cpython-311.pyc b/masa/datasets/pipelines/__pycache__/formatting.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18e8d777f2fee9a0423421e0f9f0ed9f81c77793 Binary files /dev/null and b/masa/datasets/pipelines/__pycache__/formatting.cpython-311.pyc differ diff --git a/masa/datasets/pipelines/__pycache__/framesample.cpython-311.pyc b/masa/datasets/pipelines/__pycache__/framesample.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9ea8241c53144f170f83491397b89e225190e38 Binary files /dev/null and b/masa/datasets/pipelines/__pycache__/framesample.cpython-311.pyc differ diff --git a/masa/datasets/pipelines/__pycache__/transforms.cpython-311.pyc b/masa/datasets/pipelines/__pycache__/transforms.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4992f54d5b5fb13d146781001147089210b77da9 Binary files /dev/null and b/masa/datasets/pipelines/__pycache__/transforms.cpython-311.pyc differ diff --git a/masa/datasets/pipelines/__pycache__/wrappers.cpython-311.pyc b/masa/datasets/pipelines/__pycache__/wrappers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da192d260dff6d67c481adda48e8a52f420d7bb8 Binary files /dev/null and b/masa/datasets/pipelines/__pycache__/wrappers.cpython-311.pyc differ diff --git a/masa/datasets/pipelines/formatting.py b/masa/datasets/pipelines/formatting.py new file mode 100644 index 0000000000000000000000000000000000000000..629454660a0806cd9af9fc5e23cb8935dd0cfd70 --- /dev/null +++ b/masa/datasets/pipelines/formatting.py @@ -0,0 +1,176 @@ +from typing import Optional, Sequence + +import numpy as np +from mmcv.transforms import to_tensor +from mmcv.transforms.base import BaseTransform +from mmdet.registry import TRANSFORMS +from mmdet.structures import DetDataSample, TrackDataSample +from mmdet.structures.bbox import BaseBoxes +from mmengine.structures import InstanceData + + +@TRANSFORMS.register_module(force=True) +class PackMatchInputs(BaseTransform): + """Pack the inputs data for the multi object tracking and video instance + segmentation. All the information of images are packed to ``inputs``. All + the information except images are packed to ``data_samples``. In order to + get the original annotaiton and meta info, we add `instances` key into meta + keys. + + Args: + meta_keys (Sequence[str]): Meta keys to be collected in + ``data_sample.metainfo``. Defaults to None. + default_meta_keys (tuple): Default meta keys. Defaults to ('img_id', + 'img_path', 'ori_shape', 'img_shape', 'scale_factor', + 'flip', 'flip_direction', 'frame_id', 'is_video_data', + 'video_id', 'video_length', 'instances'). + """ + + mapping_table = { + "gt_bboxes": "bboxes", + "gt_bboxes_labels": "labels", + "gt_masks": "masks", + "gt_instances_ids": "instances_ids", + } + + def __init__( + self, + meta_keys: Optional[dict] = None, + default_meta_keys: tuple = ( + "img_id", + "img_path", + "ori_shape", + "img_shape", + "scale_factor", + "flip", + "flip_direction", + "frame_id", + "video_id", + "video_length", + "ori_video_length", + "instances", + ), + ): + self.meta_keys = default_meta_keys + if meta_keys is not None: + if isinstance(meta_keys, str): + meta_keys = (meta_keys,) + else: + assert isinstance(meta_keys, tuple), "meta_keys must be str or tuple" + self.meta_keys += meta_keys + + def transform(self, results: dict) -> dict: + """Method to pack the input data. + Args: + results (dict): Result dict from the data pipeline. + Returns: + dict: + - 'inputs' (dict[Tensor]): The forward data of models. + - 'data_samples' (obj:`TrackDataSample`): The annotation info of + the samples. + """ + packed_results = dict() + packed_results["inputs"] = dict() + + # 1. Pack images + if "img" in results: + imgs = results["img"] + imgs = np.stack(imgs, axis=0) + # imgs = imgs.transpose(0, 3, 1, 2) + if not imgs.flags.c_contiguous: + imgs = np.ascontiguousarray(imgs.transpose(0, 3, 1, 2)) + imgs = to_tensor(imgs) + else: + imgs = to_tensor(imgs).permute(0, 3, 1, 2).contiguous() + packed_results["inputs"] = imgs + + # 2. Pack InstanceData + if "gt_ignore_flags" in results: + gt_ignore_flags_list = results["gt_ignore_flags"] + valid_idx_list, ignore_idx_list = [], [] + for gt_ignore_flags in gt_ignore_flags_list: + valid_idx = np.where(gt_ignore_flags == 0)[0] + ignore_idx = np.where(gt_ignore_flags == 1)[0] + valid_idx_list.append(valid_idx) + ignore_idx_list.append(ignore_idx) + + assert "img_id" in results, "'img_id' must contained in the results " + "for counting the number of images" + + num_imgs = len(results["img_id"]) + instance_data_list = [InstanceData() for _ in range(num_imgs)] + ignore_instance_data_list = [InstanceData() for _ in range(num_imgs)] + + for key in self.mapping_table.keys(): + if key not in results: + continue + if key == "gt_masks": + mapped_key = self.mapping_table[key] + gt_masks_list = results[key] + if "gt_ignore_flags" in results: + for i, gt_mask in enumerate(gt_masks_list): + valid_idx, ignore_idx = valid_idx_list[i], ignore_idx_list[i] + instance_data_list[i][mapped_key] = gt_mask[valid_idx] + ignore_instance_data_list[i][mapped_key] = gt_mask[ignore_idx] + + else: + for i, gt_mask in enumerate(gt_masks_list): + instance_data_list[i][mapped_key] = gt_mask + + elif isinstance(results[key][0], BaseBoxes): + mapped_key = self.mapping_table[key] + gt_bboxes_list = results[key] + if "gt_ignore_flags" in results: + for i, gt_bbox in enumerate(gt_bboxes_list): + gt_bbox = gt_bbox.tensor + valid_idx, ignore_idx = valid_idx_list[i], ignore_idx_list[i] + instance_data_list[i][mapped_key] = gt_bbox[valid_idx] + ignore_instance_data_list[i][mapped_key] = gt_bbox[ignore_idx] + + else: + anns_list = results[key] + if "gt_ignore_flags" in results: + for i, ann in enumerate(anns_list): + valid_idx, ignore_idx = valid_idx_list[i], ignore_idx_list[i] + instance_data_list[i][self.mapping_table[key]] = to_tensor( + ann[valid_idx] + ) + ignore_instance_data_list[i][ + self.mapping_table[key] + ] = to_tensor(ann[ignore_idx]) + else: + for i, ann in enumerate(anns_list): + instance_data_list[i][self.mapping_table[key]] = to_tensor(ann) + + det_data_samples_list = [] + for i in range(num_imgs): + det_data_sample = DetDataSample() + det_data_sample.gt_instances = instance_data_list[i] + det_data_sample.ignored_instances = ignore_instance_data_list[i] + det_data_samples_list.append(det_data_sample) + + # 3. Pack metainfo + for key in self.meta_keys: + if key not in results: + continue + img_metas_list = results[key] + for i, img_meta in enumerate(img_metas_list): + det_data_samples_list[i].set_metainfo({f"{key}": img_meta}) + + track_data_sample = TrackDataSample() + track_data_sample.video_data_samples = det_data_samples_list + if "key_frame_flags" in results: + key_frame_flags = np.asarray(results["key_frame_flags"]) + key_frames_inds = np.where(key_frame_flags)[0].tolist() + ref_frames_inds = np.where(~key_frame_flags)[0].tolist() + track_data_sample.set_metainfo(dict(key_frames_inds=key_frames_inds)) + track_data_sample.set_metainfo(dict(ref_frames_inds=ref_frames_inds)) + + packed_results["data_samples"] = track_data_sample + return packed_results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f"meta_keys={self.meta_keys}, " + repr_str += f"default_meta_keys={self.default_meta_keys})" + return repr_str diff --git a/masa/datasets/pipelines/framesample.py b/masa/datasets/pipelines/framesample.py new file mode 100644 index 0000000000000000000000000000000000000000..ea69e3e1a7e650ec2c94659ee0c9d2e7a2fabff5 --- /dev/null +++ b/masa/datasets/pipelines/framesample.py @@ -0,0 +1,161 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import random +from collections import defaultdict +from typing import Dict, List, Optional, Union + +from mmdet.datasets.transforms.frame_sampling import BaseFrameSample +from mmdet.registry import TRANSFORMS + + +@TRANSFORMS.register_module(force=True) +class MixUniformRefFrameSample(BaseFrameSample): + """Uniformly sample reference frames. + + Args: + num_ref_imgs (int): Number of reference frames to be sampled. + frame_range (int | list[int]): Range of frames to be sampled around + key frame. If int, the range is [-frame_range, frame_range]. + Defaults to 10. + filter_key_img (bool): Whether to filter the key frame when + sampling reference frames. Defaults to True. + collect_video_keys (list[str]): The keys of video info to be + collected. + """ + + def __init__( + self, + num_ref_imgs: int = 1, + frame_range: Union[int, List[int]] = 10, + filter_key_img: bool = True, + collect_video_keys: List[str] = ["video_id", "video_length"], + ): + self.num_ref_imgs = num_ref_imgs + self.filter_key_img = filter_key_img + if isinstance(frame_range, int): + assert frame_range >= 0, "frame_range can not be a negative value." + frame_range = [-frame_range, frame_range] + elif isinstance(frame_range, list): + assert len(frame_range) == 2, "The length must be 2." + assert frame_range[0] <= 0 and frame_range[1] >= 0 + for i in frame_range: + assert isinstance(i, int), "Each element must be int." + else: + raise TypeError("The type of frame_range must be int or list.") + self.frame_range = frame_range + super().__init__(collect_video_keys=collect_video_keys) + + def sampling_frames(self, video_length: int, key_frame_id: int): + """Sampling frames. + + Args: + video_length (int): The length of the video. + key_frame_id (int): The key frame id. + + Returns: + list[int]: The sampled frame indices. + """ + if video_length > 1: + left = max(0, key_frame_id + self.frame_range[0]) + right = min(key_frame_id + self.frame_range[1], video_length - 1) + frame_ids = list(range(0, video_length)) + + valid_ids = frame_ids[left : right + 1] + if self.filter_key_img and key_frame_id in valid_ids: + valid_ids.remove(key_frame_id) + assert ( + len(valid_ids) > 0 + ), "After filtering key frame, there are no valid frames" + if len(valid_ids) < self.num_ref_imgs: + valid_ids = valid_ids * self.num_ref_imgs + ref_frame_ids = random.sample(valid_ids, self.num_ref_imgs) + else: + ref_frame_ids = [key_frame_id] * self.num_ref_imgs + + sampled_frames_ids = [key_frame_id] + ref_frame_ids + sampled_frames_ids = sorted(sampled_frames_ids) + + key_frames_ind = sampled_frames_ids.index(key_frame_id) + key_frame_flags = [False] * len(sampled_frames_ids) + key_frame_flags[key_frames_ind] = True + return sampled_frames_ids, key_frame_flags + + def transform(self, video_infos: dict) -> Optional[Dict[str, List]]: + """Transform the video information. + + Args: + video_infos (dict): The whole video information. + + Returns: + dict: The data information of the sampled frames. + """ + + if "video_length" not in video_infos: + generated_video_info = {} + key_frame_id = 0 + generated_video_info["video_id"] = video_infos["img_id"] + generated_video_info["video_length"] = 1 + generated_video_info["key_frame_id"] = key_frame_id + generated_video_info["images"] = [video_infos] + (sampled_frames_ids, key_frame_flags) = self.sampling_frames( + generated_video_info["video_length"], key_frame_id=key_frame_id + ) + results = self.prepare_data(generated_video_info, sampled_frames_ids) + results["key_frame_flags"] = key_frame_flags + # results['is_image'] = True + + else: + if "key_frame_id" in video_infos: + key_frame_id = video_infos["key_frame_id"] + assert isinstance(video_infos["key_frame_id"], int) + else: + key_frame_id = random.sample( + list(range(video_infos["video_length"])), 1 + )[0] + + (sampled_frames_ids, key_frame_flags) = self.sampling_frames( + video_infos["video_length"], key_frame_id=key_frame_id + ) + results = self.prepare_data(video_infos, sampled_frames_ids) + results["key_frame_flags"] = key_frame_flags + # results['is_image'] = False + + return results + + def prepare_data( + self, video_infos: dict, sampled_inds: List[int] + ) -> Dict[str, List]: + """Prepare data for the subsequent pipeline. + + Args: + video_infos (dict): The whole video information. + sampled_inds (list[int]): The sampled frame indices. + + Returns: + dict: The processed data information. + """ + frames_anns = video_infos["images"] + final_data_info = defaultdict(list) + # for data in frames_anns: + for index in sampled_inds: + data = copy.deepcopy(frames_anns[index]) + # copy the info in video-level into img-level + for key in self.collect_video_keys: + if key == "video_length": + data["ori_video_length"] = video_infos[key] + data["video_length"] = len(sampled_inds) + else: + data[key] = video_infos[key] + # Collate data_list (list of dict to dict of list) + for key, value in data.items(): + final_data_info[key].append(value) + + return final_data_info + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f"(num_ref_imgs={self.num_ref_imgs}, " + repr_str += f"frame_range={self.frame_range}, " + repr_str += f"filter_key_img={self.filter_key_img}, " + repr_str += f"collect_video_keys={self.collect_video_keys})" + return repr_str diff --git a/masa/datasets/pipelines/loading.py b/masa/datasets/pipelines/loading.py new file mode 100644 index 0000000000000000000000000000000000000000..6e9ff16faaae5a01838d925d6ed64cfce35d17d6 --- /dev/null +++ b/masa/datasets/pipelines/loading.py @@ -0,0 +1,158 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch +from mmdet.datasets.transforms.loading import LoadAnnotations +from mmdet.registry import TRANSFORMS +from mmdet.structures.bbox import get_box_type + + +@TRANSFORMS.register_module() +class LoadMatchAnnotations(LoadAnnotations): + """Load and process the ``instances`` and ``seg_map`` annotation provided + by dataset. It must load ``instances_ids`` which is only used in the + tracking tasks. The annotation format is as the following: + + .. code-block:: python + { + 'instances': + [ + { + # List of 4 numbers representing the bounding box of the + # instance, in (x1, y1, x2, y2) order. + 'bbox': [x1, y1, x2, y2], + # Label of image classification. + 'bbox_label': 1, + # Used in tracking. + # Id of instances. + 'instance_id': 100, + # Used in instance/panoptic segmentation. The segmentation mask + # of the instance or the information of segments. + # 1. If list[list[float]], it represents a list of polygons, + # one for each connected component of the object. Each + # list[float] is one simple polygon in the format of + # [x1, y1, ..., xn, yn] (n >= 3). The Xs and Ys are absolute + # coordinates in unit of pixels. + # 2. If dict, it represents the per-pixel segmentation mask in + # COCO's compressed RLE format. The dict should have keys + # “size” and “counts”. Can be loaded by pycocotools + 'mask': list[list[float]] or dict, + } + ] + # Filename of semantic or panoptic segmentation ground truth file. + 'seg_map_path': 'a/b/c' + } + + After this module, the annotation has been changed to the format below: + .. code-block:: python + { + # In (x1, y1, x2, y2) order, float type. N is the number of bboxes + # in an image + 'gt_bboxes': np.ndarray(N, 4) + # In int type. + 'gt_bboxes_labels': np.ndarray(N, ) + # In built-in class + 'gt_masks': PolygonMasks (H, W) or BitmapMasks (H, W) + # In uint8 type. + 'gt_seg_map': np.ndarray (H, W) + # in (x, y, v) order, float type. + } + + Required Keys: + + - height (optional) + - width (optional) + - instances + - bbox (optional) + - bbox_label + - instance_id (optional) + - mask (optional) + - ignore_flag (optional) + - seg_map_path (optional) + + Added Keys: + + - gt_bboxes (np.float32) + - gt_bboxes_labels (np.int32) + - gt_instances_ids (np.int32) + - gt_masks (BitmapMasks | PolygonMasks) + - gt_seg_map (np.uint8) + - gt_ignore_flags (np.bool) + """ + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + + def _load_bboxes(self, results: dict) -> None: + """Private function to load bounding box annotations. + + Args: + results (dict): Result dict from :obj:``mmcv.BaseDataset``. + + Returns: + dict: The dict contains loaded bounding box annotations. + """ + gt_bboxes = [] + gt_ignore_flags = [] + # TODO: use bbox_type + for instance in results["instances"]: + # The datasets which are only format in evaluation don't have + # groundtruth boxes. + if "bbox" in instance: + gt_bboxes.append(instance["bbox"]) + if "ignore_flag" in instance: + gt_ignore_flags.append(instance["ignore_flag"]) + + # TODO: check this case + if len(gt_bboxes) != len(gt_ignore_flags): + # There may be no ``gt_ignore_flags`` in some cases, we treat them + # as all False in order to keep the length of ``gt_bboxes`` and + # ``gt_ignore_flags`` the same + gt_ignore_flags = [False] * len(gt_bboxes) + + if self.box_type is None: + results["gt_bboxes"] = np.array(gt_bboxes, dtype=np.float32).reshape( + (-1, 4) + ) + else: + _, box_type_cls = get_box_type(self.box_type) + results["gt_bboxes"] = box_type_cls(gt_bboxes, dtype=torch.float32) + results["gt_ignore_flags"] = np.array(gt_ignore_flags, dtype=bool) + + def _load_instances_ids(self, results: dict) -> None: + """Private function to load instances id annotations. + + Args: + results (dict): Result dict from :obj :obj:``mmcv.BaseDataset``. + + Returns: + dict: The dict containing instances id annotations. + """ + gt_instances_ids = [] + for instance in results["instances"]: + gt_instances_ids.append(instance["instance_id"]) + results["gt_instances_ids"] = np.array(gt_instances_ids, dtype=np.int32) + + def transform(self, results: dict) -> dict: + """Function to load multiple types annotations. + + Args: + results (dict): Result dict from :obj:``mmcv.BaseDataset``. + + Returns: + dict: The dict contains loaded bounding box, label, instances id + and semantic segmentation and keypoints annotations. + """ + results = super().transform(results) + self._load_instances_ids(results) + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f"(with_bbox={self.with_bbox}, " + repr_str += f"with_label={self.with_label}, " + repr_str += f"with_mask={self.with_mask}, " + repr_str += f"with_seg={self.with_seg}, " + repr_str += f"poly2mask={self.poly2mask}, " + repr_str += f"imdecode_backend='{self.imdecode_backend}', " + repr_str += f"file_client_args={self.file_client_args})" + return repr_str diff --git a/masa/datasets/pipelines/transforms.py b/masa/datasets/pipelines/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..d797fc8ee9db748ee3a62f961221ba046713196e --- /dev/null +++ b/masa/datasets/pipelines/transforms.py @@ -0,0 +1,1176 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import math +import warnings +from typing import List, Optional, Sequence, Tuple, Union + +import cv2 +import mmcv +import numpy as np +from mmcv.image import imresize +from mmcv.transforms import BaseTransform +from mmcv.transforms.utils import cache_randomness +from mmdet.registry import TRANSFORMS +from mmdet.structures.bbox import autocast_box_type +from mmdet.structures.mask import BitmapMasks +from mmdet.utils import log_img_scale +from mmengine.dataset import BaseDataset +from numpy import random + +try: + from imagecorruptions import corrupt +except ImportError: + corrupt = None + +try: + import albumentations + from albumentations import Compose +except ImportError: + albumentations = None + Compose = None + +Number = Union[int, float] + + +def _fixed_scale_size( + size: Tuple[int, int], scale: Union[float, int, tuple], +) -> Tuple[int, int]: + """Rescale a size by a ratio. + + Args: + size (tuple[int]): (w, h). + scale (float | tuple(float)): Scaling factor. + + Returns: + tuple[int]: scaled size. + """ + if isinstance(scale, (float, int)): + scale = (scale, scale) + w, h = size + # don't need o.5 offset + return int(w * float(scale[0])), int(h * float(scale[1])) + + +def rescale_size( + old_size: tuple, scale: Union[float, int, tuple], return_scale: bool = False +) -> tuple: + """Calculate the new size to be rescaled to. + + Args: + old_size (tuple[int]): The old size (w, h) of image. + scale (float | tuple[int]): The scaling factor or maximum size. + If it is a float number, then the image will be rescaled by this + factor, else if it is a tuple of 2 integers, then the image will + be rescaled as large as possible within the scale. + return_scale (bool): Whether to return the scaling factor besides the + rescaled image size. + + Returns: + tuple[int]: The new rescaled image size. + """ + w, h = old_size + if isinstance(scale, (float, int)): + if scale <= 0: + raise ValueError(f"Invalid scale {scale}, must be positive.") + scale_factor = scale + elif isinstance(scale, tuple): + max_long_edge = max(scale) + max_short_edge = min(scale) + scale_factor = min(max_long_edge / max(h, w), max_short_edge / min(h, w)) + else: + raise TypeError( + f"Scale must be a number or tuple of int, but got {type(scale)}" + ) + # only change this + new_size = _fixed_scale_size((w, h), scale_factor) + + if return_scale: + return new_size, scale_factor + else: + return new_size + + +def imrescale( + img: np.ndarray, + scale: Union[float, Tuple[int, int]], + return_scale: bool = False, + interpolation: str = "bilinear", + backend: Optional[str] = None, +) -> Union[np.ndarray, Tuple[np.ndarray, float]]: + """Resize image while keeping the aspect ratio. + + Args: + img (ndarray): The input image. + scale (float | tuple[int]): The scaling factor or maximum size. + If it is a float number, then the image will be rescaled by this + factor, else if it is a tuple of 2 integers, then the image will + be rescaled as large as possible within the scale. + return_scale (bool): Whether to return the scaling factor besides the + rescaled image. + interpolation (str): Same as :func:`resize`. + backend (str | None): Same as :func:`resize`. + + Returns: + ndarray: The rescaled image. + """ + h, w = img.shape[:2] + new_size, scale_factor = rescale_size((w, h), scale, return_scale=True) + rescaled_img = imresize(img, new_size, interpolation=interpolation, backend=backend) + if return_scale: + return rescaled_img, scale_factor + else: + return rescaled_img + + +@TRANSFORMS.register_module(force=True) +class SeqMosaic(BaseTransform): + """Mosaic augmentation. + + Given 4 images, mosaic transform combines them into + one output image. The output image is composed of the parts from each sub- + image. + + .. code:: text + + mosaic transform + center_x + +------------------------------+ + | pad | pad | + | +-----------+ | + | | | | + | | image1 |--------+ | + | | | | | + | | | image2 | | + center_y |----+-------------+-----------| + | | cropped | | + |pad | image3 | image4 | + | | | | + +----|-------------+-----------+ + | | + +-------------+ + + The mosaic transform steps are as follows: + + 1. Choose the mosaic center as the intersections of 4 images + 2. Get the left top image according to the index, and randomly + sample another 3 images from the custom dataset. + 3. Sub image will be cropped if image is larger than mosaic patch + + Required Keys: + + - img + - gt_bboxes (BaseBoxes[torch.float32]) (optional) + - gt_bboxes_labels (np.int64) (optional) + - gt_ignore_flags (bool) (optional) + - mix_results (List[dict]) + + Modified Keys: + + - img + - img_shape + - gt_bboxes (optional) + - gt_bboxes_labels (optional) + - gt_ignore_flags (optional) + - gt_instances_ids (options, only used in MOT/VIS) + + Args: + img_scale (Sequence[int]): Image size before mosaic pipeline of single + image. The shape order should be (width, height). + Defaults to (640, 640). + center_ratio_range (Sequence[float]): Center ratio range of mosaic + output. Defaults to (0.5, 1.5). + bbox_clip_border (bool, optional): Whether to clip the objects outside + the border of the image. In some dataset like MOT17, the gt bboxes + are allowed to cross the border of images. Therefore, we don't + need to clip the gt bboxes in these cases. Defaults to True. + pad_val (int): Pad value. Defaults to 114. + prob (float): Probability of applying this transformation. + Defaults to 1.0. + """ + + def __init__( + self, + img_scale: Tuple[int, int] = (640, 640), + center_ratio_range: Tuple[float, float] = (0.5, 1.5), + bbox_clip_border: bool = True, + pad_val: float = 114.0, + prob: float = 1.0, + ) -> None: + assert isinstance(img_scale, tuple) + assert 0 <= prob <= 1.0, ( + "The probability should be in range [0,1]. " f"got {prob}." + ) + + log_img_scale(img_scale, skip_square=True, shape_order="wh") + self.img_scale = img_scale + self.center_ratio_range = center_ratio_range + self.bbox_clip_border = bbox_clip_border + self.pad_val = pad_val + self.prob = prob + + @cache_randomness + def get_indexes(self, dataset: BaseDataset) -> int: + """Call function to collect indexes. + + Args: + dataset (:obj:`MultiImageMixDataset`): The dataset. + + Returns: + list: indexes. + """ + + indexes = [random.randint(0, len(dataset)) for _ in range(3)] + return indexes + + @autocast_box_type() + def transform(self, results: dict) -> dict: + """Mosaic transform function. + + Args: + results (dict): Result dict. + + Returns: + dict: Updated result dict. + """ + if random.uniform(0, 1) > self.prob: + return results + + assert "mosaic_mix_results" in results + mosaic_bboxes = [] + mosaic_bboxes_labels = [] + mosaic_ignore_flags = [] + mosaic_instances_ids = [] + if len(results["img"].shape) == 3: + mosaic_img = np.full( + (int(self.img_scale[1] * 2), int(self.img_scale[0] * 2), 3), + self.pad_val, + dtype=results["img"].dtype, + ) + else: + mosaic_img = np.full( + (int(self.img_scale[1] * 2), int(self.img_scale[0] * 2)), + self.pad_val, + dtype=results["img"].dtype, + ) + + # mosaic center x, y + center_x = int(random.uniform(*self.center_ratio_range) * self.img_scale[0]) + center_y = int(random.uniform(*self.center_ratio_range) * self.img_scale[1]) + center_position = (center_x, center_y) + + loc_strs = ("top_left", "top_right", "bottom_left", "bottom_right") + for i, loc in enumerate(loc_strs): + if loc == "top_left": + results_patch = copy.deepcopy(results) + else: + results_patch = copy.deepcopy(results["mosaic_mix_results"][i - 1]) + + img_i = results_patch["img"] + h_i, w_i = img_i.shape[:2] + # keep_ratio resize + scale_ratio_i = min(self.img_scale[1] / h_i, self.img_scale[0] / w_i) + img_i = mmcv.imresize( + img_i, (int(w_i * scale_ratio_i), int(h_i * scale_ratio_i)) + ) + + # compute the combine parameters + paste_coord, crop_coord = self._mosaic_combine( + loc, center_position, img_i.shape[:2][::-1] + ) + x1_p, y1_p, x2_p, y2_p = paste_coord + x1_c, y1_c, x2_c, y2_c = crop_coord + + # crop and paste image + mosaic_img[y1_p:y2_p, x1_p:x2_p] = img_i[y1_c:y2_c, x1_c:x2_c] + + # adjust coordinate + gt_bboxes_i = results_patch["gt_bboxes"] + gt_bboxes_labels_i = results_patch["gt_bboxes_labels"] + gt_ignore_flags_i = results_patch["gt_ignore_flags"] + gt_instances_ids_i = results_patch.get("gt_instances_ids", None) + + padw = x1_p - x1_c + padh = y1_p - y1_c + gt_bboxes_i.rescale_([scale_ratio_i, scale_ratio_i]) + gt_bboxes_i.translate_([padw, padh]) + mosaic_bboxes.append(gt_bboxes_i) + mosaic_bboxes_labels.append(gt_bboxes_labels_i) + mosaic_ignore_flags.append(gt_ignore_flags_i) + mosaic_instances_ids.append(gt_instances_ids_i) + + if len(mosaic_bboxes_labels) > 0: + mosaic_bboxes = mosaic_bboxes[0].cat(mosaic_bboxes, 0) + mosaic_bboxes_labels = np.concatenate(mosaic_bboxes_labels, 0) + mosaic_ignore_flags = np.concatenate(mosaic_ignore_flags, 0) + mosaic_instances_ids = np.concatenate(mosaic_instances_ids, 0) + + if self.bbox_clip_border: + mosaic_bboxes.clip_([2 * self.img_scale[1], 2 * self.img_scale[0]]) + # remove outside bboxes + inside_inds = mosaic_bboxes.is_inside( + [2 * self.img_scale[1], 2 * self.img_scale[0]] + ).numpy() + mosaic_bboxes = mosaic_bboxes[inside_inds] + mosaic_bboxes_labels = mosaic_bboxes_labels[inside_inds] + mosaic_ignore_flags = mosaic_ignore_flags[inside_inds] + mosaic_instances_ids = mosaic_instances_ids[inside_inds] + + results["img"] = mosaic_img + results["img_shape"] = mosaic_img.shape[:2] + results["gt_bboxes"] = mosaic_bboxes + results["gt_bboxes_labels"] = mosaic_bboxes_labels + results["gt_ignore_flags"] = mosaic_ignore_flags + results["gt_instances_ids"] = mosaic_instances_ids + + return results + + def _mosaic_combine( + self, loc: str, center_position_xy: Sequence[float], img_shape_wh: Sequence[int] + ) -> Tuple[Tuple[int], Tuple[int]]: + """Calculate global coordinate of mosaic image and local coordinate of + cropped sub-image. + + Args: + loc (str): Index for the sub-image, loc in ('top_left', + 'top_right', 'bottom_left', 'bottom_right'). + center_position_xy (Sequence[float]): Mixing center for 4 images, + (x, y). + img_shape_wh (Sequence[int]): Width and height of sub-image + + Returns: + tuple[tuple[float]]: Corresponding coordinate of pasting and + cropping + - paste_coord (tuple): paste corner coordinate in mosaic image. + - crop_coord (tuple): crop corner coordinate in mosaic image. + """ + assert loc in ("top_left", "top_right", "bottom_left", "bottom_right") + if loc == "top_left": + # index0 to top left part of image + x1, y1, x2, y2 = ( + max(center_position_xy[0] - img_shape_wh[0], 0), + max(center_position_xy[1] - img_shape_wh[1], 0), + center_position_xy[0], + center_position_xy[1], + ) + crop_coord = ( + img_shape_wh[0] - (x2 - x1), + img_shape_wh[1] - (y2 - y1), + img_shape_wh[0], + img_shape_wh[1], + ) + + elif loc == "top_right": + # index1 to top right part of image + x1, y1, x2, y2 = ( + center_position_xy[0], + max(center_position_xy[1] - img_shape_wh[1], 0), + min(center_position_xy[0] + img_shape_wh[0], self.img_scale[0] * 2), + center_position_xy[1], + ) + crop_coord = ( + 0, + img_shape_wh[1] - (y2 - y1), + min(img_shape_wh[0], x2 - x1), + img_shape_wh[1], + ) + + elif loc == "bottom_left": + # index2 to bottom left part of image + x1, y1, x2, y2 = ( + max(center_position_xy[0] - img_shape_wh[0], 0), + center_position_xy[1], + center_position_xy[0], + min(self.img_scale[1] * 2, center_position_xy[1] + img_shape_wh[1]), + ) + crop_coord = ( + img_shape_wh[0] - (x2 - x1), + 0, + img_shape_wh[0], + min(y2 - y1, img_shape_wh[1]), + ) + + else: + # index3 to bottom right part of image + x1, y1, x2, y2 = ( + center_position_xy[0], + center_position_xy[1], + min(center_position_xy[0] + img_shape_wh[0], self.img_scale[0] * 2), + min(self.img_scale[1] * 2, center_position_xy[1] + img_shape_wh[1]), + ) + crop_coord = ( + 0, + 0, + min(img_shape_wh[0], x2 - x1), + min(y2 - y1, img_shape_wh[1]), + ) + + paste_coord = x1, y1, x2, y2 + return paste_coord, crop_coord + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f"(img_scale={self.img_scale}, " + repr_str += f"center_ratio_range={self.center_ratio_range}, " + repr_str += f"pad_val={self.pad_val}, " + repr_str += f"prob={self.prob})" + return repr_str + + +@TRANSFORMS.register_module(force=True) +class SeqMixUp(BaseTransform): + """MixUp data augmentation. + + .. code:: text + + mixup transform + +------------------------------+ + | mixup image | | + | +--------|--------+ | + | | | | | + |---------------+ | | + | | | | + | | image | | + | | | | + | | | | + | |-----------------+ | + | pad | + +------------------------------+ + + The mixup transform steps are as follows: + + 1. Another random image is picked by dataset and embedded in + the top left patch(after padding and resizing) + 2. The target of mixup transform is the weighted average of mixup + image and origin image. + + Required Keys: + + - img + - gt_bboxes (BaseBoxes[torch.float32]) (optional) + - gt_bboxes_labels (np.int64) (optional) + - gt_ignore_flags (bool) (optional) + - mix_results (List[dict]) + + + Modified Keys: + + - img + - img_shape + - gt_bboxes (optional) + - gt_bboxes_labels (optional) + - gt_ignore_flags (optional) + + + Args: + img_scale (Sequence[int]): Image output size after mixup pipeline. + The shape order should be (width, height). Defaults to (640, 640). + ratio_range (Sequence[float]): Scale ratio of mixup image. + Defaults to (0.5, 1.5). + flip_ratio (float): Horizontal flip ratio of mixup image. + Defaults to 0.5. + pad_val (int): Pad value. Defaults to 114. + max_iters (int): The maximum number of iterations. If the number of + iterations is greater than `max_iters`, but gt_bbox is still + empty, then the iteration is terminated. Defaults to 15. + bbox_clip_border (bool, optional): Whether to clip the objects outside + the border of the image. In some dataset like MOT17, the gt bboxes + are allowed to cross the border of images. Therefore, we don't + need to clip the gt bboxes in these cases. Defaults to True. + """ + + def __init__( + self, + img_scale: Tuple[int, int] = (640, 640), + ratio_range: Tuple[float, float] = (0.5, 1.5), + flip_ratio: float = 0.5, + pad_val: float = 114.0, + max_iters: int = 15, + bbox_clip_border: bool = True, + ) -> None: + assert isinstance(img_scale, tuple) + log_img_scale(img_scale, skip_square=True, shape_order="wh") + self.dynamic_scale = img_scale + self.ratio_range = ratio_range + self.flip_ratio = flip_ratio + self.pad_val = pad_val + self.max_iters = max_iters + self.bbox_clip_border = bbox_clip_border + + @cache_randomness + def get_indexes(self, dataset: BaseDataset) -> int: + """Call function to collect indexes. + + Args: + dataset (:obj:`MultiImageMixDataset`): The dataset. + + Returns: + list: indexes. + """ + + for i in range(self.max_iters): + index = random.randint(0, len(dataset)) + gt_bboxes_i = dataset[index]["gt_bboxes"] + if len(gt_bboxes_i) != 0: + break + + return index + + @autocast_box_type() + def transform(self, results: dict) -> dict: + """MixUp transform function. + + Args: + results (dict): Result dict. + + Returns: + dict: Updated result dict. + """ + + assert "mixup_mix_results" in results + assert ( + len(results["mixup_mix_results"]) == 1 + ), "MixUp only support 2 images now !" + + if results["mixup_mix_results"][0]["gt_bboxes"].shape[0] == 0: + # empty bbox + return results + + retrieve_results = copy.deepcopy(results["mixup_mix_results"][0]) + retrieve_img = retrieve_results["img"] + + jit_factor = random.uniform(*self.ratio_range) + is_flip = random.uniform(0, 1) > self.flip_ratio + + if len(retrieve_img.shape) == 3: + out_img = ( + np.ones( + (self.dynamic_scale[1], self.dynamic_scale[0], 3), + dtype=retrieve_img.dtype, + ) + * self.pad_val + ) + else: + out_img = ( + np.ones(self.dynamic_scale[::-1], dtype=retrieve_img.dtype) + * self.pad_val + ) + + # 1. keep_ratio resize + scale_ratio = min( + self.dynamic_scale[1] / retrieve_img.shape[0], + self.dynamic_scale[0] / retrieve_img.shape[1], + ) + retrieve_img = mmcv.imresize( + retrieve_img, + ( + int(retrieve_img.shape[1] * scale_ratio), + int(retrieve_img.shape[0] * scale_ratio), + ), + ) + + # 2. paste + out_img[: retrieve_img.shape[0], : retrieve_img.shape[1]] = retrieve_img + + # 3. scale jit + scale_ratio *= jit_factor + out_img = mmcv.imresize( + out_img, + (int(out_img.shape[1] * jit_factor), int(out_img.shape[0] * jit_factor)), + ) + + # 4. flip + if is_flip: + out_img = out_img[:, ::-1, :] + + # 5. random crop + ori_img = results["img"] + origin_h, origin_w = out_img.shape[:2] + target_h, target_w = ori_img.shape[:2] + padded_img = ( + np.ones((max(origin_h, target_h), max(origin_w, target_w), 3)) + * self.pad_val + ) + padded_img = padded_img.astype(np.uint8) + padded_img[:origin_h, :origin_w] = out_img + + x_offset, y_offset = 0, 0 + if padded_img.shape[0] > target_h: + y_offset = random.randint(0, padded_img.shape[0] - target_h) + if padded_img.shape[1] > target_w: + x_offset = random.randint(0, padded_img.shape[1] - target_w) + padded_cropped_img = padded_img[ + y_offset : y_offset + target_h, x_offset : x_offset + target_w + ] + + # 6. adjust bbox + retrieve_gt_bboxes = retrieve_results["gt_bboxes"] + retrieve_gt_bboxes.rescale_([scale_ratio, scale_ratio]) + if self.bbox_clip_border: + retrieve_gt_bboxes.clip_([origin_h, origin_w]) + + if is_flip: + retrieve_gt_bboxes.flip_([origin_h, origin_w], direction="horizontal") + + # 7. filter + cp_retrieve_gt_bboxes = retrieve_gt_bboxes.clone() + cp_retrieve_gt_bboxes.translate_([-x_offset, -y_offset]) + if self.bbox_clip_border: + cp_retrieve_gt_bboxes.clip_([target_h, target_w]) + + # 8. mix up + ori_img = ori_img.astype(np.float32) + mixup_img = 0.5 * ori_img + 0.5 * padded_cropped_img.astype(np.float32) + + retrieve_gt_bboxes_labels = retrieve_results["gt_bboxes_labels"] + retrieve_gt_ignore_flags = retrieve_results["gt_ignore_flags"] + retrieve_gt_instances_ids = retrieve_results["gt_instances_ids"] + + mixup_gt_bboxes = cp_retrieve_gt_bboxes.cat( + (results["gt_bboxes"], cp_retrieve_gt_bboxes), dim=0 + ) + mixup_gt_bboxes_labels = np.concatenate( + (results["gt_bboxes_labels"], retrieve_gt_bboxes_labels), axis=0 + ) + mixup_gt_ignore_flags = np.concatenate( + (results["gt_ignore_flags"], retrieve_gt_ignore_flags), axis=0 + ) + mixup_gt_instances_ids = np.concatenate( + (results["gt_instances_ids"], retrieve_gt_instances_ids), axis=0 + ) + + # remove outside bbox + inside_inds = mixup_gt_bboxes.is_inside([target_h, target_w]).numpy() + mixup_gt_bboxes = mixup_gt_bboxes[inside_inds] + mixup_gt_bboxes_labels = mixup_gt_bboxes_labels[inside_inds] + mixup_gt_ignore_flags = mixup_gt_ignore_flags[inside_inds] + mixup_gt_instances_ids = mixup_gt_instances_ids[inside_inds] + + results["img"] = mixup_img.astype(np.uint8) + results["img_shape"] = mixup_img.shape[:2] + results["gt_bboxes"] = mixup_gt_bboxes + results["gt_bboxes_labels"] = mixup_gt_bboxes_labels + results["gt_ignore_flags"] = mixup_gt_ignore_flags + results["gt_instances_ids"] = mixup_gt_instances_ids + + assert len(results["gt_bboxes"]) == len(results["gt_instances_ids"]) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f"(dynamic_scale={self.dynamic_scale}, " + repr_str += f"ratio_range={self.ratio_range}, " + repr_str += f"flip_ratio={self.flip_ratio}, " + repr_str += f"pad_val={self.pad_val}, " + repr_str += f"max_iters={self.max_iters}, " + repr_str += f"bbox_clip_border={self.bbox_clip_border})" + return repr_str + + +@TRANSFORMS.register_module(force=True) +class FilterMatchAnnotations(BaseTransform): + """Filter invalid annotations. + + Required Keys: + + - gt_bboxes (BaseBoxes[torch.float32]) (optional) + - gt_bboxes_labels (np.int64) (optional) + - gt_masks (BitmapMasks | PolygonMasks) (optional) + - gt_ignore_flags (bool) (optional) + + Modified Keys: + + - gt_bboxes (optional) + - gt_bboxes_labels (optional) + - gt_masks (optional) + - gt_ignore_flags (optional) + + Args: + min_gt_bbox_wh (tuple[float]): Minimum width and height of ground truth + boxes. Default: (1., 1.) + min_gt_mask_area (int): Minimum foreground area of ground truth masks. + Default: 1 + by_box (bool): Filter instances with bounding boxes not meeting the + min_gt_bbox_wh threshold. Default: True + by_mask (bool): Filter instances with masks not meeting + min_gt_mask_area threshold. Default: False + keep_empty (bool): Whether to return None when it + becomes an empty bbox after filtering. Defaults to True. + """ + + def __init__( + self, + min_gt_bbox_wh: Tuple[int, int] = (1, 1), + min_gt_mask_area: int = 1, + by_box: bool = True, + by_mask: bool = False, + keep_empty: bool = True, + ) -> None: + # TODO: add more filter options + assert by_box or by_mask + self.min_gt_bbox_wh = min_gt_bbox_wh + self.min_gt_mask_area = min_gt_mask_area + self.by_box = by_box + self.by_mask = by_mask + self.keep_empty = keep_empty + + @autocast_box_type() + def transform(self, results: dict) -> Union[dict, None]: + """Transform function to filter annotations. + + Args: + results (dict): Result dict. + + Returns: + dict: Updated result dict. + """ + assert "gt_bboxes" in results + gt_bboxes = results["gt_bboxes"] + if gt_bboxes.shape[0] == 0: + return results + + tests = [] + if self.by_box: + tests.append( + ( + (gt_bboxes.widths > self.min_gt_bbox_wh[0]) + & (gt_bboxes.heights > self.min_gt_bbox_wh[1]) + ).numpy() + ) + if self.by_mask: + assert "gt_masks" in results + gt_masks = results["gt_masks"] + tests.append(gt_masks.areas >= self.min_gt_mask_area) + + keep = tests[0] + for t in tests[1:]: + keep = keep & t + + if not keep.any(): + if self.keep_empty: + return None + + keys = ( + "gt_bboxes", + "gt_bboxes_labels", + "gt_masks", + "gt_instances_ids", + "gt_ignore_flags", + ) + for key in keys: + if key in results: + results[key] = results[key][keep] + + return results + + def __repr__(self): + return ( + self.__class__.__name__ + f"(min_gt_bbox_wh={self.min_gt_bbox_wh}, " + f"keep_empty={self.keep_empty})" + ) + + +@TRANSFORMS.register_module(force=True) +class SeqCopyPaste(BaseTransform): + """Simple Copy-Paste is a Strong Data Augmentation Method for Instance + Segmentation The simple copy-paste transform steps are as follows: + + 1. The destination image is already resized with aspect ratio kept, + cropped and padded. + 2. Randomly select a source image, which is also already resized + with aspect ratio kept, cropped and padded in a similar way + as the destination image. + 3. Randomly select some objects from the source image. + 4. Paste these source objects to the destination image directly, + due to the source and destination image have the same size. + 5. Update object masks of the destination image, for some origin objects + may be occluded. + 6. Generate bboxes from the updated destination masks and + filter some objects which are totally occluded, and adjust bboxes + which are partly occluded. + 7. Append selected source bboxes, masks, and labels. + + Required Keys: + + - img + - gt_bboxes (BaseBoxes[torch.float32]) (optional) + - gt_bboxes_labels (np.int64) (optional) + - gt_ignore_flags (bool) (optional) + - gt_masks (BitmapMasks) (optional) + + Modified Keys: + + - img + - gt_bboxes (optional) + - gt_bboxes_labels (optional) + - gt_ignore_flags (optional) + - gt_masks (optional) + + Args: + max_num_pasted (int): The maximum number of pasted objects. + Defaults to 100. + bbox_occluded_thr (int): The threshold of occluded bbox. + Defaults to 10. + mask_occluded_thr (int): The threshold of occluded mask. + Defaults to 300. + selected (bool): Whether select objects or not. If select is False, + all objects of the source image will be pasted to the + destination image. + Defaults to True. + paste_by_box (bool): Whether use boxes as masks when masks are not + available. + Defaults to False. + """ + + def __init__( + self, + max_num_pasted: int = 100, + bbox_occluded_thr: int = 10, + mask_occluded_thr: int = 300, + selected: bool = True, + paste_by_box: bool = False, + ) -> None: + self.max_num_pasted = max_num_pasted + self.bbox_occluded_thr = bbox_occluded_thr + self.mask_occluded_thr = mask_occluded_thr + self.selected = selected + self.paste_by_box = paste_by_box + + @cache_randomness + def get_indexes(self, dataset: BaseDataset) -> int: + """Call function to collect indexes.s. + + Args: + dataset (:obj:`MultiImageMixDataset`): The dataset. + Returns: + list: Indexes. + """ + return random.randint(0, len(dataset)) + + @autocast_box_type() + def transform(self, results: dict) -> dict: + """Transform function to make a copy-paste of image. + + Args: + results (dict): Result dict. + Returns: + dict: Result dict with copy-paste transformed. + """ + + assert "copypaste_mix_results" in results + num_images = len(results["copypaste_mix_results"]) + assert ( + num_images == 1 + ), f"CopyPaste only supports processing 2 images, got {num_images}" + if self.selected: + selected_results = copy.deepcopy( + self._select_object(results["copypaste_mix_results"][0]) + ) + else: + selected_results = copy.deepcopy(results["copypaste_mix_results"][0]) + + return self._copy_paste(results, selected_results) + + @cache_randomness + def _get_selected_inds(self, num_bboxes: int) -> np.ndarray: + max_num_pasted = min(num_bboxes + 1, self.max_num_pasted) + num_pasted = np.random.randint(0, max_num_pasted) + return np.random.choice(num_bboxes, size=num_pasted, replace=False) + + def get_gt_masks(self, results: dict) -> BitmapMasks: + """Get gt_masks originally or generated based on bboxes. + + If gt_masks is not contained in results, + it will be generated based on gt_bboxes. + Args: + results (dict): Result dict. + Returns: + BitmapMasks: gt_masks, originally or generated based on bboxes. + """ + if results.get("gt_masks", None) is not None: + if self.paste_by_box: + warnings.warn( + "gt_masks is already contained in results, " + "so paste_by_box is disabled." + ) + return results["gt_masks"] + else: + if not self.paste_by_box: + raise RuntimeError("results does not contain masks.") + return results["gt_bboxes"].create_masks(results["img"].shape[:2]) + + def _select_object(self, results: dict) -> dict: + """Select some objects from the source results.""" + bboxes = results["gt_bboxes"] + labels = results["gt_bboxes_labels"] + masks = self.get_gt_masks(results) + ignore_flags = results["gt_ignore_flags"] + gt_instances_ids = results.get("gt_instances_ids", None) + + selected_inds = self._get_selected_inds(bboxes.shape[0]) + + selected_bboxes = bboxes[selected_inds] + selected_labels = labels[selected_inds] + selected_masks = masks[selected_inds] + selected_ignore_flags = ignore_flags[selected_inds] + selected_gt_instances_ids = gt_instances_ids[selected_inds] + + results["gt_bboxes"] = selected_bboxes + results["gt_bboxes_labels"] = selected_labels + results["gt_masks"] = selected_masks + results["gt_ignore_flags"] = selected_ignore_flags + results["gt_instances_ids"] = selected_gt_instances_ids + return results + + def _copy_paste(self, dst_results: dict, src_results: dict) -> dict: + """CopyPaste transform function. + + Args: + dst_results (dict): Result dict of the destination image. + src_results (dict): Result dict of the source image. + Returns: + dict: Updated result dict. + """ + dst_img = dst_results["img"] + dst_bboxes = dst_results["gt_bboxes"] + dst_labels = dst_results["gt_bboxes_labels"] + dst_masks = self.get_gt_masks(dst_results) + dst_ignore_flags = dst_results["gt_ignore_flags"] + dst_instances_ids = dst_results.get("gt_instances_ids", None) + + src_img = src_results["img"] + src_bboxes = src_results["gt_bboxes"] + src_labels = src_results["gt_bboxes_labels"] + src_masks = src_results["gt_masks"] + src_ignore_flags = src_results["gt_ignore_flags"] + src_instances_ids = src_results.get("gt_instances_ids", None) + + if len(src_bboxes) == 0: + return dst_results + + # update masks and generate bboxes from updated masks + composed_mask = np.where(np.any(src_masks.masks, axis=0), 1, 0) + updated_dst_masks = self._get_updated_masks(dst_masks, composed_mask) + updated_dst_bboxes = updated_dst_masks.get_bboxes(type(dst_bboxes)) + assert len(updated_dst_bboxes) == len(updated_dst_masks) + + # filter totally occluded objects + l1_distance = (updated_dst_bboxes.tensor - dst_bboxes.tensor).abs() + bboxes_inds = (l1_distance <= self.bbox_occluded_thr).all(dim=-1).numpy() + masks_inds = updated_dst_masks.masks.sum(axis=(1, 2)) > self.mask_occluded_thr + valid_inds = bboxes_inds | masks_inds + + # Paste source objects to destination image directly + img = ( + dst_img * (1 - composed_mask[..., np.newaxis]) + + src_img * composed_mask[..., np.newaxis] + ) + bboxes = src_bboxes.cat([updated_dst_bboxes[valid_inds], src_bboxes]) + labels = np.concatenate([dst_labels[valid_inds], src_labels]) + masks = np.concatenate([updated_dst_masks.masks[valid_inds], src_masks.masks]) + ignore_flags = np.concatenate([dst_ignore_flags[valid_inds], src_ignore_flags]) + instances_ids = np.concatenate( + [dst_instances_ids[valid_inds], src_instances_ids] + ) + + dst_results["img"] = img + dst_results["gt_bboxes"] = bboxes + dst_results["gt_bboxes_labels"] = labels + dst_results["gt_masks"] = BitmapMasks(masks, masks.shape[1], masks.shape[2]) + dst_results["gt_ignore_flags"] = ignore_flags + dst_results["gt_instances_ids"] = instances_ids + + return dst_results + + def _get_updated_masks( + self, masks: BitmapMasks, composed_mask: np.ndarray + ) -> BitmapMasks: + """Update masks with composed mask.""" + assert ( + masks.masks.shape[-2:] == composed_mask.shape[-2:] + ), "Cannot compare two arrays of different size" + masks.masks = np.where(composed_mask, 0, masks.masks) + return masks + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f"(max_num_pasted={self.max_num_pasted}, " + repr_str += f"bbox_occluded_thr={self.bbox_occluded_thr}, " + repr_str += f"mask_occluded_thr={self.mask_occluded_thr}, " + repr_str += f"selected={self.selected}), " + repr_str += f"paste_by_box={self.paste_by_box})" + return repr_str + + +@TRANSFORMS.register_module(force=True) +class SeqRandomAffine(BaseTransform): + """Random affine transform data augmentation. + + This operation randomly generates affine transform matrix which including + rotation, translation, shear and scaling transforms. + + Required Keys: + + - img + - gt_bboxes (BaseBoxes[torch.float32]) (optional) + - gt_bboxes_labels (np.int64) (optional) + - gt_ignore_flags (bool) (optional) + + Modified Keys: + + - img + - img_shape + - gt_bboxes (optional) + - gt_bboxes_labels (optional) + - gt_ignore_flags (optional) + + Args: + max_rotate_degree (float): Maximum degrees of rotation transform. + Defaults to 10. + max_translate_ratio (float): Maximum ratio of translation. + Defaults to 0.1. + scaling_ratio_range (tuple[float]): Min and max ratio of + scaling transform. Defaults to (0.5, 1.5). + max_shear_degree (float): Maximum degrees of shear + transform. Defaults to 2. + border (tuple[int]): Distance from width and height sides of input + image to adjust output shape. Only used in mosaic dataset. + Defaults to (0, 0). + border_val (tuple[int]): Border padding values of 3 channels. + Defaults to (114, 114, 114). + bbox_clip_border (bool, optional): Whether to clip the objects outside + the border of the image. In some dataset like MOT17, the gt bboxes + are allowed to cross the border of images. Therefore, we don't + need to clip the gt bboxes in these cases. Defaults to True. + """ + + def __init__( + self, + max_rotate_degree: float = 10.0, + max_translate_ratio: float = 0.1, + scaling_ratio_range: Tuple[float, float] = (0.5, 1.5), + max_shear_degree: float = 2.0, + border: Tuple[int, int] = (0, 0), + border_val: Tuple[int, int, int] = (114, 114, 114), + bbox_clip_border: bool = True, + ) -> None: + assert 0 <= max_translate_ratio <= 1 + assert scaling_ratio_range[0] <= scaling_ratio_range[1] + assert scaling_ratio_range[0] > 0 + self.max_rotate_degree = max_rotate_degree + self.max_translate_ratio = max_translate_ratio + self.scaling_ratio_range = scaling_ratio_range + self.max_shear_degree = max_shear_degree + self.border = border + self.border_val = border_val + self.bbox_clip_border = bbox_clip_border + + @cache_randomness + def _get_random_homography_matrix(self, height, width): + # Rotation + rotation_degree = random.uniform( + -self.max_rotate_degree, self.max_rotate_degree + ) + rotation_matrix = self._get_rotation_matrix(rotation_degree) + + # Scaling + scaling_ratio = random.uniform( + self.scaling_ratio_range[0], self.scaling_ratio_range[1] + ) + scaling_matrix = self._get_scaling_matrix(scaling_ratio) + + # Shear + x_degree = random.uniform(-self.max_shear_degree, self.max_shear_degree) + y_degree = random.uniform(-self.max_shear_degree, self.max_shear_degree) + shear_matrix = self._get_shear_matrix(x_degree, y_degree) + + # Translation + trans_x = ( + random.uniform(-self.max_translate_ratio, self.max_translate_ratio) * width + ) + trans_y = ( + random.uniform(-self.max_translate_ratio, self.max_translate_ratio) * height + ) + translate_matrix = self._get_translation_matrix(trans_x, trans_y) + + warp_matrix = translate_matrix @ shear_matrix @ rotation_matrix @ scaling_matrix + return warp_matrix + + @autocast_box_type() + def transform(self, results: dict) -> dict: + img = results["img"] + height = img.shape[0] + self.border[1] * 2 + width = img.shape[1] + self.border[0] * 2 + + warp_matrix = self._get_random_homography_matrix(height, width) + + img = cv2.warpPerspective( + img, warp_matrix, dsize=(width, height), borderValue=self.border_val + ) + results["img"] = img + results["img_shape"] = img.shape[:2] + + bboxes = results["gt_bboxes"] + num_bboxes = len(bboxes) + if num_bboxes: + bboxes.project_(warp_matrix) + if self.bbox_clip_border: + bboxes.clip_([height, width]) + # remove outside bbox + valid_index = bboxes.is_inside([height, width]).numpy() + results["gt_bboxes"] = bboxes[valid_index] + results["gt_bboxes_labels"] = results["gt_bboxes_labels"][valid_index] + results["gt_ignore_flags"] = results["gt_ignore_flags"][valid_index] + results["gt_instances_ids"] = results["gt_instances_ids"][valid_index] + assert len(results["gt_bboxes"]) == len(results["gt_instances_ids"]) + if "gt_masks" in results: + raise NotImplementedError("RandomAffine only supports bbox.") + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f"(max_rotate_degree={self.max_rotate_degree}, " + repr_str += f"max_translate_ratio={self.max_translate_ratio}, " + repr_str += f"scaling_ratio_range={self.scaling_ratio_range}, " + repr_str += f"max_shear_degree={self.max_shear_degree}, " + repr_str += f"border={self.border}, " + repr_str += f"border_val={self.border_val}, " + repr_str += f"bbox_clip_border={self.bbox_clip_border})" + return repr_str + + @staticmethod + def _get_rotation_matrix(rotate_degrees: float) -> np.ndarray: + radian = math.radians(rotate_degrees) + rotation_matrix = np.array( + [ + [np.cos(radian), -np.sin(radian), 0.0], + [np.sin(radian), np.cos(radian), 0.0], + [0.0, 0.0, 1.0], + ], + dtype=np.float32, + ) + return rotation_matrix + + @staticmethod + def _get_scaling_matrix(scale_ratio: float) -> np.ndarray: + scaling_matrix = np.array( + [[scale_ratio, 0.0, 0.0], [0.0, scale_ratio, 0.0], [0.0, 0.0, 1.0]], + dtype=np.float32, + ) + return scaling_matrix + + @staticmethod + def _get_shear_matrix(x_shear_degrees: float, y_shear_degrees: float) -> np.ndarray: + x_radian = math.radians(x_shear_degrees) + y_radian = math.radians(y_shear_degrees) + shear_matrix = np.array( + [[1, np.tan(x_radian), 0.0], [np.tan(y_radian), 1, 0.0], [0.0, 0.0, 1.0]], + dtype=np.float32, + ) + return shear_matrix + + @staticmethod + def _get_translation_matrix(x: float, y: float) -> np.ndarray: + translation_matrix = np.array( + [[1, 0.0, x], [0.0, 1, y], [0.0, 0.0, 1.0]], dtype=np.float32 + ) + return translation_matrix diff --git a/masa/datasets/pipelines/wrappers.py b/masa/datasets/pipelines/wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..2d2e7f94849e40842f3562b81ee3a9b90f3ecc04 --- /dev/null +++ b/masa/datasets/pipelines/wrappers.py @@ -0,0 +1,173 @@ +from typing import Callable, Dict, List, Optional, Sequence, Union + +import cv2 +import numpy as np +from mmcv.transforms import TRANSFORMS +from mmcv.transforms.utils import cache_random_params +from mmcv.transforms.wrappers import * + +# Define type of transform or transform config +Transform = Union[Dict, Callable[[Dict], Dict]] + +# Indicator of keys marked by KeyMapper._map_input, which means ignoring the +# marked keys in KeyMapper._apply_transform so they will be invisible to +# wrapped transforms. +# This can be 2 possible case: +# 1. The key is required but missing in results +# 2. The key is manually set as ... (Ellipsis) in ``mapping``, which means +# the original value in results should be ignored +IgnoreKey = object() + +# Import nullcontext if python>=3.7, otherwise use a simple alternative +# implementation. +try: + from contextlib import nullcontext # type: ignore +except ImportError: + from contextlib import contextmanager + + @contextmanager # type: ignore + def nullcontext(resource=None): + try: + yield resource + finally: + pass + + +def imdenormalize(img, mean, std, to_bgr=True): + assert img.dtype != np.uint8 + mean = mean.reshape(1, -1).astype(np.float64) + std = std.reshape(1, -1).astype(np.float64) + img = cv2.multiply(img, std) # make a copy + cv2.add(img, mean, img) # inplace + if to_bgr: + cv2.cvtColor(img, cv2.COLOR_RGB2BGR, img) # inplace + return img + + +@TRANSFORMS.register_module() +class MasaTransformBroadcaster(KeyMapper): + """A transform wrapper to apply the wrapped transforms to multiple data + items. For example, apply Resize to multiple images. + + Args: + transforms (list[dict | callable]): Sequence of transform object or + config dict to be wrapped. + mapping (dict): A dict that defines the input key mapping. + Note that to apply the transforms to multiple data items, the + outer keys of the target items should be remapped as a list with + the standard inner key (The key required by the wrapped transform). + See the following example and the document of + ``mmcv.transforms.wrappers.KeyMapper`` for details. + remapping (dict): A dict that defines the output key mapping. + The keys and values have the same meanings and rules as in the + ``mapping``. Default: None. + auto_remap (bool, optional): If True, an inverse of the mapping will + be used as the remapping. If auto_remap is not given, it will be + automatically set True if 'remapping' is not given, and vice + versa. Default: None. + allow_nonexist_keys (bool): If False, the outer keys in the mapping + must exist in the input data, or an exception will be raised. + Default: False. + share_random_params (bool): If True, the random transform + (e.g., RandomFlip) will be conducted in a deterministic way and + have the same behavior on all data items. For example, to randomly + flip either both input image and ground-truth image, or none. + Default: False. + + """ + + def __init__( + self, + transforms: List[Union[Dict, Callable[[Dict], Dict]]], + mapping: Optional[Dict] = None, + remapping: Optional[Dict] = None, + auto_remap: Optional[bool] = None, + allow_nonexist_keys: bool = False, + share_random_params: bool = False, + ): + super().__init__( + transforms, mapping, remapping, auto_remap, allow_nonexist_keys + ) + + self.share_random_params = share_random_params + + def scatter_sequence(self, data: Dict) -> List[Dict]: + """Scatter the broadcasting targets to a list of inputs of the wrapped + transforms.""" + + # infer split number from input + seq_len = 0 + key_rep = None + + if self.mapping: + keys = self.mapping.keys() + else: + keys = data.keys() + + for key in keys: + assert isinstance(data[key], Sequence) + if seq_len: + if len(data[key]) != seq_len: + raise ValueError( + "Got inconsistent sequence length: " + f"{seq_len} ({key_rep}) vs. " + f"{len(data[key])} ({key})" + ) + else: + seq_len = len(data[key]) + key_rep = key + + assert seq_len > 0, "Fail to get the number of broadcasting targets" + + scatters = [] + for i in range(seq_len): # type: ignore + scatter = data.copy() + for key in keys: + scatter[key] = data[key][i] + scatters.append(scatter) + return scatters + + def transform(self, results: Dict): + """Broadcast wrapped transforms to multiple targets.""" + + # Apply input remapping + inputs = self._map_input(results, self.mapping) + + # Scatter sequential inputs into a list + input_scatters = self.scatter_sequence(inputs) + + # Control random parameter sharing with a context manager + if self.share_random_params: + # The context manager :func`:cache_random_params` will let + # cacheable method of the transforms cache their outputs. Thus + # the random parameters will only generated once and shared + # by all data items. + ctx = cache_random_params # type: ignore + else: + ctx = nullcontext # type: ignore + + with ctx(self.transforms): + output_scatters = [ + self._apply_transforms(_input) for _input in input_scatters + ] + + outputs = { + key: [_output[key] for _output in output_scatters] + for key in output_scatters[0] + } + + # Apply remapping + outputs = self._map_output(outputs, self.remapping) + + results.update(outputs) + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f"(transforms = {self.transforms}" + repr_str += f", mapping = {self.mapping}" + repr_str += f", remapping = {self.remapping}" + repr_str += f", auto_remap = {self.auto_remap}" + repr_str += f", allow_nonexist_keys = {self.allow_nonexist_keys}" + repr_str += f", share_random_params = {self.share_random_params})" + return repr_str diff --git a/masa/datasets/rsconcat_dataset.py b/masa/datasets/rsconcat_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..db3e87660f81c0358dd47b1f1089ad012c6dcb94 --- /dev/null +++ b/masa/datasets/rsconcat_dataset.py @@ -0,0 +1,209 @@ +import random +from typing import Iterable, List + +import numpy as np +from mmdet.datasets.base_det_dataset import BaseDetDataset +from mmdet.datasets.base_video_dataset import BaseVideoDataset +from mmdet.registry import DATASETS +from mmengine.dataset import BaseDataset +from torch.utils.data import Dataset +from torch.utils.data.dataset import ConcatDataset as _ConcatDataset + + +@DATASETS.register_module() +class RandomSampleConcatDataset(_ConcatDataset): + def __init__( + self, + datasets: Iterable[Dataset], + sampling_probs: List[float], + fixed_length: int, + lazy_init: bool = False, + ): + super(RandomSampleConcatDataset, self).__init__(datasets) + assert len(sampling_probs) == len( + datasets + ), "Number of sampling probabilities must match the number of datasets" + assert sum(sampling_probs) == 1.0, "Sum of sampling probabilities must be 1.0" + + self.datasets: List[BaseDataset] = [] + for i, dataset in enumerate(datasets): + if isinstance(dataset, dict): + self.datasets.append(DATASETS.build(dataset)) + elif isinstance(dataset, BaseDataset): + self.datasets.append(dataset) + else: + raise TypeError( + "elements in datasets sequence should be config or " + f"`BaseDataset` instance, but got {type(dataset)}" + ) + self.sampling_probs = sampling_probs + self.fixed_length = fixed_length + + self.metainfo = self.datasets[0].metainfo + total_datasets_length = sum([len(dataset) for dataset in self.datasets]) + assert ( + self.fixed_length <= total_datasets_length + ), "the length of the concatenated dataset must be less than the sum of the lengths of the individual datasets" + self.flag = np.zeros(self.fixed_length, dtype=np.uint8) + + self._fully_initialized = False + if not lazy_init: + self.full_init() + + def full_init(self): + """Loop to ``full_init`` each dataset.""" + if self._fully_initialized: + return + for i, dataset in enumerate(self.datasets): + dataset.full_init() + self._ori_len = self.fixed_length + self._fully_initialized = True + + def get_data_info(self, idx: int) -> dict: + """Get annotation by index. + + Args: + idx (int): Global index of ``ConcatDataset``. + + Returns: + dict: The idx-th annotation of the datasets. + """ + return self.dataset.get_data_info(idx) + + def __len__(self): + return self.fixed_length + + def __getitem__(self, idx): + # Choose a dataset based on the sampling probabilities + chosen_dataset_idx = random.choices( + range(len(self.datasets)), weights=self.sampling_probs, k=1 + )[0] + chosen_dataset = self.datasets[chosen_dataset_idx] + + # Sample a random item from the chosen dataset + sample_idx = random.randrange(0, len(chosen_dataset)) + return chosen_dataset[sample_idx] + + +@DATASETS.register_module() +class RandomSampleJointVideoConcatDataset(_ConcatDataset): + def __init__( + self, + datasets: Iterable[Dataset], + fixed_length: int, + lazy_init: bool = False, + video_sampling_probs: List[float] = [], + img_sampling_probs: List[float] = [], + *args, + **kwargs, + ): + super(RandomSampleJointVideoConcatDataset, self).__init__(datasets) + + self.datasets: List[BaseDataset] = [] + for i, dataset in enumerate(datasets): + if isinstance(dataset, dict): + self.datasets.append(DATASETS.build(dataset)) + elif isinstance(dataset, BaseDataset): + self.datasets.append(dataset) + else: + raise TypeError( + "elements in datasets sequence should be config or " + f"`BaseDataset` instance, but got {type(dataset)}" + ) + + self.video_dataset_idx = [] + self.img_dataset_idx = [] + self.datasets_indices_mapping = {} + for i, dataset in enumerate(self.datasets): + if isinstance(dataset, BaseVideoDataset): + self.video_dataset_idx.append(i) + num_videos = len(dataset) + video_indices = [] + for video_ind in range(num_videos): + video_indices.extend( + [ + (video_ind, frame_ind) + for frame_ind in range(dataset.get_len_per_video(video_ind)) + ] + ) + self.datasets_indices_mapping[i] = video_indices + + elif isinstance(dataset, BaseDetDataset): + self.img_dataset_idx.append(i) + img_indices = [] + num_imgs = len(dataset) + for img_ind in range(num_imgs): + img_indices.extend([img_ind]) + self.datasets_indices_mapping[i] = img_indices + + else: + raise TypeError( + "elements in datasets sequence should be config or " + f"`BaseDataset` instance, but got {type(dataset)}" + ) + + self.fixed_length = fixed_length + self.metainfo = self.datasets[0].metainfo + total_datasets_length = sum( + [len(indices) for key, indices in self.datasets_indices_mapping.items()] + ) + assert ( + self.fixed_length <= total_datasets_length + ), "the length of the concatenated dataset must be less than the sum of the lengths of the individual datasets" + self.flag = np.zeros(self.fixed_length, dtype=np.uint8) + + self._fully_initialized = False + if not lazy_init: + self.full_init() + + self.video_sampling_probs = video_sampling_probs + self.img_sampling_probs = img_sampling_probs + if self.video_sampling_probs: + assert ( + sum(self.video_sampling_probs) == 1.0 + ), "Sum of video sampling probabilities must be 1.0" + if self.img_sampling_probs: + assert ( + sum(self.img_sampling_probs) == 1.0 + ), "Sum of image sampling probabilities must be 1.0" + + def full_init(self): + """Loop to ``full_init`` each dataset.""" + if self._fully_initialized: + return + for i, dataset in enumerate(self.datasets): + dataset.full_init() + self._ori_len = self.fixed_length + self._fully_initialized = True + + def get_data_info(self, idx: int) -> dict: + """Get annotation by index. + + Args: + idx (int): Global index of ``ConcatDataset``. + + Returns: + dict: The idx-th annotation of the datasets. + """ + return self.dataset.get_data_info(idx) + + def __len__(self): + return self.fixed_length + + def __getitem__(self, idx): + # idx ==0 means samples from video dataset, idx == 1 means samples from image dataset + # Choose a dataset based on the sampling probabilities + if idx == 0: + chosen_dataset_idx = random.choices( + self.video_dataset_idx, weights=self.video_sampling_probs, k=1 + )[0] + elif idx == 1: + chosen_dataset_idx = random.choices( + self.img_dataset_idx, weights=self.img_sampling_probs, k=1 + )[0] + + chosen_dataset = self.datasets[chosen_dataset_idx] + # Sample a random item from the chosen dataset + sample_idx = random.choice(self.datasets_indices_mapping[chosen_dataset_idx]) + + return chosen_dataset[sample_idx] diff --git a/masa/datasets/samplers/__init__.py b/masa/datasets/samplers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cbcb5dfc9347cab47813bcc6f484aa06b08ce365 --- /dev/null +++ b/masa/datasets/samplers/__init__.py @@ -0,0 +1,4 @@ +from .distributed_video_sampler import DistributedVideoSampler +from .hybrid_video_img_sampler import HybridVideoImgSampler + +__all__ = ["DistributedVideoSampler", HybridVideoImgSampler] diff --git a/masa/datasets/samplers/distributed_video_sampler.py b/masa/datasets/samplers/distributed_video_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..70f7fdc4dc54e7b75c6260a9b2b738af6c8d1e43 --- /dev/null +++ b/masa/datasets/samplers/distributed_video_sampler.py @@ -0,0 +1,28 @@ +import numpy as np +from torch.utils.data import DistributedSampler as _DistributedSampler + + +class DistributedVideoSampler(_DistributedSampler): + def __init__(self, dataset, num_replicas=None, rank=None, shuffle=False): + super().__init__(dataset, num_replicas=num_replicas, rank=rank) + self.shuffle = shuffle + assert not self.shuffle, "Specific for video sequential testing." + self.num_samples = len(dataset) + + first_frame_indices = [] + for i, img_info in enumerate(self.dataset.data_infos): + if img_info["frame_id"] == 0: + first_frame_indices.append(i) + + chunks = np.array_split(first_frame_indices, num_replicas) + split_flags = [c[0] for c in chunks] + split_flags.append(self.num_samples) + + self.indices = [ + list(range(split_flags[i], split_flags[i + 1])) + for i in range(self.num_replicas) + ] + + def __iter__(self): + indices = self.indices[self.rank] + return iter(indices) diff --git a/masa/datasets/samplers/hybrid_video_img_sampler.py b/masa/datasets/samplers/hybrid_video_img_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..557107f08e10c3df818aa35260cedba8e65b27d1 --- /dev/null +++ b/masa/datasets/samplers/hybrid_video_img_sampler.py @@ -0,0 +1,212 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +import random +from typing import Iterator, Optional, Sized + +import numpy as np +from mmdet.datasets.base_det_dataset import BaseDetDataset +from mmdet.datasets.base_video_dataset import BaseVideoDataset +from mmdet.registry import DATA_SAMPLERS +from mmengine.dataset import ClassBalancedDataset, ConcatDataset +from mmengine.dist import get_dist_info, sync_random_seed +from torch.utils.data import Sampler +from torch.utils.data.dataset import ConcatDataset as _ConcatDataset + +from ..dataset_wrappers import SeqMultiImageMixDataset + + +@DATA_SAMPLERS.register_module() +class HybridVideoImgSampler(Sampler): + """Sampler that providing image-level sampling outputs for video datasets + in tracking tasks. It could be both used in both distributed and + non-distributed environment. + If using the default sampler in pytorch, the subsequent data receiver will + get one video, which is not desired in some cases: + (Take a non-distributed environment as an example) + 1. In test mode, we want only one image is fed into the data pipeline. This + is in consideration of memory usage since feeding the whole video commonly + requires a large amount of memory (>=20G on MOTChallenge17 dataset), which + is not available in some machines. + 2. In training mode, we may want to make sure all the images in one video + are randomly sampled once in one epoch and this can not be guaranteed in + the default sampler in pytorch. + + Args: + dataset (Sized): Dataset used for sampling. + seed (int, optional): random seed used to shuffle the sampler. This + number should be identical across all processes in the distributed + group. Defaults to None. + """ + + def __init__(self, dataset: Sized, seed: Optional[int] = None,) -> None: + rank, world_size = get_dist_info() + self.rank = rank + self.world_size = world_size + self.epoch = 0 + if seed is None: + self.seed = sync_random_seed() + else: + self.seed = seed + + self.dataset = dataset + self.indices = [] + # Hard code here to handle different dataset wrapper + if isinstance(self.dataset, ConcatDataset): + cat_datasets = self.dataset.datasets + assert isinstance( + cat_datasets[0], BaseVideoDataset + ), f"expected BaseVideoDataset, but got {type(cat_datasets[0])}" + self.test_mode = cat_datasets[0].test_mode + assert not self.test_mode, "'ConcatDataset' should not exist in " + "test mode" + for dataset in cat_datasets: + num_videos = len(dataset) + for video_ind in range(num_videos): + self.indices.extend( + [ + (video_ind, frame_ind) + for frame_ind in range(dataset.get_len_per_video(video_ind)) + ] + ) + elif isinstance(self.dataset, ClassBalancedDataset): + ori_dataset = self.dataset.dataset + assert isinstance( + ori_dataset, BaseVideoDataset + ), f"expected BaseVideoDataset, but got {type(ori_dataset)}" + self.test_mode = ori_dataset.test_mode + assert not self.test_mode, "'ClassBalancedDataset' should not " + "exist in test mode" + video_indices = self.dataset.repeat_indices + for index in video_indices: + self.indices.extend( + [ + (index, frame_ind) + for frame_ind in range(ori_dataset.get_len_per_video(index)) + ] + ) + elif isinstance(self.dataset, BaseVideoDataset): + self.test_mode = self.dataset.test_mode + num_videos = len(self.dataset) + + if self.test_mode: + # in test mode, the images belong to the same video must be put + # on the same device. + if num_videos < self.world_size: + raise ValueError( + f"only {num_videos} videos loaded," + f"but {self.world_size} gpus were given." + ) + chunks = np.array_split(list(range(num_videos)), self.world_size) + for videos_inds in chunks: + indices_chunk = [] + for video_ind in videos_inds: + indices_chunk.extend( + [ + (video_ind, frame_ind) + for frame_ind in range( + self.dataset.get_len_per_video(video_ind) + ) + ] + ) + self.indices.append(indices_chunk) + else: + for video_ind in range(num_videos): + self.indices.extend( + [ + (video_ind, frame_ind) + for frame_ind in range( + self.dataset.get_len_per_video(video_ind) + ) + ] + ) + else: + assert isinstance(self.dataset, SeqMultiImageMixDataset), ( + "HybridVideoImgSampler is only supported in BaseVideoDataset or " + "dataset wrapper: ClassBalancedDataset and ConcatDataset,SeqMultiImageMixDataset, but " + f"got {type(self.dataset)} " + ) + self.test_mode = self.dataset.test_mode + # num_videos = len(self.dataset) + if self.test_mode: + print("Not support test mode") + raise NotImplementedError + else: + assert isinstance( + self.dataset.dataset, _ConcatDataset + ), "HybridVideoImgSampler is only supported in _ConcatDataset" + cat_datasets = self.dataset.dataset.datasets + for dataset in cat_datasets: + self.test_mode = dataset.test_mode + assert not self.test_mode, "'ConcatDataset' should not exist in " + "test mode" + if isinstance(dataset, BaseVideoDataset): + num_videos = len(dataset) + video_indices = [] + for video_ind in range(num_videos): + video_indices.extend( + [ + (video_ind, frame_ind) + for frame_ind in range( + dataset.get_len_per_video(video_ind) + ) + ] + ) + elif isinstance(dataset, BaseDetDataset): + img_indices = [] + num_imgs = len(dataset) + for img_ind in range(num_imgs): + img_indices.extend([img_ind]) + ###### special process to make debug task easier ##### + def alternate_merge(list1, list2): + # Create a new list to hold the merged elements + merged_list = [] + + # Get the length of the shorter list + min_length = min(len(list1), len(list2)) + + # Append elements alternately from both lists + for i in range(min_length): + merged_list.append(list1[i]) + merged_list.append(list2[i]) + + # Append the remaining elements from the longer list + if len(list1) > len(list2): + merged_list.extend(list1[min_length:]) + else: + merged_list.extend(list2[min_length:]) + + return merged_list + + self.indices = alternate_merge(img_indices, video_indices) + + if self.test_mode: + self.num_samples = len(self.indices[self.rank]) + self.total_size = sum([len(index_list) for index_list in self.indices]) + else: + self.num_samples = int(math.ceil(len(self.indices) * 1.0 / self.world_size)) + self.total_size = self.num_samples * self.world_size + + def __iter__(self) -> Iterator: + if self.test_mode: + # in test mode, the order of frames can not be shuffled. + indices = self.indices[self.rank] + else: + # deterministically shuffle based on epoch + rng = random.Random(self.epoch + self.seed) + indices = rng.sample(self.indices, len(self.indices)) + + # add extra samples to make it evenly divisible + indices += indices[: (self.total_size - len(indices))] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank : self.total_size : self.world_size] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/masa/datasets/tao_masa_dataset.py b/masa/datasets/tao_masa_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..2e1600714cd7b59a6e9c2f4d692dbc167a132e82 --- /dev/null +++ b/masa/datasets/tao_masa_dataset.py @@ -0,0 +1,101 @@ +import os.path as osp +from collections import defaultdict +from typing import Any, List, Tuple + +import numpy as np +from mmdet.datasets import BaseVideoDataset, LVISV1Dataset, LVISV05Dataset +from mmdet.registry import DATASETS + + +@DATASETS.register_module() +class Taov05Dataset(BaseVideoDataset): + """Dataset for TAO benchmark. + + """ + + METAINFO = LVISV05Dataset.METAINFO + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.flag = np.zeros(len(self), dtype=np.uint8) + + def _rand_another(self, idx): + """Get another random index from the same group as the given index.""" + pool = np.where(self.flag == self.flag[idx])[0] + return np.random.choice(pool) + + def prepare_data(self, idx) -> Any: + """Get date processed by ``self.pipeline``. Note that ``idx`` is a + video index in default since the base element of video dataset is a + video. However, in some cases, we need to specific both the video index + and frame index. For example, in traing mode, we may want to sample the + specific frames and all the frames must be sampled once in a epoch; in + test mode, we may want to output data of a single image rather than the + whole video for saving memory. + + Args: + idx (int): The index of ``data_info``. + + Returns: + Any: Depends on ``self.pipeline``. + """ + if isinstance(idx, tuple): + assert len(idx) == 2, "The length of idx must be 2: " + "(video_index, frame_index)" + video_idx, frame_idx = idx[0], idx[1] + else: + video_idx, frame_idx = idx, None + + data_info = self.get_data_info(video_idx) + if self.test_mode: + # Support two test_mode: frame-level and video-level + final_data_info = defaultdict(list) + if frame_idx is None: + frames_idx_list = list(range(data_info["video_length"])) + else: + frames_idx_list = [frame_idx] + for index in frames_idx_list: + frame_ann = data_info["images"][index] + frame_ann["video_id"] = data_info["video_id"] + # Collate data_list (list of dict to dict of list) + for key, value in frame_ann.items(): + final_data_info[key].append(value) + # copy the info in video-level into img-level + # TODO: the value of this key is the same as that of + # `video_length` in test mode + final_data_info["ori_video_length"].append(data_info["video_length"]) + + final_data_info["video_length"] = [len(frames_idx_list)] * len( + frames_idx_list + ) + return self.pipeline(final_data_info) + else: + # Specify `key_frame_id` for the frame sampling in the pipeline + if frame_idx is not None: + data_info["key_frame_id"] = frame_idx + + for i in range(self.max_refetch): + # To confirm the results passed the training pipeline + # of the wrapper is not None. + try: + + data = self.pipeline(data_info) + except Exception as e: + print("Error occurred while running pipeline", f" with error: {e}") + # print('Empty instances due to augmentation, re-sampling...') + video_idx = self._rand_another(video_idx) + data_info = self.get_data_info(video_idx) + continue + + if data is not None: + break + return data + + +@DATASETS.register_module() +class Taov1Dataset(Taov05Dataset): + """Dataset for TAO benchmark. + + """ + + METAINFO = LVISV1Dataset.METAINFO diff --git a/masa/datasets/utils.py b/masa/datasets/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ffebeb11e883afae4b103fc28638c52718ee38ac --- /dev/null +++ b/masa/datasets/utils.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence + +import torch +from mmengine.dataset import COLLATE_FUNCTIONS + + +@COLLATE_FUNCTIONS.register_module() +def yolow_collate(data_batch: Sequence, use_ms_training: bool = False) -> dict: + """Rewrite collate_fn to get faster training speed. + + Args: + data_batch (Sequence): Batch of data. + use_ms_training (bool): Whether to use multi-scale training. + """ + batch_imgs = [] + batch_bboxes_labels = [] + batch_masks = [] + for i in range(len(data_batch)): + datasamples = data_batch[i]["data_samples"] + inputs = data_batch[i]["inputs"] + batch_imgs.append(inputs) + + gt_bboxes = datasamples.gt_instances.bboxes.tensor + gt_labels = datasamples.gt_instances.labels + if "masks" in datasamples.gt_instances: + masks = datasamples.gt_instances.masks.to( + dtype=torch.bool, device=gt_bboxes.device + ) + batch_masks.append(masks) + batch_idx = gt_labels.new_full((len(gt_labels), 1), i) + bboxes_labels = torch.cat((batch_idx, gt_labels[:, None], gt_bboxes), dim=1) + batch_bboxes_labels.append(bboxes_labels) + + collated_results = { + "data_samples": {"bboxes_labels": torch.cat(batch_bboxes_labels, 0)} + } + if len(batch_masks) > 0: + collated_results["data_samples"]["masks"] = torch.cat(batch_masks, 0) + + if use_ms_training: + collated_results["inputs"] = batch_imgs + else: + collated_results["inputs"] = torch.stack(batch_imgs, 0) + + if hasattr(data_batch[0]["data_samples"], "texts"): + batch_texts = [meta["data_samples"].texts for meta in data_batch] + collated_results["data_samples"]["texts"] = batch_texts + + if hasattr(data_batch[0]["data_samples"], "is_detection"): + # detection flag + batch_detection = [meta["data_samples"].is_detection for meta in data_batch] + collated_results["data_samples"]["is_detection"] = torch.tensor(batch_detection) + + return collated_results diff --git a/masa/models/__init__.py b/masa/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d68306ba1a2a6fcaf2de10117d21f4230f9bb236 --- /dev/null +++ b/masa/models/__init__.py @@ -0,0 +1,7 @@ +from .detectors import * # noqa +from .losses import * # noqa +from .mot import * # noqa +from .necks import * # noqa +from .roi_heads import MasaTrackHead +from .sam import * +from .tracker import * diff --git a/masa/models/__pycache__/__init__.cpython-311.pyc b/masa/models/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dfa2201b76862adecbd550711e28ca916a85352d Binary files /dev/null and b/masa/models/__pycache__/__init__.cpython-311.pyc differ diff --git a/masa/models/detectors/__init__.py b/masa/models/detectors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2e99d2e39e59e5eb549844ecd5ae380e2c804907 --- /dev/null +++ b/masa/models/detectors/__init__.py @@ -0,0 +1,6 @@ +from .detic_masa import DeticMasa +from .gdino_masa import GroundingDINOMasa +from .grounding_dino import GroundingDINO +from .sam_masa import SamMasa + +__all__ = ["GroundingDINO", "DeticMasa", "GroundingDINOMasa", "SamMasa"] diff --git a/masa/models/detectors/__pycache__/__init__.cpython-311.pyc b/masa/models/detectors/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67b2889dd333456bec403657ffd9816bfc21a750 Binary files /dev/null and b/masa/models/detectors/__pycache__/__init__.cpython-311.pyc differ diff --git a/masa/models/detectors/__pycache__/detic_masa.cpython-311.pyc b/masa/models/detectors/__pycache__/detic_masa.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41be89eb5a6b3f788deb79c22256a5f037896da6 Binary files /dev/null and b/masa/models/detectors/__pycache__/detic_masa.cpython-311.pyc differ diff --git a/masa/models/detectors/__pycache__/gdino_masa.cpython-311.pyc b/masa/models/detectors/__pycache__/gdino_masa.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1f3390239d5888070dfe69c33b515e0cdf7e5f1 Binary files /dev/null and b/masa/models/detectors/__pycache__/gdino_masa.cpython-311.pyc differ diff --git a/masa/models/detectors/__pycache__/grounding_dino.cpython-311.pyc b/masa/models/detectors/__pycache__/grounding_dino.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e35ec6a3209e88f4a4b7091263097a7642e27e01 Binary files /dev/null and b/masa/models/detectors/__pycache__/grounding_dino.cpython-311.pyc differ diff --git a/masa/models/detectors/__pycache__/sam_masa.cpython-311.pyc b/masa/models/detectors/__pycache__/sam_masa.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4daf9e86f341d087d602f09143929bb01ba8a9a2 Binary files /dev/null and b/masa/models/detectors/__pycache__/sam_masa.cpython-311.pyc differ diff --git a/masa/models/detectors/detic_masa.py b/masa/models/detectors/detic_masa.py new file mode 100644 index 0000000000000000000000000000000000000000..0577a4c9059f2e62700c87902ae003f5a55f5bfa --- /dev/null +++ b/masa/models/detectors/detic_masa.py @@ -0,0 +1,218 @@ +""" +Author: Siyuan Li +Licensed: Apache-2.0 License +""" +from typing import List, Union + +import numpy as np +import pycocotools.mask as mask_util +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmdet.registry import MODELS +from mmdet.structures import SampleList +from mmengine.logging import print_log +from torch import Tensor + +from projects.Detic_new.detic import Detic + + +def encode_mask_results(mask_results): + """Encode bitmap mask to RLE code. + + Args: + mask_results (list): bitmap mask results. + + Returns: + list | tuple: RLE encoded mask. + """ + encoded_mask_results = [] + for mask in mask_results: + encoded_mask_results.append( + mask_util.encode( + np.array(mask[:, :, np.newaxis], order="F", dtype="uint8") + )[0] + ) # encoded with RLE + return encoded_mask_results + + +class CLIPTextEncoder(nn.Module): + def __init__(self, model_name="ViT-B/32"): + super().__init__() + import clip + from clip.simple_tokenizer import SimpleTokenizer + + self.tokenizer = SimpleTokenizer() + pretrained_model, _ = clip.load(model_name, device="cpu") + self.clip = pretrained_model + + @property + def device(self): + return self.clip.device + + @property + def dtype(self): + return self.clip.dtype + + def tokenize( + self, texts: Union[str, List[str]], context_length: int = 77 + ) -> torch.LongTensor: + if isinstance(texts, str): + texts = [texts] + + sot_token = self.tokenizer.encoder["<|startoftext|>"] + eot_token = self.tokenizer.encoder["<|endoftext|>"] + all_tokens = [ + [sot_token] + self.tokenizer.encode(text) + [eot_token] for text in texts + ] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + st = torch.randint(len(tokens) - context_length + 1, (1,))[0].item() + tokens = tokens[st : st + context_length] + result[i, : len(tokens)] = torch.tensor(tokens) + + return result + + def forward(self, text): + text = self.tokenize(text) + text_features = self.clip.encode_text(text) + return text_features + + +def get_class_weight(original_caption, prompt_prefix="a "): + if isinstance(original_caption, str): + if original_caption == "coco": + from mmdet.datasets import CocoDataset + + class_names = CocoDataset.METAINFO["classes"] + elif original_caption == "cityscapes": + from mmdet.datasets import CityscapesDataset + + class_names = CityscapesDataset.METAINFO["classes"] + elif original_caption == "voc": + from mmdet.datasets import VOCDataset + + class_names = VOCDataset.METAINFO["classes"] + elif original_caption == "openimages": + from mmdet.datasets import OpenImagesDataset + + class_names = OpenImagesDataset.METAINFO["classes"] + elif original_caption == "lvis": + from mmdet.datasets import LVISV1Dataset + + class_names = LVISV1Dataset.METAINFO["classes"] + else: + if not original_caption.endswith("."): + original_caption = original_caption + " . " + original_caption = original_caption.split(" . ") + class_names = list(filter(lambda x: len(x) > 0, original_caption)) + + # for test.py + else: + class_names = list(original_caption) + + text_encoder = CLIPTextEncoder() + text_encoder.eval() + texts = [prompt_prefix + x for x in class_names] + print_log(f"Computing text embeddings for {len(class_names)} classes.") + embeddings = text_encoder(texts).detach().permute(1, 0).contiguous().cpu() + return class_names, embeddings + + +def reset_cls_layer_weight(roi_head, weight): + if type(weight) == str: + print_log(f"Resetting cls_layer_weight from file: {weight}") + zs_weight = ( + torch.tensor(np.load(weight), dtype=torch.float32) + .permute(1, 0) + .contiguous() + ) # D x C + else: + zs_weight = weight + zs_weight = torch.cat( + [zs_weight, zs_weight.new_zeros((zs_weight.shape[0], 1))], dim=1 + ) # D x (C + 1) + zs_weight = F.normalize(zs_weight, p=2, dim=0) + zs_weight = zs_weight.to("cuda") + num_classes = zs_weight.shape[-1] + + for bbox_head in roi_head.bbox_head: + bbox_head.num_classes = num_classes + del bbox_head.fc_cls.zs_weight + bbox_head.fc_cls.zs_weight = zs_weight + + +@MODELS.register_module() +class DeticMasa(Detic): + def predict( + self, + batch_inputs: Tensor, + detection_features: Tensor, + batch_data_samples: SampleList, + rescale: bool = True, + ) -> SampleList: + """Predict results from a batch of inputs and data samples with post- + processing. + + Args: + batch_inputs (Tensor): Inputs with shape (N, C, H, W). + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + rescale (bool): Whether to rescale the results. + Defaults to True. + + Returns: + list[:obj:`DetDataSample`]: Return the detection results of the + input images. The returns value is DetDataSample, + which usually contain 'pred_instances'. And the + ``pred_instances`` usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, H, W). + """ + # For single image inference + if "custom_entities" in batch_data_samples[0]: + text_prompts = batch_data_samples[0].text + if text_prompts != self._text_prompts: + self._text_prompts = text_prompts + class_names, zs_weight = get_class_weight(text_prompts) + self._entities = class_names + reset_cls_layer_weight(self.roi_head, zs_weight) + + assert self.with_bbox, "Bbox head must be implemented." + + # x = self.extract_feat(batch_inputs) + x = detection_features + + # If there are no pre-defined proposals, use RPN to get proposals + if batch_data_samples[0].get("proposals", None) is None: + rpn_results_list = self.rpn_head.predict( + x, batch_data_samples, rescale=False + ) + else: + rpn_results_list = [ + data_sample.proposals for data_sample in batch_data_samples + ] + + results_list = self.roi_head.predict( + x, rpn_results_list, batch_data_samples, rescale=rescale + ) + + for data_sample, pred_instances in zip(batch_data_samples, results_list): + if len(pred_instances) > 0: + label_names = [] + for labels in pred_instances.labels: + label_names.append(self._entities[labels]) + # for visualization + pred_instances.label_names = label_names + data_sample.pred_instances = pred_instances + + return batch_data_samples diff --git a/masa/models/detectors/gdino_masa.py b/masa/models/detectors/gdino_masa.py new file mode 100644 index 0000000000000000000000000000000000000000..2559a1f03bf144760fe6f1e214ddab73bdba6239 --- /dev/null +++ b/masa/models/detectors/gdino_masa.py @@ -0,0 +1,261 @@ +""" +Author: Siyuan Li +Licensed: Apache-2.0 License +""" + +import copy +import logging +import re +import warnings + +from mmdet.registry import MODELS +from mmengine.logging import MMLogger, print_log +from mmengine.model.weight_init import (PretrainedInit, initialize, + update_init_info) + +from .grounding_dino import GroundingDINO + + +def clean_label_name(name: str) -> str: + name = re.sub(r"\(.*\)", "", name) + name = re.sub(r"_", " ", name) + name = re.sub(r" ", " ", name) + return name + + +def chunks(lst: list, n: int) -> list: + """Yield successive n-sized chunks from lst.""" + all_ = [] + for i in range(0, len(lst), n): + data_index = lst[i : i + n] + all_.append(data_index) + counter = 0 + for i in all_: + counter += len(i) + assert counter == len(lst) + + return all_ + + +@MODELS.register_module() +class GroundingDINOMasa(GroundingDINO): + """Implementation of `Grounding DINO: Marrying DINO with Grounded Pre- + Training for Open-Set Object Detection. + + `_ + + Code is modified from the `official github repo + `_. + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + self.track_text_prompt = None + self.track_text_dict = None + self.token_positive_maps = None + self.track_entities = None + + def init_weights(self) -> None: + """Initialize weights for Transformer and other components.""" + if self.init_cfg: + print_log( + f"initialize {self.__class__.__name__} with init_cfg {self.init_cfg}", + logger="current", + level=logging.DEBUG, + ) + + init_cfgs = self.init_cfg + if isinstance(self.init_cfg, dict): + init_cfgs = [self.init_cfg] + + # PretrainedInit has higher priority than any other init_cfg. + # Therefore we initialize `pretrained_cfg` last to overwrite + # the previous initialized weights. + # See details in https://github.com/open-mmlab/mmengine/issues/691 # noqa E501 + other_cfgs = [] + pretrained_cfg = [] + for init_cfg in init_cfgs: + assert isinstance(init_cfg, dict) + if ( + init_cfg["type"] == "Pretrained" + or init_cfg["type"] is PretrainedInit + ): + pretrained_cfg.append(init_cfg) + else: + other_cfgs.append(init_cfg) + + initialize(self, other_cfgs) + + else: + super().init_weights() + + initialize(self, pretrained_cfg) + + def predict( + self, batch_inputs, detection_features, batch_data_samples, rescale: bool = True + ): + + text_prompts = [] + enhanced_text_prompts = [] + tokens_positives = [] + for data_samples in batch_data_samples: + text_prompts.append(data_samples.text) + if "caption_prompt" in data_samples: + enhanced_text_prompts.append(data_samples.caption_prompt) + else: + enhanced_text_prompts.append(None) + tokens_positives.append(data_samples.get("tokens_positive", None)) + + if "custom_entities" in batch_data_samples[0]: + # Assuming that the `custom_entities` flag + # inside a batch is always the same. For single image inference + custom_entities = batch_data_samples[0].custom_entities + else: + custom_entities = False + + if self.track_text_dict is not None and self.track_text_prompt == text_prompts: + # text feature map layer + + is_rec_tasks = [] + for i, data_samples in enumerate(batch_data_samples): + if self.token_positive_maps[i] is not None: + is_rec_tasks.append(False) + else: + is_rec_tasks.append(True) + data_samples.token_positive_map = self.token_positive_maps[i] + + visual_feats = detection_features + + head_inputs_dict = self.forward_transformer( + visual_feats, self.track_text_dict, batch_data_samples + ) + results_list = self.bbox_head.predict( + **head_inputs_dict, + rescale=rescale, + batch_data_samples=batch_data_samples, + ) + + entities = self.track_entities + + else: + self.track_text_prompt = text_prompts + + if len(text_prompts) == 1: + # All the text prompts are the same, + # so there is no need to calculate them multiple times. + _positive_maps_and_prompts = [ + self.get_tokens_positive_and_prompts( + text_prompts[0], + custom_entities, + enhanced_text_prompts[0], + tokens_positives[0], + ) + ] * len(batch_inputs) + else: + _positive_maps_and_prompts = [ + self.get_tokens_positive_and_prompts( + text_prompt, + custom_entities, + enhanced_text_prompt, + tokens_positive, + ) + for text_prompt, enhanced_text_prompt, tokens_positive in zip( + text_prompts, enhanced_text_prompts, tokens_positives + ) + ] + token_positive_maps, text_prompts, _, entities = zip( + *_positive_maps_and_prompts + ) + + self.token_positive_maps = token_positive_maps + self.track_entities = entities + + # image feature extraction + visual_feats = detection_features + + if isinstance(text_prompts[0], list): + # chunked text prompts, only bs=1 is supported + assert len(batch_inputs) == 1 + count = 0 + results_list = [] + + entities = [[item for lst in entities[0] for item in lst]] + + for b in range(len(text_prompts[0])): + text_prompts_once = [text_prompts[0][b]] + token_positive_maps_once = token_positive_maps[0][b] + text_dict = self.language_model(text_prompts_once) + # text feature map layer + if self.text_feat_map is not None: + text_dict["embedded"] = self.text_feat_map( + text_dict["embedded"] + ) + + batch_data_samples[0].token_positive_map = token_positive_maps_once + + head_inputs_dict = self.forward_transformer( + copy.deepcopy(visual_feats), text_dict, batch_data_samples + ) + pred_instances = self.bbox_head.predict( + **head_inputs_dict, + rescale=rescale, + batch_data_samples=batch_data_samples, + )[0] + + if len(pred_instances) > 0: + pred_instances.labels += count + count += len(token_positive_maps_once) + results_list.append(pred_instances) + results_list = [results_list[0].cat(results_list)] + is_rec_tasks = [False] * len(results_list) + else: + # extract text feats + text_dict = self.language_model(list(text_prompts)) + # text feature map layer + if self.text_feat_map is not None: + text_dict["embedded"] = self.text_feat_map(text_dict["embedded"]) + + is_rec_tasks = [] + for i, data_samples in enumerate(batch_data_samples): + if token_positive_maps[i] is not None: + is_rec_tasks.append(False) + else: + is_rec_tasks.append(True) + data_samples.token_positive_map = token_positive_maps[i] + + if self.track_text_dict is None: + self.track_text_dict = text_dict + + head_inputs_dict = self.forward_transformer( + visual_feats, text_dict, batch_data_samples + ) + results_list = self.bbox_head.predict( + **head_inputs_dict, + rescale=rescale, + batch_data_samples=batch_data_samples, + ) + + for data_sample, pred_instances, entity, is_rec_task in zip( + batch_data_samples, results_list, entities, is_rec_tasks + ): + if len(pred_instances) > 0: + label_names = [] + for labels in pred_instances.labels: + if is_rec_task: + label_names.append(entity) + continue + if labels >= len(entity): + warnings.warn( + "The unexpected output indicates an issue with " + "named entity recognition. You can try " + "setting custom_entities=True and running " + "again to see if it helps." + ) + label_names.append("unobject") + else: + label_names.append(entity[labels]) + # for visualization + pred_instances.label_names = label_names + data_sample.pred_instances = pred_instances + return batch_data_samples diff --git a/masa/models/detectors/grounding_dino.py b/masa/models/detectors/grounding_dino.py new file mode 100644 index 0000000000000000000000000000000000000000..03b35f0cb1485745158a8919a1db5578ac61f1a2 --- /dev/null +++ b/masa/models/detectors/grounding_dino.py @@ -0,0 +1,801 @@ +import copy +import re +import warnings +from typing import Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn +from mmdet.models.detectors.dino import DINO +from mmdet.models.detectors.glip import (create_positive_map, + create_positive_map_label_to_token) +from mmdet.models.layers import SinePositionalEncoding +from mmdet.models.layers.transformer.grounding_dino_layers import ( + GroundingDinoTransformerDecoder, GroundingDinoTransformerEncoder) +from mmdet.registry import MODELS +from mmdet.structures import OptSampleList, SampleList +from mmdet.utils import ConfigType +from mmengine.runner.amp import autocast +from torch import Tensor + +try: + import os + + import nltk + + download_dir = os.path.expanduser("~/nltk_data") + nltk.download("punkt", download_dir=download_dir, quiet=True) + nltk.download("averaged_perceptron_tagger", download_dir=download_dir, quiet=True) +except ImportError: + raise RuntimeError( + "nltk is not installed, please install it by: " "pip install nltk." + ) + + +def find_noun_phrases(caption: str) -> list: + """Find noun phrases in a caption using nltk. + Args: + caption (str): The caption to analyze. + + Returns: + list: List of noun phrases found in the caption. + + Examples: + >>> caption = 'There is two cat and a remote in the picture' + >>> find_noun_phrases(caption) # ['cat', 'a remote', 'the picture'] + """ + # try: + # import nltk + # import os + # # nltk.download('punkt', download_dir='~/nltk_data') + # # nltk.download('averaged_perceptron_tagger', download_dir='~/nltk_data') + # download_dir = os.path.expanduser('~/nltk_data') + # nltk.download('punkt', download_dir=download_dir) + # nltk.download('averaged_perceptron_tagger', download_dir=download_dir) + # except ImportError: + # raise RuntimeError('nltk is not installed, please install it by: ' + # 'pip install nltk.') + + caption = caption.lower() + tokens = nltk.word_tokenize(caption) + pos_tags = nltk.pos_tag(tokens) + + grammar = "NP: {
?*+}" + cp = nltk.RegexpParser(grammar) + result = cp.parse(pos_tags) + + noun_phrases = [] + for subtree in result.subtrees(): + if subtree.label() == "NP": + noun_phrases.append(" ".join(t[0] for t in subtree.leaves())) + + return noun_phrases + + +def remove_punctuation(text: str) -> str: + """Remove punctuation from a text. + Args: + text (str): The input text. + + Returns: + str: The text with punctuation removed. + """ + punctuation = [ + "|", + ":", + ";", + "@", + "(", + ")", + "[", + "]", + "{", + "}", + "^", + "'", + '"', + "’", + "`", + "?", + "$", + "%", + "#", + "!", + "&", + "*", + "+", + ",", + ".", + ] + for p in punctuation: + text = text.replace(p, "") + return text.strip() + + +def run_ner(caption: str) -> Tuple[list, list]: + """Run NER on a caption and return the tokens and noun phrases. + Args: + caption (str): The input caption. + + Returns: + Tuple[List, List]: A tuple containing the tokens and noun phrases. + - tokens_positive (List): A list of token positions. + - noun_phrases (List): A list of noun phrases. + """ + noun_phrases = find_noun_phrases(caption) + noun_phrases = [remove_punctuation(phrase) for phrase in noun_phrases] + noun_phrases = [phrase for phrase in noun_phrases if phrase != ""] + # print('noun_phrases:', noun_phrases) + relevant_phrases = noun_phrases + labels = noun_phrases + + tokens_positive = [] + for entity, label in zip(relevant_phrases, labels): + try: + # search all occurrences and mark them as different entities + # TODO: Not Robust + for m in re.finditer(entity, caption.lower()): + tokens_positive.append([[m.start(), m.end()]]) + except Exception: + print("noun entities:", noun_phrases) + print("entity:", entity) + print("caption:", caption.lower()) + return tokens_positive, noun_phrases + + +def clean_label_name(name: str) -> str: + name = re.sub(r"\(.*\)", "", name) + name = re.sub(r"_", " ", name) + name = re.sub(r" ", " ", name) + return name + + +def chunks(lst: list, n: int) -> list: + """Yield successive n-sized chunks from lst.""" + all_ = [] + for i in range(0, len(lst), n): + data_index = lst[i : i + n] + all_.append(data_index) + counter = 0 + for i in all_: + counter += len(i) + assert counter == len(lst) + + return all_ + + +@MODELS.register_module(force=True) +class GroundingDINO(DINO): + """Implementation of `Grounding DINO: Marrying DINO with Grounded Pre- + Training for Open-Set Object Detection. + + `_ + + Code is modified from the `official github repo + `_. + """ + + def __init__(self, language_model, *args, use_autocast=False, **kwargs) -> None: + + self.language_model_cfg = language_model + self._special_tokens = ". " + self.use_autocast = use_autocast + super().__init__(*args, **kwargs) + + def _init_layers(self) -> None: + """Initialize layers except for backbone, neck and bbox_head.""" + self.positional_encoding = SinePositionalEncoding(**self.positional_encoding) + self.encoder = GroundingDinoTransformerEncoder(**self.encoder) + self.decoder = GroundingDinoTransformerDecoder(**self.decoder) + self.embed_dims = self.encoder.embed_dims + self.query_embedding = nn.Embedding(self.num_queries, self.embed_dims) + num_feats = self.positional_encoding.num_feats + assert num_feats * 2 == self.embed_dims, ( + f"embed_dims should be exactly 2 times of num_feats. " + f"Found {self.embed_dims} and {num_feats}." + ) + + self.level_embed = nn.Parameter( + torch.Tensor(self.num_feature_levels, self.embed_dims) + ) + self.memory_trans_fc = nn.Linear(self.embed_dims, self.embed_dims) + self.memory_trans_norm = nn.LayerNorm(self.embed_dims) + + # text modules + self.language_model = MODELS.build(self.language_model_cfg) + self.text_feat_map = nn.Linear( + self.language_model.language_backbone.body.language_dim, + self.embed_dims, + bias=True, + ) + + def init_weights(self) -> None: + """Initialize weights for Transformer and other components.""" + super().init_weights() + nn.init.constant_(self.text_feat_map.bias.data, 0) + nn.init.xavier_uniform_(self.text_feat_map.weight.data) + + def to_enhance_text_prompts(self, original_caption, enhanced_text_prompts): + caption_string = "" + tokens_positive = [] + for idx, word in enumerate(original_caption): + if word in enhanced_text_prompts: + enhanced_text_dict = enhanced_text_prompts[word] + if "prefix" in enhanced_text_dict: + caption_string += enhanced_text_dict["prefix"] + start_i = len(caption_string) + if "name" in enhanced_text_dict: + caption_string += enhanced_text_dict["name"] + else: + caption_string += word + end_i = len(caption_string) + tokens_positive.append([[start_i, end_i]]) + + if "suffix" in enhanced_text_dict: + caption_string += enhanced_text_dict["suffix"] + else: + tokens_positive.append( + [[len(caption_string), len(caption_string) + len(word)]] + ) + caption_string += word + caption_string += self._special_tokens + return caption_string, tokens_positive + + def to_plain_text_prompts(self, original_caption): + caption_string = "" + tokens_positive = [] + for idx, word in enumerate(original_caption): + tokens_positive.append( + [[len(caption_string), len(caption_string) + len(word)]] + ) + caption_string += word + caption_string += self._special_tokens + return caption_string, tokens_positive + + def get_tokens_and_prompts( + self, + original_caption: Union[str, list, tuple], + custom_entities: bool = False, + enhanced_text_prompts: Optional[ConfigType] = None, + ) -> Tuple[dict, str, list]: + """Get the tokens positive and prompts for the caption.""" + if isinstance(original_caption, (list, tuple)) or custom_entities: + if custom_entities and isinstance(original_caption, str): + original_caption = original_caption.strip(self._special_tokens) + original_caption = original_caption.split(self._special_tokens) + original_caption = list(filter(lambda x: len(x) > 0, original_caption)) + + original_caption = [clean_label_name(i) for i in original_caption] + + if custom_entities and enhanced_text_prompts is not None: + caption_string, tokens_positive = self.to_enhance_text_prompts( + original_caption, enhanced_text_prompts + ) + else: + caption_string, tokens_positive = self.to_plain_text_prompts( + original_caption + ) + + # NOTE: Tokenizer in Grounding DINO is different from + # that in GLIP. The tokenizer in GLIP will pad the + # caption_string to max_length, while the tokenizer + # in Grounding DINO will not. + tokenized = self.language_model.tokenizer( + [caption_string], + padding="max_length" if self.language_model.pad_to_max else "longest", + return_tensors="pt", + ) + entities = original_caption + else: + if not original_caption.endswith("."): + original_caption = original_caption + self._special_tokens + # NOTE: Tokenizer in Grounding DINO is different from + # that in GLIP. The tokenizer in GLIP will pad the + # caption_string to max_length, while the tokenizer + # in Grounding DINO will not. + tokenized = self.language_model.tokenizer( + [original_caption], + padding="max_length" if self.language_model.pad_to_max else "longest", + return_tensors="pt", + ) + tokens_positive, noun_phrases = run_ner(original_caption) + entities = noun_phrases + caption_string = original_caption + + return tokenized, caption_string, tokens_positive, entities + + def get_positive_map(self, tokenized, tokens_positive): + positive_map = create_positive_map( + tokenized, + tokens_positive, + max_num_entities=self.bbox_head.cls_branches[ + self.decoder.num_layers + ].max_text_len, + ) + positive_map_label_to_token = create_positive_map_label_to_token( + positive_map, plus=1 + ) + return positive_map_label_to_token, positive_map + + def get_tokens_positive_and_prompts( + self, + original_caption: Union[str, list, tuple], + custom_entities: bool = False, + enhanced_text_prompt: Optional[ConfigType] = None, + tokens_positive: Optional[list] = None, + ) -> Tuple[dict, str, Tensor, list]: + """Get the tokens positive and prompts for the caption. + + Args: + original_caption (str): The original caption, e.g. 'bench . car .' + custom_entities (bool, optional): Whether to use custom entities. + If ``True``, the ``original_caption`` should be a list of + strings, each of which is a word. Defaults to False. + + Returns: + Tuple[dict, str, dict, str]: The dict is a mapping from each entity + id, which is numbered from 1, to its positive token id. + The str represents the prompts. + """ + if tokens_positive is not None: + if tokens_positive == -1: + if not original_caption.endswith("."): + original_caption = original_caption + self._special_tokens + return None, original_caption, None, original_caption + else: + if not original_caption.endswith("."): + original_caption = original_caption + self._special_tokens + tokenized = self.language_model.tokenizer( + [original_caption], + padding="max_length" + if self.language_model.pad_to_max + else "longest", + return_tensors="pt", + ) + positive_map_label_to_token, positive_map = self.get_positive_map( + tokenized, tokens_positive + ) + + entities = [] + for token_positive in tokens_positive: + instance_entities = [] + for t in token_positive: + instance_entities.append(original_caption[t[0] : t[1]]) + entities.append(" / ".join(instance_entities)) + return ( + positive_map_label_to_token, + original_caption, + positive_map, + entities, + ) + + chunked_size = self.test_cfg.get("chunked_size", -1) + if not self.training and chunked_size > 0: + assert ( + isinstance(original_caption, (list, tuple)) or custom_entities is True + ) + all_output = self.get_tokens_positive_and_prompts_chunked( + original_caption, enhanced_text_prompt + ) + ( + positive_map_label_to_token, + caption_string, + positive_map, + entities, + ) = all_output + else: + ( + tokenized, + caption_string, + tokens_positive, + entities, + ) = self.get_tokens_and_prompts( + original_caption, custom_entities, enhanced_text_prompt + ) + positive_map_label_to_token, positive_map = self.get_positive_map( + tokenized, tokens_positive + ) + return positive_map_label_to_token, caption_string, positive_map, entities + + def get_tokens_positive_and_prompts_chunked( + self, + original_caption: Union[list, tuple], + enhanced_text_prompts: Optional[ConfigType] = None, + ): + chunked_size = self.test_cfg.get("chunked_size", -1) + original_caption = [clean_label_name(i) for i in original_caption] + + original_caption_chunked = chunks(original_caption, chunked_size) + ids_chunked = chunks(list(range(1, len(original_caption) + 1)), chunked_size) + + positive_map_label_to_token_chunked = [] + caption_string_chunked = [] + positive_map_chunked = [] + entities_chunked = [] + + for i in range(len(ids_chunked)): + if enhanced_text_prompts is not None: + caption_string, tokens_positive = self.to_enhance_text_prompts( + original_caption_chunked[i], enhanced_text_prompts + ) + else: + caption_string, tokens_positive = self.to_plain_text_prompts( + original_caption_chunked[i] + ) + tokenized = self.language_model.tokenizer( + [caption_string], return_tensors="pt" + ) + if tokenized.input_ids.shape[1] > self.language_model.max_tokens: + warnings.warn( + "Inputting a text that is too long will result " + "in poor prediction performance. " + "Please reduce the --chunked-size." + ) + positive_map_label_to_token, positive_map = self.get_positive_map( + tokenized, tokens_positive + ) + + caption_string_chunked.append(caption_string) + positive_map_label_to_token_chunked.append(positive_map_label_to_token) + positive_map_chunked.append(positive_map) + entities_chunked.append(original_caption_chunked[i]) + + return ( + positive_map_label_to_token_chunked, + caption_string_chunked, + positive_map_chunked, + entities_chunked, + ) + + def forward_transformer( + self, + img_feats: Tuple[Tensor], + text_dict: Dict, + batch_data_samples: OptSampleList = None, + ) -> Dict: + encoder_inputs_dict, decoder_inputs_dict = self.pre_transformer( + img_feats, batch_data_samples + ) + + encoder_outputs_dict = self.forward_encoder( + **encoder_inputs_dict, text_dict=text_dict + ) + + tmp_dec_in, head_inputs_dict = self.pre_decoder( + **encoder_outputs_dict, batch_data_samples=batch_data_samples + ) + decoder_inputs_dict.update(tmp_dec_in) + + decoder_outputs_dict = self.forward_decoder(**decoder_inputs_dict) + head_inputs_dict.update(decoder_outputs_dict) + return head_inputs_dict + + def forward_encoder( + self, + feat: Tensor, + feat_mask: Tensor, + feat_pos: Tensor, + spatial_shapes: Tensor, + level_start_index: Tensor, + valid_ratios: Tensor, + text_dict: Dict, + ) -> Dict: + text_token_mask = text_dict["text_token_mask"] + memory, memory_text = self.encoder( + query=feat, + query_pos=feat_pos, + key_padding_mask=feat_mask, # for self_attn + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + # for text encoder + memory_text=text_dict["embedded"], + text_attention_mask=~text_token_mask, + position_ids=text_dict["position_ids"], + text_self_attention_masks=text_dict["masks"], + ) + encoder_outputs_dict = dict( + memory=memory, + memory_mask=feat_mask, + spatial_shapes=spatial_shapes, + memory_text=memory_text, + text_token_mask=text_token_mask, + ) + return encoder_outputs_dict + + def pre_decoder( + self, + memory: Tensor, + memory_mask: Tensor, + spatial_shapes: Tensor, + memory_text: Tensor, + text_token_mask: Tensor, + batch_data_samples: OptSampleList = None, + ) -> Tuple[Dict]: + bs, _, c = memory.shape + + output_memory, output_proposals = self.gen_encoder_output_proposals( + memory, memory_mask, spatial_shapes + ) + + enc_outputs_class = self.bbox_head.cls_branches[self.decoder.num_layers]( + output_memory, memory_text, text_token_mask + ) + cls_out_features = self.bbox_head.cls_branches[ + self.decoder.num_layers + ].max_text_len + enc_outputs_coord_unact = ( + self.bbox_head.reg_branches[self.decoder.num_layers](output_memory) + + output_proposals + ) + + # NOTE The DINO selects top-k proposals according to scores of + # multi-class classification, while DeformDETR, where the input + # is `enc_outputs_class[..., 0]` selects according to scores of + # binary classification. + topk_indices = torch.topk( + enc_outputs_class.max(-1)[0], k=self.num_queries, dim=1 + )[1] + + topk_score = torch.gather( + enc_outputs_class, + 1, + topk_indices.unsqueeze(-1).repeat(1, 1, cls_out_features), + ) + topk_coords_unact = torch.gather( + enc_outputs_coord_unact, 1, topk_indices.unsqueeze(-1).repeat(1, 1, 4) + ) + topk_coords = topk_coords_unact.sigmoid() + topk_coords_unact = topk_coords_unact.detach() + + query = self.query_embedding.weight[:, None, :] + query = query.repeat(1, bs, 1).transpose(0, 1) + if self.training: + dn_label_query, dn_bbox_query, dn_mask, dn_meta = self.dn_query_generator( + batch_data_samples + ) + query = torch.cat([dn_label_query, query], dim=1) + reference_points = torch.cat([dn_bbox_query, topk_coords_unact], dim=1) + else: + reference_points = topk_coords_unact + dn_mask, dn_meta = None, None + reference_points = reference_points.sigmoid() + + decoder_inputs_dict = dict( + query=query, + memory=memory, + reference_points=reference_points, + dn_mask=dn_mask, + memory_text=memory_text, + text_attention_mask=~text_token_mask, + ) + # NOTE DINO calculates encoder losses on scores and coordinates + # of selected top-k encoder queries, while DeformDETR is of all + # encoder queries. + head_inputs_dict = ( + dict( + enc_outputs_class=topk_score, + enc_outputs_coord=topk_coords, + dn_meta=dn_meta, + ) + if self.training + else dict() + ) + # append text_feats to head_inputs_dict + head_inputs_dict["memory_text"] = memory_text + head_inputs_dict["text_token_mask"] = text_token_mask + return decoder_inputs_dict, head_inputs_dict + + def loss( + self, batch_inputs: Tensor, batch_data_samples: SampleList + ) -> Union[dict, list]: + text_prompts = [data_samples.text for data_samples in batch_data_samples] + + gt_labels = [ + data_samples.gt_instances.labels for data_samples in batch_data_samples + ] + + if "tokens_positive" in batch_data_samples[0]: + tokens_positive = [ + data_samples.tokens_positive for data_samples in batch_data_samples + ] + positive_maps = [] + for token_positive, text_prompt, gt_label in zip( + tokens_positive, text_prompts, gt_labels + ): + tokenized = self.language_model.tokenizer( + [text_prompt], + padding="max_length" + if self.language_model.pad_to_max + else "longest", + return_tensors="pt", + ) + new_tokens_positive = [ + token_positive[label.item()] for label in gt_label + ] + _, positive_map = self.get_positive_map(tokenized, new_tokens_positive) + positive_maps.append(positive_map) + new_text_prompts = text_prompts + else: + new_text_prompts = [] + positive_maps = [] + if len(set(text_prompts)) == 1: + # All the text prompts are the same, + # so there is no need to calculate them multiple times. + ( + tokenized, + caption_string, + tokens_positive, + _, + ) = self.get_tokens_and_prompts(text_prompts[0], True) + new_text_prompts = [caption_string] * len(batch_inputs) + for gt_label in gt_labels: + new_tokens_positive = [tokens_positive[label] for label in gt_label] + _, positive_map = self.get_positive_map( + tokenized, new_tokens_positive + ) + positive_maps.append(positive_map) + else: + for text_prompt, gt_label in zip(text_prompts, gt_labels): + ( + tokenized, + caption_string, + tokens_positive, + _, + ) = self.get_tokens_and_prompts(text_prompt, True) + new_tokens_positive = [tokens_positive[label] for label in gt_label] + _, positive_map = self.get_positive_map( + tokenized, new_tokens_positive + ) + positive_maps.append(positive_map) + new_text_prompts.append(caption_string) + + text_dict = self.language_model(new_text_prompts) + if self.text_feat_map is not None: + text_dict["embedded"] = self.text_feat_map(text_dict["embedded"]) + + for i, data_samples in enumerate(batch_data_samples): + positive_map = positive_maps[i].to(batch_inputs.device).bool().float() + text_token_mask = text_dict["text_token_mask"][i] + data_samples.gt_instances.positive_maps = positive_map + data_samples.gt_instances.text_token_mask = text_token_mask.unsqueeze( + 0 + ).repeat(len(positive_map), 1) + if self.use_autocast: + with autocast(enabled=True): + visual_features = self.extract_feat(batch_inputs) + else: + visual_features = self.extract_feat(batch_inputs) + head_inputs_dict = self.forward_transformer( + visual_features, text_dict, batch_data_samples + ) + + losses = self.bbox_head.loss( + **head_inputs_dict, batch_data_samples=batch_data_samples + ) + return losses + + def predict(self, batch_inputs, batch_data_samples, rescale: bool = True): + text_prompts = [] + enhanced_text_prompts = [] + tokens_positives = [] + for data_samples in batch_data_samples: + text_prompts.append(data_samples.text) + if "caption_prompt" in data_samples: + enhanced_text_prompts.append(data_samples.caption_prompt) + else: + enhanced_text_prompts.append(None) + tokens_positives.append(data_samples.get("tokens_positive", None)) + + if "custom_entities" in batch_data_samples[0]: + # Assuming that the `custom_entities` flag + # inside a batch is always the same. For single image inference + custom_entities = batch_data_samples[0].custom_entities + else: + custom_entities = False + if len(text_prompts) == 1: + # All the text prompts are the same, + # so there is no need to calculate them multiple times. + _positive_maps_and_prompts = [ + self.get_tokens_positive_and_prompts( + text_prompts[0], + custom_entities, + enhanced_text_prompts[0], + tokens_positives[0], + ) + ] * len(batch_inputs) + else: + _positive_maps_and_prompts = [ + self.get_tokens_positive_and_prompts( + text_prompt, custom_entities, enhanced_text_prompt, tokens_positive + ) + for text_prompt, enhanced_text_prompt, tokens_positive in zip( + text_prompts, enhanced_text_prompts, tokens_positives + ) + ] + token_positive_maps, text_prompts, _, entities = zip( + *_positive_maps_and_prompts + ) + + # image feature extraction + visual_feats = self.extract_feat(batch_inputs) + + if isinstance(text_prompts[0], list): + # chunked text prompts, only bs=1 is supported + assert len(batch_inputs) == 1 + count = 0 + results_list = [] + + entities = [[item for lst in entities[0] for item in lst]] + + for b in range(len(text_prompts[0])): + text_prompts_once = [text_prompts[0][b]] + token_positive_maps_once = token_positive_maps[0][b] + text_dict = self.language_model(text_prompts_once) + # text feature map layer + if self.text_feat_map is not None: + text_dict["embedded"] = self.text_feat_map(text_dict["embedded"]) + + batch_data_samples[0].token_positive_map = token_positive_maps_once + + head_inputs_dict = self.forward_transformer( + copy.deepcopy(visual_feats), text_dict, batch_data_samples + ) + pred_instances = self.bbox_head.predict( + **head_inputs_dict, + rescale=rescale, + batch_data_samples=batch_data_samples, + )[0] + + if len(pred_instances) > 0: + pred_instances.labels += count + count += len(token_positive_maps_once) + results_list.append(pred_instances) + results_list = [results_list[0].cat(results_list)] + is_rec_tasks = [False] * len(results_list) + else: + # extract text feats + text_dict = self.language_model(list(text_prompts)) + # text feature map layer + if self.text_feat_map is not None: + text_dict["embedded"] = self.text_feat_map(text_dict["embedded"]) + + is_rec_tasks = [] + for i, data_samples in enumerate(batch_data_samples): + if token_positive_maps[i] is not None: + is_rec_tasks.append(False) + else: + is_rec_tasks.append(True) + data_samples.token_positive_map = token_positive_maps[i] + + head_inputs_dict = self.forward_transformer( + visual_feats, text_dict, batch_data_samples + ) + results_list = self.bbox_head.predict( + **head_inputs_dict, + rescale=rescale, + batch_data_samples=batch_data_samples, + ) + + for data_sample, pred_instances, entity, is_rec_task in zip( + batch_data_samples, results_list, entities, is_rec_tasks + ): + if len(pred_instances) > 0: + label_names = [] + for labels in pred_instances.labels: + if is_rec_task: + label_names.append(entity) + continue + if labels >= len(entity): + warnings.warn( + "The unexpected output indicates an issue with " + "named entity recognition. You can try " + "setting custom_entities=True and running " + "again to see if it helps." + ) + label_names.append("unobject") + else: + label_names.append(entity[labels]) + # for visualization + pred_instances.label_names = label_names + data_sample.pred_instances = pred_instances + return batch_data_samples diff --git a/masa/models/detectors/sam_masa.py b/masa/models/detectors/sam_masa.py new file mode 100644 index 0000000000000000000000000000000000000000..1cfd0a6708b6578db86ee7c93f35c6452708273b --- /dev/null +++ b/masa/models/detectors/sam_masa.py @@ -0,0 +1,191 @@ +""" +Author: Siyuan Li +Licensed: Apache-2.0 License +""" + +from typing import Any, Dict, List, Tuple + +import torch +from mmdet.registry import MODELS +from mmengine.model import BaseModule +from torch.nn import functional as F + +from ..sam.image_encoder import ImageEncoderViT +from ..sam.mask_decoder import MaskDecoder +from ..sam.prompt_encoder import PromptEncoder + + +@MODELS.register_module() +class SamMasa(BaseModule): + mask_threshold: float = 0.0 + image_format: str = "RGB" + + def __init__( + self, + backbone: ImageEncoderViT, + prompt_encoder: PromptEncoder, + mask_decoder: MaskDecoder, + pixel_mean: List[float] = [123.675, 116.28, 103.53], + pixel_std: List[float] = [58.395, 57.12, 57.375], + **kwargs + ) -> None: + """ + SAM predicts object masks from an image and input prompts. + + Arguments: + backbone (ImageEncoderViT): The backbone used to encode the + image into image embeddings that allow for efficient mask prediction. + prompt_encoder (PromptEncoder): Encodes various types of input prompts. + mask_decoder (MaskDecoder): Predicts masks from the image embeddings + and encoded prompts. + pixel_mean (list(float)): Mean values for normalizing pixels in the input image. + pixel_std (list(float)): Std values for normalizing pixels in the input image. + """ + super().__init__(**kwargs) + if backbone is not None: + self.backbone = MODELS.build(backbone) + + if prompt_encoder is not None: + self.prompt_encoder = MODELS.build(prompt_encoder) + else: + self.prompt_encoder = None + + if mask_decoder is not None: + self.mask_decoder = MODELS.build(mask_decoder) + else: + self.mask_decoder = None + + self.register_buffer( + "pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False + ) + self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) + + @property + def device(self) -> Any: + return self.pixel_mean.device + + @torch.no_grad() + def forward( + self, batched_input: List[Dict[str, Any]], multimask_output: bool, + ) -> List[Dict[str, torch.Tensor]]: + """ + Predicts masks end-to-end from provided images and prompts. + If prompts are not known in advance, using SamPredictor is + recommended over calling the model directly. + + Arguments: + batched_input (list(dict)): A list over input images, each a + dictionary with the following keys. A prompt key can be + excluded if it is not present. + 'image': The image as a torch tensor in 3xHxW format, + already transformed for input to the model. + 'original_size': (tuple(int, int)) The original size of + the image before transformation, as (H, W). + 'point_coords': (torch.Tensor) Batched point prompts for + this image, with shape BxNx2. Already transformed to the + input frame of the model. + 'point_labels': (torch.Tensor) Batched labels for point prompts, + with shape BxN. + 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. + Already transformed to the input frame of the model. + 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, + in the form Bx1xHxW. + multimask_output (bool): Whether the model should predict multiple + disambiguating masks, or return a single mask. + + Returns: + (list(dict)): A list over input images, where each element is + as dictionary with the following keys. + 'masks': (torch.Tensor) Batched binary mask predictions, + with shape BxCxHxW, where B is the number of input promts, + C is determiend by multimask_output, and (H, W) is the + original size of the image. + 'iou_predictions': (torch.Tensor) The model's predictions + of mask quality, in shape BxC. + 'low_res_logits': (torch.Tensor) Low resolution logits with + shape BxCxHxW, where H=W=256. Can be passed as mask input + to subsequent iterations of prediction. + """ + input_images = torch.stack( + [self.preprocess(x["image"]) for x in batched_input], dim=0 + ) + image_embeddings = self.backbone(input_images) + + outputs = [] + for image_record, curr_embedding in zip(batched_input, image_embeddings): + if "point_coords" in image_record: + points = (image_record["point_coords"], image_record["point_labels"]) + else: + points = None + sparse_embeddings, dense_embeddings = self.prompt_encoder( + points=points, + boxes=image_record.get("boxes", None), + masks=image_record.get("mask_inputs", None), + ) + low_res_masks, iou_predictions = self.mask_decoder( + image_embeddings=curr_embedding.unsqueeze(0), + image_pe=self.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + masks = self.postprocess_masks( + low_res_masks, + input_size=image_record["image"].shape[-2:], + original_size=image_record["original_size"], + ) + masks = masks > self.mask_threshold + outputs.append( + { + "masks": masks, + "iou_predictions": iou_predictions, + "low_res_logits": low_res_masks, + } + ) + + return outputs + + def postprocess_masks( + self, + masks: torch.Tensor, + input_size: Tuple[int, ...], + original_size: Tuple[int, ...], + ) -> torch.Tensor: + """ + Remove padding and upscale masks to the original image size. + + Arguments: + masks (torch.Tensor): Batched masks from the mask_decoder, + in BxCxHxW format. + input_size (tuple(int, int)): The size of the image input to the + model, in (H, W) format. Used to remove padding. + original_size (tuple(int, int)): The original size of the image + before resizing for input to the model, in (H, W) format. + + Returns: + (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) + is given by original_size. + """ + masks = F.interpolate( + masks, + (self.backbone.img_size, self.backbone.img_size), + mode="bilinear", + align_corners=False, + ) + masks = masks[..., : input_size[0], : input_size[1]] + masks = F.interpolate( + masks, original_size, mode="bilinear", align_corners=False + ) + return masks + + def preprocess(self, x: torch.Tensor) -> torch.Tensor: + """Normalize pixel values and pad to a square input.""" + # Normalize colors + x = (x - self.pixel_mean) / self.pixel_std + + # Pad + h, w = x.shape[-2:] + padh = self.backbone.img_size - h + padw = self.backbone.img_size - w + x = F.pad(x, (0, padw, 0, padh)) + return x diff --git a/masa/models/losses/__init__.py b/masa/models/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c1eaeff39dfecb635ce001c5530f36e1159b1c4d --- /dev/null +++ b/masa/models/losses/__init__.py @@ -0,0 +1,3 @@ +from .unbiased_contrastive_loss import UnbiasedContrastLoss + +__all__ = ["UnbiasedContrastLoss"] diff --git a/masa/models/losses/__pycache__/__init__.cpython-311.pyc b/masa/models/losses/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74a117a52e71a6659be38bcb218183af8c397cb2 Binary files /dev/null and b/masa/models/losses/__pycache__/__init__.cpython-311.pyc differ diff --git a/masa/models/losses/__pycache__/unbiased_contrastive_loss.cpython-311.pyc b/masa/models/losses/__pycache__/unbiased_contrastive_loss.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27f2bd89008592dfb9ce15c5f67b6cda34289fa7 Binary files /dev/null and b/masa/models/losses/__pycache__/unbiased_contrastive_loss.cpython-311.pyc differ diff --git a/masa/models/losses/unbiased_contrastive_loss.py b/masa/models/losses/unbiased_contrastive_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..d416aaaea122f88ec554b1a8910a72bc32a8443e --- /dev/null +++ b/masa/models/losses/unbiased_contrastive_loss.py @@ -0,0 +1,76 @@ +""" +Author: Siyuan Li +Licensed: Apache-2.0 License +""" + +import torch +import torch.nn as nn +from mmdet.models import weight_reduce_loss +from mmdet.registry import MODELS + + +def multi_pos_cross_entropy( + pred, label, weight=None, reduction="mean", avg_factor=None, pos_normalize=True, +): + + valid_mask = label.sum(1) != 0 + pred = pred[valid_mask] + label = label[valid_mask] + weight = weight[valid_mask] + if min(pred.shape) != 0: + logits_max, _ = torch.max(pred, dim=1, keepdim=True) + logits = pred - logits_max.detach() + else: + logits = pred + + if pos_normalize: + pos_norm = torch.div(label, label.sum(1).reshape(-1, 1)) + exp_logits = (torch.exp(logits)) * pos_norm + ( + torch.exp(logits) + ) * torch.logical_not(label) + else: + exp_logits = torch.exp(logits) + exp_logits_input = exp_logits.sum(1, keepdim=True) + log_prob = logits - torch.log(exp_logits_input) + + mean_log_prob_pos = (label * log_prob).sum(1) / label.sum(1) + loss = -mean_log_prob_pos + + # apply weights and do the reduction + if weight is not None: + weight = weight.float() + loss = weight_reduce_loss( + loss, weight=weight, reduction=reduction, avg_factor=avg_factor + ) + + return loss + + +@MODELS.register_module() +class UnbiasedContrastLoss(nn.Module): + def __init__(self, reduction="mean", loss_weight=1.0): + super(UnbiasedContrastLoss, self).__init__() + self.reduction = reduction + self.loss_weight = loss_weight + + def forward( + self, + cls_score, + label, + weight=None, + avg_factor=None, + reduction_override=None, + **kwargs + ): + assert cls_score.size() == label.size() + assert reduction_override in (None, "none", "mean", "sum") + reduction = reduction_override if reduction_override else self.reduction + loss_cls = self.loss_weight * multi_pos_cross_entropy( + cls_score, + label, + weight, + reduction=reduction, + avg_factor=avg_factor, + **kwargs + ) + return loss_cls diff --git a/masa/models/mot/__init__.py b/masa/models/mot/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9a87ceeefcd95eb9e47a0b569e2acd13996fe747 --- /dev/null +++ b/masa/models/mot/__init__.py @@ -0,0 +1,3 @@ +from .masa import MASA + +__all__ = ["MASA"] diff --git a/masa/models/mot/__pycache__/__init__.cpython-311.pyc b/masa/models/mot/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab494cafe35f86550a56b7c4b846ea7110437746 Binary files /dev/null and b/masa/models/mot/__pycache__/__init__.cpython-311.pyc differ diff --git a/masa/models/mot/__pycache__/masa.cpython-311.pyc b/masa/models/mot/__pycache__/masa.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c387821ab28b1df42c71f6067d85ac7af7dac009 Binary files /dev/null and b/masa/models/mot/__pycache__/masa.cpython-311.pyc differ diff --git a/masa/models/mot/masa.py b/masa/models/mot/masa.py new file mode 100644 index 0000000000000000000000000000000000000000..02d4061b7cf0d180bb76b465111a30848837a93c --- /dev/null +++ b/masa/models/mot/masa.py @@ -0,0 +1,498 @@ +""" +Author: Siyuan Li +Licensed: Apache-2.0 License +""" + +import copy +import os +import pickle +import warnings +from typing import Dict, List, Optional, Tuple, Union + +import torch +from mmdet.models.mot.base import BaseMOTModel +from mmdet.registry import MODELS +from mmdet.structures import TrackSampleList +from mmdet.utils import OptConfigType, OptMultiConfig +from mmengine.structures import InstanceData +from torch import Tensor + + +@MODELS.register_module() +class MASA(BaseMOTModel): + + """Matching Anything By Segmenting Anything. + + This multi object tracker is the implementation of `MASA + https://arxiv.org/abs/2406.04221`. + + Args: + backbone (dict, optional): Configuration of backbone. Defaults to None. + detector (dict, optional): Configuration of detector. Defaults to None. + masa_adapter (dict, optional): Configuration of MASA adapter. Defaults to None. + rpn_head (dict, optional): Configuration of RPN head. Defaults to None. + roi_head (dict, optional): Configuration of RoI head. Defaults to None. + track_head (dict, optional): Configuration of track head. Defaults to None. + tracker (dict, optional): Configuration of tracker. Defaults to None. + freeze_detector (bool): If True, freeze the detector weights. Defaults to False. + freeze_masa_backbone (bool): If True, freeze the MASA backbone weights. Defaults to False. + freeze_masa_adapter (bool): If True, freeze the MASA adapter weights. Defaults to False. + freeze_object_prior_distillation (bool): If True, freeze the object prior distillation. Defaults to False. + data_preprocessor (dict or ConfigDict, optional): The pre-process config of :class:`TrackDataPreprocessor`. + It usually includes, ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``. Defaults to None. + train_cfg (dict or ConfigDict, optional): Training configuration. Defaults to None. + test_cfg (dict or ConfigDict, optional): Testing configuration. Defaults to None. + init_cfg (dict or list[dict], optional): Configuration of initialization. Defaults to None. + load_public_dets (bool): If True, load public detections. Defaults to False. + public_det_path (str, optional): Path to public detections. Required if load_public_dets is True. Defaults to None. + given_dets (bool): If True, detections are given. Defaults to False. + with_segm (bool): If True, segmentation masks are included. Defaults to False. + end_pkl_name (str): Suffix for pickle file names. Defaults to '.pth'. + unified_backbone (bool): If True, use a unified backbone. Defaults to False. + use_masa_backbone (bool): If True, use the MASA backbone. Defaults to False. + benchmark (str): Benchmark for evaluation. Defaults to 'tao'. + """ + + def __init__( + self, + backbone: Optional[dict] = None, + detector: Optional[dict] = None, + masa_adapter: Optional[dict] = None, + rpn_head: Optional[dict] = None, + roi_head: Optional[dict] = None, + track_head: Optional[dict] = None, + tracker: Optional[dict] = None, + freeze_detector: bool = False, + freeze_masa_backbone: bool = False, + freeze_masa_adapter: bool = False, + freeze_object_prior_distillation: bool = False, + data_preprocessor: OptConfigType = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + init_cfg: OptMultiConfig = None, + load_public_dets=False, + public_det_path=None, + given_dets=False, + with_segm=False, + end_pkl_name=".pth", + unified_backbone=False, + use_masa_backbone=False, + benchmark="tao", + ) -> None: + super().__init__(data_preprocessor, init_cfg) + + self.use_masa_backbone = use_masa_backbone + if use_masa_backbone: + assert ( + backbone is not None + ), "backbone must be set when using MASA backbone." + + if backbone is not None: + self.backbone = MODELS.build(backbone) + + if detector is not None: + self.detector = MODELS.build(detector) + + if masa_adapter is not None: + self.masa_adapter = MODELS.build(masa_adapter) + + if rpn_head is not None: + rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None + rpn_head_ = rpn_head.copy() + rpn_head_.update(train_cfg=rpn_train_cfg, test_cfg=test_cfg.rpn) + rpn_head_num_classes = rpn_head_.get("num_classes", None) + if rpn_head_num_classes is None: + rpn_head_.update(num_classes=1) + else: + if rpn_head_num_classes != 1: + warnings.warn( + "The `num_classes` should be 1 in RPN, but get " + f"{rpn_head_num_classes}, please set " + "rpn_head.num_classes = 1 in your config file." + ) + rpn_head_.update(num_classes=1) + self.rpn_head = MODELS.build(rpn_head_) + + if roi_head is not None: + # update train and test cfg here for now + rcnn_train_cfg = train_cfg.rcnn if train_cfg is not None else None + roi_head.update(train_cfg=rcnn_train_cfg) + roi_head.update(test_cfg=test_cfg.rcnn) + self.roi_head = MODELS.build(roi_head) + + if track_head is not None: + self.track_head = MODELS.build(track_head) + + if tracker is not None: + self.tracker = MODELS.build(tracker) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + self.freeze_detector = freeze_detector + self.freeze_masa_adapter = freeze_masa_adapter + self.freeze_object_prior_distillation = freeze_object_prior_distillation + self.freeze_masa_backbone = freeze_masa_backbone + + def set_to_eval(module, input): + module.eval() + + if self.freeze_detector: + assert ( + detector is not None + ), "detector must be set when freeze_detector is True." + self.freeze_module("detector") + # self.detector.backbone.register_forward_pre_hook(set_to_eval) + + if self.freeze_masa_adapter: + assert ( + masa_adapter is not None + ), "masa_adapter must be set when freeze_masa_adapter is True." + self.freeze_module("masa_adapter") + + self.masa_adapter.register_forward_pre_hook(set_to_eval) + + if self.freeze_object_prior_distillation: + assert ( + roi_head is not None + ), "roi_head must be set when freeze_object_prior_distillation is True." + assert ( + rpn_head is not None + ), "rpn_head must be set when freeze_object_prior_distillation is True." + self.freeze_module("roi_head") + self.freeze_module("rpn_head") + + if self.freeze_masa_backbone: + assert ( + backbone is not None + ), "backbone must be set when freeze_masa_backbone is True." + self.freeze_module("backbone") + self.backbone.register_forward_pre_hook(set_to_eval) + + if load_public_dets: + assert ( + public_det_path is not None + ), "load_public_dets and public_det_path must be set together." + self.benchmark = benchmark + self.load_public_dets = load_public_dets + self.public_det_path = public_det_path + self.with_segm = with_segm + self.end_pkl_name = end_pkl_name + self.given_dets = given_dets + + self.unified_backbone = unified_backbone + + @property + def with_rpn(self) -> bool: + """bool: whether the detector has RPN""" + return hasattr(self, "rpn_head") and self.rpn_head is not None + + @property + def with_roi_head(self) -> bool: + """bool: whether the detector has a RoI head""" + return hasattr(self, "roi_head") and self.roi_head is not None + + def predict( + self, + inputs: Tensor, + data_samples: TrackSampleList, + rescale: bool = True, + **kwargs, + ) -> TrackSampleList: + """Predict results from a video and data samples with post- processing. + + Args: + inputs (Tensor): of shape (N, T, C, H, W) encoding + input images. The N denotes batch size. + The T denotes the number of frames in a video. + data_samples (list[:obj:`TrackDataSample`]): The batch + data samples. It usually includes information such + as `video_data_samples`. + rescale (bool, Optional): If False, then returned bboxes and masks + will fit the scale of img, otherwise, returned bboxes and masks + will fit the scale of original image shape. Defaults to True. + + Returns: + TrackSampleList: Tracking results of the inputs. + """ + assert inputs.dim() == 5, "The img must be 5D Tensor (N, T, C, H, W)." + assert ( + inputs.size(0) == 1 + ), "MASA inference only support 1 batch size per gpu for now." + + assert len(data_samples) == 1, "MASA only support 1 batch size per gpu for now." + + track_data_sample = data_samples[0] + video_len = len(track_data_sample) + if track_data_sample[0].frame_id == 0: + self.tracker.reset() + + for frame_id in range(video_len): + img_data_sample = track_data_sample[frame_id] + single_img = inputs[:, frame_id].contiguous() + if self.load_public_dets: + img_name = img_data_sample.img_path + if img_name is not None: + if self.benchmark == "bdd": + pickle_name = img_name.replace( + "data/bdd/bdd100k/images/track/val/", "" + ).replace(".jpg", self.end_pkl_name) + elif self.benchmark == "tao": + pickle_name = img_name.replace("data/tao/frames/", "").replace( + ".jpg", self.end_pkl_name + ) + + path = os.path.join(self.public_det_path, pickle_name) + pickle_res = pickle.load(open(path, "rb")) + det_labels = torch.tensor(pickle_res["det_labels"]).to("cuda") + det_bboxes = ( + torch.tensor(pickle_res["det_bboxes"]).to("cuda").to(torch.float32) + ) + if len(det_bboxes) != 0: + if det_bboxes.size(1) == 4: + det_bboxes = torch.cat( + [ + det_bboxes, + torch.ones(det_bboxes.size(0), 1).to(det_bboxes.device), + ], + dim=1, + ) + + det_results = InstanceData() + det_results.labels = det_labels + det_results.bboxes = det_bboxes[:, :4] + det_results.scores = det_bboxes[:, 4] + + if self.with_segm: + segm_results = pickle_res["det_masks"] + det_results.masks = segm_results + + img_data_sample.pred_instances = det_results + + if self.unified_backbone: + if hasattr(self.detector.backbone, "with_text_model"): + x = self.detector.backbone.forward_image(single_img) + elif self.detector.__class__.__name__ == "SamMasa": + x = self.detector.backbone.forward_base_multi_level(single_img) + else: + x = self.detector.backbone(single_img) + elif self.use_masa_backbone: + x = self.backbone.forward(single_img) + x_m = self.masa_adapter(x) + + elif self.given_dets: + assert ( + "det_bboxes" in img_data_sample + ), "det_bboxes must be given when given_dets is True." + assert ( + "det_labels" in img_data_sample + ), "det_labels must be given when given_dets is True." + det_labels = img_data_sample.det_labels + det_bboxes = img_data_sample.det_bboxes + if len(det_bboxes) != 0: + if det_bboxes.size(1) == 4: + det_bboxes = torch.cat( + [ + det_bboxes, + torch.ones(det_bboxes.size(0), 1).to(det_bboxes.device), + ], + dim=1, + ) + det_results = InstanceData() + det_results.labels = det_labels + det_results.bboxes = det_bboxes[:, :4] + det_results.scores = det_bboxes[:, 4] + + img_data_sample.pred_instances = det_results + + if self.unified_backbone: + if hasattr(self.detector.backbone, "with_text_model"): + x = self.detector.backbone.forward_image(single_img) + elif self.detector.__class__.__name__ == "SamMasa": + x = self.detector.backbone.forward_base_multi_level(single_img) + else: + x = self.detector.backbone(single_img) + elif self.use_masa_backbone: + x = self.backbone.forward(single_img) + x_m = self.masa_adapter(x) + else: + if self.unified_backbone: + if hasattr(self.detector.backbone, "with_text_model"): + texts = img_data_sample.texts + ## fix some inconsistency caused by the implementation of yolo-world and mmdet + if type(texts[0]) == list: + new_texts = [text[0] for text in texts] + del img_data_sample.texts + img_data_sample.set_field( + new_texts, "texts", field_type="metainfo" + ) + ( + backbone_feats, + img_feats, + text_feats, + ) = self.detector.extract_feat(single_img, [img_data_sample]) + x_m = self.masa_adapter(backbone_feats) + img_data_sample = self.detector.predict( + single_img, + (img_feats, text_feats), + [img_data_sample], + rescale=rescale, + )[0] + else: + x = self.detector.backbone(single_img) + x_m = self.masa_adapter(x) + if self.detector.with_neck: + x = self.detector.neck(x) + + img_data_sample = self.detector.predict( + single_img, x, [img_data_sample], rescale=rescale + )[0] + else: + raise NotImplementedError + + frame_pred_track_instances = self.tracker.track( + model=self, + img=single_img, + feats=x_m, + data_sample=img_data_sample, + with_segm=self.with_segm, + **kwargs, + ) + if self.with_segm: + if frame_pred_track_instances.mask_inds is not None: + frame_pred_track_instances.masks = [ + img_data_sample.pred_instances.masks[i] + for i in frame_pred_track_instances.mask_inds + ] + + img_data_sample.pred_track_instances = frame_pred_track_instances + + return [track_data_sample] + + def parse_tensors(self, tensor_tuple, key_ids, ref_ids): + key_tensors = [] + ref_tensors = [] + device = tensor_tuple[0].device + for tensor in tensor_tuple: + key_tensors.append( + tensor.index_select( + 0, torch.tensor(key_ids, dtype=torch.long, device=device) + ) + ) + ref_tensors.append( + tensor.index_select( + 0, torch.tensor(ref_ids, dtype=torch.long, device=device) + ) + ) + + return list(key_tensors), list(ref_tensors) + + def loss( + self, inputs: Tensor, data_samples: TrackSampleList, **kwargs + ) -> Union[dict, tuple]: + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (Dict[str, Tensor]): of shape (N, T, C, H, W) encoding + input images. Typically these should be mean centered and std + scaled. The N denotes batch size. The T denotes the number of + frames. + data_samples (list[:obj:`TrackDataSample`]): The batch + data samples. It usually includes information such + as `video_data_samples`. + + Returns: + dict: A dictionary of loss components. + """ + # modify the inputs shape to fit mmdet + assert inputs.dim() == 5, "The img must be 5D Tensor (N, T, C, H, W)." + assert ( + inputs.size(1) == 2 + ), "MASA can only have 1 key frame and 1 reference frame." + if self.detector is not None: + self.detector.eval() + # split the data_samples into two aspects: key frames and reference + # frames + ref_data_samples, key_data_samples = [], [] + key_frame_inds, ref_frame_inds = [], [] + # set cat_id of gt_labels to 0 in RPN + for track_data_sample in data_samples: + key_frame_inds.append(track_data_sample.key_frames_inds[0]) + ref_frame_inds.append(track_data_sample.ref_frames_inds[0]) + key_data_sample = track_data_sample.get_key_frames()[0] + key_data_sample.gt_instances.labels = torch.zeros_like( + key_data_sample.gt_instances.labels + ) + key_data_samples.append(key_data_sample) + ref_data_sample = track_data_sample.get_ref_frames()[0] + ref_data_samples.append(ref_data_sample) + + key_frame_inds = torch.tensor(key_frame_inds, dtype=torch.int64) + ref_frame_inds = torch.tensor(ref_frame_inds, dtype=torch.int64) + batch_inds = torch.arange(len(inputs)) + key_imgs = inputs[batch_inds, key_frame_inds].contiguous() + ref_imgs = inputs[batch_inds, ref_frame_inds].contiguous() + + if self.use_masa_backbone: + x = self.backbone.forward(key_imgs) + ref_x = self.backbone.forward(ref_imgs) + + else: + if hasattr(self.detector.backbone, "with_text_model"): + x = self.detector.backbone.forward_image(key_imgs) + ref_x = self.detector.backbone.forward_image(ref_imgs) + elif self.detector.__class__.__name__ == "SamMasa": + x = self.detector.backbone.forward_base_multi_level(key_imgs) + ref_x = self.detector.backbone.forward_base_multi_level(ref_imgs) + else: + x = self.detector.backbone.forward(key_imgs) + ref_x = self.detector.backbone.forward(ref_imgs) + + x_m = self.masa_adapter(x) + ref_x_m = self.masa_adapter(ref_x) + + losses = dict() + + if self.with_rpn: + proposal_cfg = self.train_cfg.get("rpn_proposal", self.test_cfg.rpn) + key_rpn_data_samples = copy.deepcopy(key_data_samples) + ref_rpn_data_samples = copy.deepcopy(ref_data_samples) + # set cat_id of gt_labels to 0 in RPN + for data_sample in key_rpn_data_samples: + data_sample.gt_instances.labels = torch.zeros_like( + data_sample.gt_instances.labels + ) + for data_sample in ref_rpn_data_samples: + data_sample.gt_instances.labels = torch.zeros_like( + data_sample.gt_instances.labels + ) + + rpn_losses, rpn_results_list = self.rpn_head.loss_and_predict( + x_m, key_rpn_data_samples, proposal_cfg=proposal_cfg, **kwargs + ) + ref_rpn_results_list = self.rpn_head.predict( + ref_x_m, ref_rpn_data_samples, **kwargs + ) + + # avoid get same name with roi_head loss + keys = rpn_losses.keys() + for key in keys: + if "loss" in key and "rpn" not in key: + rpn_losses[f"rpn_{key}"] = rpn_losses.pop(key) + losses.update(rpn_losses) + else: + raise NotImplementedError("MASA only support with_rpn for now.") + + # roi_head loss + losses_detect = self.roi_head.loss( + x_m, rpn_results_list, key_data_samples, **kwargs + ) + losses.update(losses_detect) + + # tracking head loss + losses_track = self.track_head.loss( + x_m, ref_x_m, rpn_results_list, ref_rpn_results_list, data_samples, **kwargs + ) + losses.update(losses_track) + + return losses diff --git a/masa/models/necks/__init__.py b/masa/models/necks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9aae177888b9eb9f64b08602ed4b2edae8d40dbf --- /dev/null +++ b/masa/models/necks/__init__.py @@ -0,0 +1,4 @@ +from .deform_fusion import DeformFusion +from .simplefpn import SimpleFPN + +__all__ = ["DeformFusion", "SimpleFPN"] diff --git a/masa/models/necks/__pycache__/__init__.cpython-311.pyc b/masa/models/necks/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..748ea769a4a19806a2ed7d42d6ca37a3154c5f80 Binary files /dev/null and b/masa/models/necks/__pycache__/__init__.cpython-311.pyc differ diff --git a/masa/models/necks/__pycache__/deform_fusion.cpython-311.pyc b/masa/models/necks/__pycache__/deform_fusion.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..201a2c11338405800b76f92212dbcb9b253608cc Binary files /dev/null and b/masa/models/necks/__pycache__/deform_fusion.cpython-311.pyc differ diff --git a/masa/models/necks/__pycache__/simplefpn.cpython-311.pyc b/masa/models/necks/__pycache__/simplefpn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5def65eba3756b379682b87aafa280df7be9df2e Binary files /dev/null and b/masa/models/necks/__pycache__/simplefpn.cpython-311.pyc differ diff --git a/masa/models/necks/deform_fusion.py b/masa/models/necks/deform_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..c9052443fca2f72ee9dd152a55a4841185fd8b36 --- /dev/null +++ b/masa/models/necks/deform_fusion.py @@ -0,0 +1,192 @@ +""" +Author: Siyuan Li +Licensed: Apache-2.0 License +""" + +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from mmcv.cnn import build_norm_layer +from mmcv.ops.modulated_deform_conv import ModulatedDeformConv2d +from mmdet.registry import MODELS +from mmengine.model import BaseModule, constant_init, normal_init + +# Reference: +# https://github.com/microsoft/DynamicHead +# https://github.com/jshilong/SEPC + + +class LayerNormProxy(nn.Module): + def __init__(self, dim): + super().__init__() + self.norm = nn.LayerNorm(dim) + + def forward(self, x): + x = rearrange(x, "b c h w -> b h w c") + x = self.norm(x) + return rearrange(x, "b h w c -> b c h w") + + +class DyDCNv2(nn.Module): + """ModulatedDeformConv2d with normalization layer used in DyHead. + + This module cannot be configured with `conv_cfg=dict(type='DCNv2')` + because DyHead calculates offset and mask from middle-level feature. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + stride (int | tuple[int], optional): Stride of the convolution. + Default: 1. + norm_cfg (dict, optional): Config dict for normalization layer. + Default: dict(type='GN', num_groups=16, requires_grad=True). + """ + + def __init__( + self, + in_channels, + out_channels, + stride=1, + norm_cfg=dict(type="GN", num_groups=16, requires_grad=True), + ): + super().__init__() + self.with_norm = norm_cfg is not None + bias = not self.with_norm + self.conv = ModulatedDeformConv2d( + in_channels, out_channels, 3, stride=stride, padding=1, bias=bias + ) + if self.with_norm: + self.norm = build_norm_layer(norm_cfg, out_channels)[1] + + def forward(self, x, offset, mask): + """Forward function.""" + x = self.conv(x.contiguous(), offset, mask) + if self.with_norm: + x = self.norm(x) + return x + + +class DyHeadBlock(nn.Module): + """Modified DyHead Block for dynamic feature fusion. + We remove the task and scale aware attention in the original implementation. + + HSigmoid arguments in default act_cfg follow official code, not paper. + https://github.com/microsoft/DynamicHead/blob/master/dyhead/dyrelu.py + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + zero_init_offset (bool, optional): Whether to use zero init for + `spatial_conv_offset`. Default: True. + """ + + def __init__( + self, in_channels, out_channels, zero_init_offset=True, fix_upsample=False, + ): + super().__init__() + self.zero_init_offset = zero_init_offset + self.fix_upsample = fix_upsample + # (offset_x, offset_y, mask) * kernel_size_y * kernel_size_x + self.offset_and_mask_dim = 3 * 3 * 3 * 3 + self.offset_dim = 3 * 2 * 3 * 3 + + self.spatial_conv_offset = nn.Conv2d( + in_channels, self.offset_and_mask_dim, 3, padding=1 + ) + self.spatial_conv_high = DyDCNv2(in_channels, out_channels) + self.spatial_conv_mid = DyDCNv2(in_channels, out_channels) + self.spatial_conv_low = DyDCNv2(in_channels, out_channels, stride=2) + + self._init_weights() + + def _init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + normal_init(m, 0, 0.01) + if self.zero_init_offset: + constant_init(self.spatial_conv_offset, 0) + + def forward(self, x): + """Forward function.""" + outs = [] + for level in range(len(x)): + offset_and_mask = self.spatial_conv_offset(x[level]) + offset = offset_and_mask[:, : self.offset_dim, :, :] + mask = offset_and_mask[:, self.offset_dim :, :, :].sigmoid() + + # calculate offset and mask of DCNv2 from current feature + offsets = offset.split(offset.size(1) // 3, dim=1) + masks = mask.split(mask.size(1) // 3, dim=1) + + sum_feat = self.spatial_conv_mid(x[level], offsets[0], masks[0]) + summed_levels = 1 + if level > 0: + sum_feat += self.spatial_conv_low(x[level - 1], offsets[1], masks[1]) + summed_levels += 1 + if level < len(x) - 1: + if not self.fix_upsample: + # this upsample order is weird, but faster than natural order + # https://github.com/microsoft/DynamicHead/issues/25 + sum_feat += F.interpolate( + self.spatial_conv_high(x[level + 1], offsets[2], masks[2]), + size=x[level].shape[-2:], + mode="bilinear", + align_corners=True, + ) + else: + sum_feat += self.spatial_conv_high( + F.interpolate( + x[level + 1], + size=x[level].shape[-2:], + mode="bilinear", + align_corners=True, + ), + offsets[2], + masks[2], + ) + summed_levels += 1 + outs.append(sum_feat / summed_levels) + + return outs + +@MODELS.register_module() +class DeformFusion(BaseModule): + """Deformable Fusion Module for MASA.""" + + def __init__( + self, + in_channels, + out_channels, + num_blocks=6, + zero_init_offset=True, + fix_upsample=False, + init_cfg=None, + ): + assert init_cfg is None, ( + "To prevent abnormal initialization " + "behavior, init_cfg is not allowed to be set" + ) + super().__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.out_channels = out_channels + self.num_blocks = num_blocks + self.zero_init_offset = zero_init_offset + + dyhead_blocks = [] + for i in range(num_blocks): + in_channels = self.in_channels if i == 0 else self.out_channels + dyhead_blocks.append( + DyHeadBlock( + in_channels, + self.out_channels, + zero_init_offset=zero_init_offset, + fix_upsample=fix_upsample, + ) + ) + self.dyhead_blocks = nn.Sequential(*dyhead_blocks) + + def forward(self, inputs): + """Forward function.""" + assert isinstance(inputs, (tuple, list)) + outs = self.dyhead_blocks(inputs) + return tuple(outs) diff --git a/masa/models/necks/simplefpn.py b/masa/models/necks/simplefpn.py new file mode 100644 index 0000000000000000000000000000000000000000..278ae0953ff31e204124fb3ffc20e9961c530008 --- /dev/null +++ b/masa/models/necks/simplefpn.py @@ -0,0 +1,250 @@ +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmdet.registry import MODELS +from mmengine.model import BaseModule + + +class Norm2d(nn.Module): + def __init__(self, embed_dim): + super().__init__() + self.ln = nn.LayerNorm(embed_dim, eps=1e-6) + + def forward(self, x): + x = x.permute(0, 2, 3, 1) + x = self.ln(x) + x = x.permute(0, 3, 1, 2).contiguous() + return x + + +@MODELS.register_module() +class SimpleFPN(BaseModule): + r"""Simplified Feature Pyramid Network. + + This is an implementation of Simple FPN used in ViT Det. + + Args: + in_channels (List[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale) + num_outs (int): Number of output scales. + start_level (int): Index of the start input backbone level used to + build the feature pyramid. Default: 0. + end_level (int): Index of the end input backbone level (exclusive) to + build the feature pyramid. Default: -1, which means the last level. + add_extra_convs (bool | str): If bool, it decides whether to add conv + layers on top of the original feature maps. Default to False. + If True, it is equivalent to `add_extra_convs='on_input'`. + If str, it specifies the source feature map of the extra convs. + Only the following options are allowed + + - 'on_input': Last feat map of neck inputs (i.e. backbone feature). + - 'on_lateral': Last feature map after lateral convs. + - 'on_output': The last output feature map after fpn convs. + relu_before_extra_convs (bool): Whether to apply relu before the extra + conv. Default: False. + no_norm_on_lateral (bool): Whether to apply norm on lateral. + Default: False. + conv_cfg (dict): Config dict for convolution layer. Default: None. + norm_cfg (dict): Config dict for normalization layer. Default: None. + act_cfg (str): Config dict for activation layer in ConvModule. + Default: None. + upsample_cfg (dict): Config dict for interpolate layer. + Default: `dict(mode='nearest')` + init_cfg (dict or list[dict], optional): Initialization config dict. + + Example: + >>> import torch + >>> in_channels = [2, 3, 5, 7] + >>> scales = [340, 170, 84, 43] + >>> inputs = [torch.rand(1, c, s, s) + ... for c, s in zip(in_channels, scales)] + >>> self = FPN(in_channels, 11, len(in_channels)).eval() + >>> outputs = self.forward(inputs) + >>> for i in range(len(outputs)): + ... print(f'outputs[{i}].shape = {outputs[i].shape}') + outputs[0].shape = torch.Size([1, 11, 340, 340]) + outputs[1].shape = torch.Size([1, 11, 170, 170]) + outputs[2].shape = torch.Size([1, 11, 84, 84]) + outputs[3].shape = torch.Size([1, 11, 43, 43]) + """ + + def __init__( + self, + in_channels, + out_channels, + num_outs, + start_level=0, + end_level=-1, + add_extra_convs=False, + relu_before_extra_convs=False, + no_norm_on_lateral=False, + conv_cfg=None, + norm_cfg=None, + act_cfg=None, + use_residual=True, + upsample_cfg=dict(mode="nearest"), + init_cfg=dict(type="Xavier", layer="Conv2d", distribution="uniform"), + ): + super(SimpleFPN, self).__init__(init_cfg) + assert isinstance(in_channels, list) + self.in_channels = in_channels + self.out_channels = out_channels + self.num_ins = len(in_channels) + self.num_outs = num_outs + self.relu_before_extra_convs = relu_before_extra_convs + self.no_norm_on_lateral = no_norm_on_lateral + self.fp16_enabled = False + self.upsample_cfg = upsample_cfg.copy() + self.use_residual = use_residual + + if end_level == -1: + self.backbone_end_level = self.num_ins + assert num_outs >= self.num_ins - start_level + else: + # if end_level < inputs, no extra level is allowed + self.backbone_end_level = end_level + assert end_level <= len(in_channels) + assert num_outs == end_level - start_level + self.start_level = start_level + self.end_level = end_level + self.add_extra_convs = add_extra_convs + assert isinstance(add_extra_convs, (str, bool)) + if isinstance(add_extra_convs, str): + # Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output' + assert add_extra_convs in ("on_input", "on_lateral", "on_output") + elif add_extra_convs: # True + self.add_extra_convs = "on_input" + + self.lateral_convs = nn.ModuleList() + self.fpn_convs = nn.ModuleList() + + for i in range(self.start_level, self.backbone_end_level): + l_conv = ConvModule( + in_channels[i], + out_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg if not self.no_norm_on_lateral else None, + act_cfg=act_cfg, + inplace=False, + ) + fpn_conv = ConvModule( + out_channels, + out_channels, + 3, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False, + ) + # l_conv = checkpoint_wrapper(l_conv) + # fpn_conv = checkpoint_wrapper(fpn_conv) + + self.lateral_convs.append(l_conv) + self.fpn_convs.append(fpn_conv) + + # add extra conv layers (e.g., RetinaNet) + extra_levels = num_outs - self.backbone_end_level + self.start_level + if self.add_extra_convs and extra_levels >= 1: + for i in range(extra_levels): + if i == 0 and self.add_extra_convs == "on_input": + in_channels = self.in_channels[self.backbone_end_level - 1] + else: + in_channels = out_channels + extra_fpn_conv = ConvModule( + in_channels, + out_channels, + 3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False, + ) + self.fpn_convs.append(extra_fpn_conv) + + self.fpn1 = nn.Sequential( + nn.ConvTranspose2d( + self.in_channels[0], self.in_channels[0], kernel_size=2, stride=2 + ), + Norm2d(self.in_channels[0]), + nn.GELU(), + nn.ConvTranspose2d( + self.in_channels[0], self.in_channels[0], kernel_size=2, stride=2 + ), + ) + + self.fpn2 = nn.Sequential( + nn.ConvTranspose2d( + self.in_channels[0], self.in_channels[0], kernel_size=2, stride=2 + ), + ) + self.fpn3 = nn.Identity() + self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2) + + def forward(self, inputs): + """Forward function.""" + features = [] + ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4] + if isinstance(inputs, list): + assert len(inputs) == len(ops) + for i in range(len(ops)): + features.append(ops[i](inputs[i])) + else: + for i in range(len(ops)): + features.append(ops[i](inputs)) + + assert len(features) == len(self.in_channels) + + # build laterals + laterals = [ + lateral_conv(features[i + self.start_level]) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + + # build top-down path + used_backbone_levels = len(laterals) + if self.use_residual: + for i in range(used_backbone_levels - 1, 0, -1): + # In some cases, fixing `scale factor` (e.g. 2) is preferred, but + # it cannot co-exist with `size` in `F.interpolate`. + if "scale_factor" in self.upsample_cfg: + # fix runtime error of "+=" inplace operation in PyTorch 1.10 + laterals[i - 1] = laterals[i - 1] + F.interpolate( + laterals[i], **self.upsample_cfg + ) + else: + prev_shape = laterals[i - 1].shape[2:] + laterals[i - 1] = laterals[i - 1] + F.interpolate( + laterals[i], size=prev_shape, **self.upsample_cfg + ) + + # build outputs + # part 1: from original levels + outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)] + # part 2: add extra levels + if self.num_outs > len(outs): + # use max pool to get more levels on top of outputs + # (e.g., Faster R-CNN, Mask R-CNN) + if not self.add_extra_convs: + for i in range(self.num_outs - used_backbone_levels): + outs.append(F.max_pool2d(outs[-1], 1, stride=2)) + # add conv layers on top of original feature maps (RetinaNet) + else: + if self.add_extra_convs == "on_input": + extra_source = features[self.backbone_end_level - 1] + elif self.add_extra_convs == "on_lateral": + extra_source = laterals[-1] + elif self.add_extra_convs == "on_output": + extra_source = outs[-1] + else: + raise NotImplementedError + outs.append(self.fpn_convs[used_backbone_levels](extra_source)) + for i in range(used_backbone_levels + 1, self.num_outs): + if self.relu_before_extra_convs: + outs.append(self.fpn_convs[i](F.relu(outs[-1]))) + else: + outs.append(self.fpn_convs[i](outs[-1])) + return tuple(outs) diff --git a/masa/models/roi_heads/__init__.py b/masa/models/roi_heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..39fa5d754254f5747f4c65c14706ce68983827e9 --- /dev/null +++ b/masa/models/roi_heads/__init__.py @@ -0,0 +1,3 @@ +from .track_heads import MasaTrackHead + +__all__ = ["MasaTrackHead"] diff --git a/masa/models/roi_heads/__pycache__/__init__.cpython-311.pyc b/masa/models/roi_heads/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34808991853676d7a550ff9ea9e656d32769b975 Binary files /dev/null and b/masa/models/roi_heads/__pycache__/__init__.cpython-311.pyc differ diff --git a/masa/models/roi_heads/track_heads/__init__.py b/masa/models/roi_heads/track_heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..18aaa4376bdc6507f39467d8900c29e388489f81 --- /dev/null +++ b/masa/models/roi_heads/track_heads/__init__.py @@ -0,0 +1,3 @@ +from .masa_track_head import MasaTrackHead + +__all__ = ["MasaTrackHead"] diff --git a/masa/models/roi_heads/track_heads/__pycache__/__init__.cpython-311.pyc b/masa/models/roi_heads/track_heads/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f3b2ff25ea1660ef1131efc02287f1938eef9b1 Binary files /dev/null and b/masa/models/roi_heads/track_heads/__pycache__/__init__.cpython-311.pyc differ diff --git a/masa/models/roi_heads/track_heads/__pycache__/masa_track_head.cpython-311.pyc b/masa/models/roi_heads/track_heads/__pycache__/masa_track_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8a625320c277ee2ddd376d65ea7c4616236db5e Binary files /dev/null and b/masa/models/roi_heads/track_heads/__pycache__/masa_track_head.cpython-311.pyc differ diff --git a/masa/models/roi_heads/track_heads/masa_track_head.py b/masa/models/roi_heads/track_heads/masa_track_head.py new file mode 100644 index 0000000000000000000000000000000000000000..15b8b1843c42921d5d17c4d71f904a9bb5417b2a --- /dev/null +++ b/masa/models/roi_heads/track_heads/masa_track_head.py @@ -0,0 +1,189 @@ +from typing import List, Optional + +from mmdet.registry import MODELS, TASK_UTILS +from mmdet.structures import TrackSampleList +from mmdet.structures.bbox import bbox2roi +from mmdet.utils import InstanceList +from mmengine.model import BaseModule +from torch import Tensor + + +@MODELS.register_module() +class MasaTrackHead(BaseModule): + """The masa track head. This takes the features from masa adapter to produce the final """ + + def __init__( + self, + roi_extractor: Optional[dict] = None, + embed_head: Optional[dict] = None, + regress_head: Optional[dict] = None, + train_cfg: Optional[dict] = None, + test_cfg: Optional[dict] = None, + init_cfg: Optional[dict] = None, + **kwargs + ): + super().__init__(init_cfg=init_cfg) + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + if embed_head is not None: + self.init_embed_head(roi_extractor, embed_head) + + if regress_head is not None: + raise NotImplementedError("Regression head is not supported yet.") + + self.init_assigner_sampler() + + def init_embed_head(self, roi_extractor, embed_head) -> None: + """Initialize ``embed_head`` + + Args: + roi_extractor (dict, optional): Configuration of roi extractor. + Defaults to None. + embed_head (dict, optional): Configuration of embed head. Defaults + to None. + """ + self.roi_extractor = MODELS.build(roi_extractor) + self.embed_head = MODELS.build(embed_head) + + def init_assigner_sampler(self) -> None: + """Initialize assigner and sampler.""" + self.bbox_assigner = None + self.bbox_sampler = None + if self.train_cfg: + self.bbox_assigner = TASK_UTILS.build(self.train_cfg.assigner) + self.bbox_sampler = TASK_UTILS.build( + self.train_cfg.sampler, default_args=dict(context=self) + ) + + @property + def with_track(self) -> bool: + """bool: whether the multi-object tracker has an embed head""" + return hasattr(self, "embed_head") and self.embed_head is not None + + def extract_roi_feats(self, feats: List[Tensor], bboxes: List[Tensor]) -> Tensor: + """Extract roi features. + + Args: + feats (list[Tensor]): list of multi-level image features. + bboxes (list[Tensor]): list of bboxes in sampling result. + + Returns: + Tensor: The extracted roi features. + """ + rois = bbox2roi(bboxes) + bbox_feats = self.roi_extractor(feats[: self.roi_extractor.num_inputs], rois) + return bbox_feats + + def loss( + self, + key_feats: List[Tensor], + ref_feats: List[Tensor], + rpn_results_list: InstanceList, + ref_rpn_results_list: InstanceList, + data_samples: TrackSampleList, + **kwargs + ) -> dict: + """Calculate losses from a batch of inputs and data samples. + + Args: + key_feats (list[Tensor]): list of multi-level image features. + ref_feats (list[Tensor]): list of multi-level ref_img features. + rpn_results_list (list[:obj:`InstanceData`]): List of region + proposals of key img. + ref_rpn_results_list (list[:obj:`InstanceData`]): List of region + proposals of ref img. + data_samples (list[:obj:`TrackDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance`. + + Returns: + dict: A dictionary of loss components. + """ + assert self.with_track + num_imgs = len(data_samples) + batch_gt_instances = [] + ref_batch_gt_instances = [] + batch_gt_instances_ignore = [] + gt_match_indices_list = [] + for track_data_sample in data_samples: + key_data_sample = track_data_sample.get_key_frames()[0] + ref_data_sample = track_data_sample.get_ref_frames()[0] + batch_gt_instances.append(key_data_sample.gt_instances) + ref_batch_gt_instances.append(ref_data_sample.gt_instances) + if "ignored_instances" in key_data_sample: + batch_gt_instances_ignore.append(key_data_sample.ignored_instances) + else: + batch_gt_instances_ignore.append(None) + # get gt_match_indices + ins_ids = key_data_sample.gt_instances.instances_ids.tolist() + ref_ins_ids = ref_data_sample.gt_instances.instances_ids.tolist() + match_indices = Tensor( + [ + ref_ins_ids.index(i) if (i in ref_ins_ids and i > 0) else -1 + for i in ins_ids + ] + ).to(key_feats[0].device) + gt_match_indices_list.append(match_indices) + + key_sampling_results, ref_sampling_results = [], [] + for i in range(num_imgs): + rpn_results = rpn_results_list[i] + ref_rpn_results = ref_rpn_results_list[i] + # rename ref_rpn_results.bboxes to ref_rpn_results.priors + if "priors" not in rpn_results: + rpn_results.priors = rpn_results.pop("bboxes") + ref_rpn_results.priors = ref_rpn_results.pop("bboxes") + + assign_result = self.bbox_assigner.assign( + rpn_results, batch_gt_instances[i], batch_gt_instances_ignore[i] + ) + sampling_result = self.bbox_sampler.sample( + assign_result, + rpn_results, + batch_gt_instances[i], + feats=[lvl_feat[i][None] for lvl_feat in key_feats], + ) + key_sampling_results.append(sampling_result) + + ref_assign_result = self.bbox_assigner.assign( + ref_rpn_results, ref_batch_gt_instances[i], batch_gt_instances_ignore[i] + ) + ref_sampling_result = self.bbox_sampler.sample( + ref_assign_result, + ref_rpn_results, + ref_batch_gt_instances[i], + feats=[lvl_feat[i][None] for lvl_feat in ref_feats], + ) + ref_sampling_results.append(ref_sampling_result) + + key_bboxes = [res.pos_bboxes for res in key_sampling_results] + key_roi_feats = self.extract_roi_feats(key_feats, key_bboxes) + ref_bboxes = [res.bboxes for res in ref_sampling_results] + ref_roi_feats = self.extract_roi_feats(ref_feats, ref_bboxes) + + loss_track = self.embed_head.loss( + key_roi_feats, + ref_roi_feats, + key_sampling_results, + ref_sampling_results, + gt_match_indices_list, + ) + + return loss_track + + def predict(self, feats: List[Tensor], rescaled_bboxes: List[Tensor]) -> Tensor: + """Perform forward propagation of the tracking head and predict + tracking results on the features of the upstream network. + + Args: + feats (list[Tensor]): Multi level feature maps of `img`. + rescaled_bboxes (list[Tensor]): list of rescaled bboxes in sampling + result. + + Returns: + Tensor: The extracted track features. + """ + bbox_feats = self.extract_roi_feats(feats, rescaled_bboxes) + track_feats = self.embed_head.predict(bbox_feats) + return track_feats diff --git a/masa/models/sam/__init__.py b/masa/models/sam/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f92c36cd2b76f4b25dfbd8d3e1b7602a60aff135 --- /dev/null +++ b/masa/models/sam/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .image_encoder import ImageEncoderViT +from .mask_decoder import MaskDecoder +from .predictor import SamPredictor +from .prompt_encoder import PromptEncoder +from .sam import Sam +from .transformer import TwoWayTransformer +from .automatic_mask_generator import SamAutomaticMaskGenerator +from .build_sam import sam_model_registry + +__all__ = [ + "Sam", + "ImageEncoderViT", + "MaskDecoder", + "PromptEncoder", + "TwoWayTransformer", + "SamAutomaticMaskGenerator", + "SamPredictor", + "sam_model_registry", +] diff --git a/masa/models/sam/__pycache__/__init__.cpython-311.pyc b/masa/models/sam/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4bdba3d48e86533c6279dd604fa6614360841c46 Binary files /dev/null and b/masa/models/sam/__pycache__/__init__.cpython-311.pyc differ diff --git a/masa/models/sam/__pycache__/amg.cpython-311.pyc b/masa/models/sam/__pycache__/amg.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e436378b1b34abbbae95b55c21977009f6f8c670 Binary files /dev/null and b/masa/models/sam/__pycache__/amg.cpython-311.pyc differ diff --git a/masa/models/sam/__pycache__/automatic_mask_generator.cpython-311.pyc b/masa/models/sam/__pycache__/automatic_mask_generator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f0dd5063e913fba1586e54429aefcee104b7728 Binary files /dev/null and b/masa/models/sam/__pycache__/automatic_mask_generator.cpython-311.pyc differ diff --git a/masa/models/sam/__pycache__/build_sam.cpython-311.pyc b/masa/models/sam/__pycache__/build_sam.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6fea8340cd0be92d2824c50e9ab45dc10da9f9e Binary files /dev/null and b/masa/models/sam/__pycache__/build_sam.cpython-311.pyc differ diff --git a/masa/models/sam/__pycache__/common.cpython-311.pyc b/masa/models/sam/__pycache__/common.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c92a7e81cdbc98ab7a35d81fa3b3089753d2846e Binary files /dev/null and b/masa/models/sam/__pycache__/common.cpython-311.pyc differ diff --git a/masa/models/sam/__pycache__/image_encoder.cpython-311.pyc b/masa/models/sam/__pycache__/image_encoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..33c34480fd33ed477729f48e8956a203573fc75c Binary files /dev/null and b/masa/models/sam/__pycache__/image_encoder.cpython-311.pyc differ diff --git a/masa/models/sam/__pycache__/mask_decoder.cpython-311.pyc b/masa/models/sam/__pycache__/mask_decoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f584775c5e6f99afed152ce4c7390085a5784857 Binary files /dev/null and b/masa/models/sam/__pycache__/mask_decoder.cpython-311.pyc differ diff --git a/masa/models/sam/__pycache__/predictor.cpython-311.pyc b/masa/models/sam/__pycache__/predictor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5fccf2fc98f54eec6dec0d1d8c21deec74cc2d4f Binary files /dev/null and b/masa/models/sam/__pycache__/predictor.cpython-311.pyc differ diff --git a/masa/models/sam/__pycache__/prompt_encoder.cpython-311.pyc b/masa/models/sam/__pycache__/prompt_encoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ddfb91854da3cd51263bfa3c7e23051a8df3793 Binary files /dev/null and b/masa/models/sam/__pycache__/prompt_encoder.cpython-311.pyc differ diff --git a/masa/models/sam/__pycache__/sam.cpython-311.pyc b/masa/models/sam/__pycache__/sam.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a508f02b762b223de0688ebdae51032805a232a Binary files /dev/null and b/masa/models/sam/__pycache__/sam.cpython-311.pyc differ diff --git a/masa/models/sam/__pycache__/transformer.cpython-311.pyc b/masa/models/sam/__pycache__/transformer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..baeb904107a1d614328025c2b3c0023a9ce04061 Binary files /dev/null and b/masa/models/sam/__pycache__/transformer.cpython-311.pyc differ diff --git a/masa/models/sam/__pycache__/transforms.cpython-311.pyc b/masa/models/sam/__pycache__/transforms.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..542da43655e05b784f18f2f2887067840ae96eca Binary files /dev/null and b/masa/models/sam/__pycache__/transforms.cpython-311.pyc differ diff --git a/masa/models/sam/amg.py b/masa/models/sam/amg.py new file mode 100644 index 0000000000000000000000000000000000000000..86747edf2c01d763a48d49156933aa1499c2527c --- /dev/null +++ b/masa/models/sam/amg.py @@ -0,0 +1,340 @@ +import math +from copy import deepcopy +from itertools import product +from typing import Any, Dict, Generator, ItemsView, List, Tuple + +import numpy as np +import torch + + +class MaskData: + """ + A structure for storing masks and their related data in batched format. + Implements basic filtering and concatenation. + """ + + def __init__(self, **kwargs) -> None: + for v in kwargs.values(): + assert isinstance( + v, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats = dict(**kwargs) + + def __setitem__(self, key: str, item: Any) -> None: + assert isinstance( + item, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats[key] = item + + def __delitem__(self, key: str) -> None: + del self._stats[key] + + def __getitem__(self, key: str) -> Any: + return self._stats[key] + + def items(self) -> ItemsView[str, Any]: + return self._stats.items() + + def filter(self, keep: torch.Tensor) -> None: + for k, v in self._stats.items(): + if v is None: + self._stats[k] = None + elif isinstance(v, torch.Tensor): + self._stats[k] = v[torch.as_tensor(keep, device=v.device)] + elif isinstance(v, np.ndarray): + self._stats[k] = v[keep.detach().cpu().numpy()] + elif isinstance(v, list) and keep.dtype == torch.bool: + self._stats[k] = [a for i, a in enumerate(v) if keep[i]] + elif isinstance(v, list): + self._stats[k] = [v[i] for i in keep] + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def cat(self, new_stats: "MaskData") -> None: + for k, v in new_stats.items(): + if k not in self._stats or self._stats[k] is None: + self._stats[k] = deepcopy(v) + elif isinstance(v, torch.Tensor): + self._stats[k] = torch.cat([self._stats[k], v], dim=0) + elif isinstance(v, np.ndarray): + self._stats[k] = np.concatenate([self._stats[k], v], axis=0) + elif isinstance(v, list): + self._stats[k] = self._stats[k] + deepcopy(v) + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def to_numpy(self) -> None: + for k, v in self._stats.items(): + if isinstance(v, torch.Tensor): + self._stats[k] = v.detach().cpu().numpy() + + +def is_box_near_crop_edge( + boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 +) -> torch.Tensor: + """Filter masks at the edge of a crop, but not at the edge of the original image.""" + crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) + orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) + boxes = uncrop_boxes_xyxy(boxes, crop_box).float() + near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) + near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) + near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) + return torch.any(near_crop_edge, dim=1) + + +def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: + box_xywh = deepcopy(box_xyxy) + box_xywh[2] = box_xywh[2] - box_xywh[0] + box_xywh[3] = box_xywh[3] - box_xywh[1] + return box_xywh + + +def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: + assert len(args) > 0 and all( + len(a) == len(args[0]) for a in args + ), "Batched iteration must have inputs of all the same size." + n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) + for b in range(n_batches): + yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] + + +def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: + """ + Encodes masks to an uncompressed RLE, in the format expected by + pycoco tools. + """ + # Put in fortran order and flatten h,w + b, h, w = tensor.shape + tensor = tensor.permute(0, 2, 1).flatten(1) + + # Compute change indices + diff = tensor[:, 1:] ^ tensor[:, :-1] + change_indices = diff.nonzero() + + # Encode run length + out = [] + for i in range(b): + cur_idxs = change_indices[change_indices[:, 0] == i, 1] + cur_idxs = torch.cat( + [ + torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), + cur_idxs + 1, + torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), + ] + ) + btw_idxs = cur_idxs[1:] - cur_idxs[:-1] + counts = [] if tensor[i, 0] == 0 else [0] + counts.extend(btw_idxs.detach().cpu().tolist()) + out.append({"size": [h, w], "counts": counts}) + return out + + +def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: + """Compute a binary mask from an uncompressed RLE.""" + h, w = rle["size"] + mask = np.empty(h * w, dtype=bool) + idx = 0 + parity = False + for count in rle["counts"]: + mask[idx : idx + count] = parity + idx += count + parity ^= True + mask = mask.reshape(w, h) + return mask.transpose() # Put in C order + + +def area_from_rle(rle: Dict[str, Any]) -> int: + return sum(rle["counts"][1::2]) + + +def calculate_stability_score( + masks: torch.Tensor, mask_threshold: float, threshold_offset: float +) -> torch.Tensor: + """ + Computes the stability score for a batch of masks. The stability + score is the IoU between the binary masks obtained by thresholding + the predicted mask logits at high and low values. + """ + # One mask is always contained inside the other. + # Save memory by preventing unnecesary cast to torch.int64 + intersections = ( + (masks > (mask_threshold + threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + unions = ( + (masks > (mask_threshold - threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + return intersections / unions + + +def build_point_grid(n_per_side: int) -> np.ndarray: + """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" + offset = 1 / (2 * n_per_side) + points_one_side = np.linspace(offset, 1 - offset, n_per_side) + points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) + points_y = np.tile(points_one_side[:, None], (1, n_per_side)) + points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) + return points + + +def build_all_layer_point_grids( + n_per_side: int, n_layers: int, scale_per_layer: int +) -> List[np.ndarray]: + """Generates point grids for all crop layers.""" + points_by_layer = [] + for i in range(n_layers + 1): + n_points = int(n_per_side / (scale_per_layer ** i)) + points_by_layer.append(build_point_grid(n_points)) + return points_by_layer + + +def generate_crop_boxes( + im_size: Tuple, n_layers: int, overlap_ratio: float +) -> Tuple[List[List[int]], List[int]]: + """ + Generates a list of crop boxes of different sizes. Each layer + has (2**i)**2 boxes for the ith layer. + """ + crop_boxes, layer_idxs = [], [] + im_h, im_w = im_size + short_side = min(im_h, im_w) + + # Original image + crop_boxes.append([0, 0, im_w, im_h]) + layer_idxs.append(0) + + def crop_len(orig_len, n_crops, overlap): + return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) + + for i_layer in range(n_layers): + n_crops_per_side = 2 ** (i_layer + 1) + overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) + + crop_w = crop_len(im_w, n_crops_per_side, overlap) + crop_h = crop_len(im_h, n_crops_per_side, overlap) + + crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] + crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] + + # Crops in XYWH format + for x0, y0 in product(crop_box_x0, crop_box_y0): + box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] + crop_boxes.append(box) + layer_idxs.append(i_layer + 1) + + return crop_boxes, layer_idxs + + +def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) + # Check if boxes has a channel dimension + if len(boxes.shape) == 3: + offset = offset.unsqueeze(1) + return boxes + offset + + +def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0]], device=points.device) + # Check if points has a channel dimension + if len(points.shape) == 3: + offset = offset.unsqueeze(1) + return points + offset + + +def uncrop_masks( + masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int +) -> torch.Tensor: + x0, y0, x1, y1 = crop_box + if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: + return masks + # Coordinate transform masks + pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) + pad = (x0, pad_x - x0, y0, pad_y - y0) + return torch.nn.functional.pad(masks, pad, value=0) + + +def remove_small_regions( + mask: np.ndarray, area_thresh: float, mode: str +) -> Tuple[np.ndarray, bool]: + """ + Removes small disconnected regions and holes in a mask. Returns the + mask and an indicator of if the mask has been modified. + """ + import cv2 # type: ignore + + assert mode in ["holes", "islands"] + correct_holes = mode == "holes" + working_mask = (correct_holes ^ mask).astype(np.uint8) + n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) + sizes = stats[:, -1][1:] # Row 0 is background label + small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] + if len(small_regions) == 0: + return mask, False + fill_labels = [0] + small_regions + if not correct_holes: + fill_labels = [i for i in range(n_labels) if i not in fill_labels] + # If every region is below threshold, keep largest + if len(fill_labels) == 0: + fill_labels = [int(np.argmax(sizes)) + 1] + mask = np.isin(regions, fill_labels) + return mask, True + + +def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: + from pycocotools import mask as mask_utils # type: ignore + + h, w = uncompressed_rle["size"] + rle = mask_utils.frPyObjects(uncompressed_rle, h, w) + rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json + return rle + + +def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: + """ + Calculates boxes in XYXY format around masks. Return [0,0,0,0] for + an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. + """ + # torch.max below raises an error on empty inputs, just skip in this case + if torch.numel(masks) == 0: + return torch.zeros(*masks.shape[:-2], 4, device=masks.device) + + # Normalize shape to CxHxW + shape = masks.shape + h, w = shape[-2:] + if len(shape) > 2: + masks = masks.flatten(0, -3) + else: + masks = masks.unsqueeze(0) + + # Get top and bottom edges + in_height, _ = torch.max(masks, dim=-1) + in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] + bottom_edges, _ = torch.max(in_height_coords, dim=-1) + in_height_coords = in_height_coords + h * (~in_height) + top_edges, _ = torch.min(in_height_coords, dim=-1) + + # Get left and right edges + in_width, _ = torch.max(masks, dim=-2) + in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] + right_edges, _ = torch.max(in_width_coords, dim=-1) + in_width_coords = in_width_coords + w * (~in_width) + left_edges, _ = torch.min(in_width_coords, dim=-1) + + # If the mask is empty the right edge will be to the left of the left edge. + # Replace these boxes with [0, 0, 0, 0] + empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) + out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) + out = out * (~empty_filter).unsqueeze(-1) + + # Return to original shape + if len(shape) > 2: + out = out.reshape(*shape[:-2], 4) + else: + out = out[0] + + return out diff --git a/masa/models/sam/automatic_mask_generator.py b/masa/models/sam/automatic_mask_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..6c2d74fc52c82ad6e29a8c91231800a52f4cbfd4 --- /dev/null +++ b/masa/models/sam/automatic_mask_generator.py @@ -0,0 +1,431 @@ +""" +Modified from the original SAM +- No longer transfer the predicted mask to the CPU (since we need it on GPU later) +- No longer compute RLE +""" + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import torch +from torchvision.ops.boxes import batched_nms, box_area # type: ignore + +from .amg import (MaskData, area_from_rle, batch_iterator, batched_mask_to_box, + box_xyxy_to_xywh, build_all_layer_point_grids, + calculate_stability_score, coco_encode_rle, + generate_crop_boxes, is_box_near_crop_edge, + mask_to_rle_pytorch, remove_small_regions, rle_to_mask, + uncrop_boxes_xyxy, uncrop_masks, uncrop_points) +from .predictor import SamPredictor +from .sam import Sam + + +class SamAutomaticMaskGenerator: + def __init__( + self, + model: Sam, + points_per_side: Optional[int] = 32, + points_per_batch: int = 64, + pred_iou_thresh: float = 0.88, + stability_score_thresh: float = 0.95, + stability_score_offset: float = 1.0, + box_nms_thresh: float = 0.7, + crop_n_layers: int = 0, + crop_nms_thresh: float = 0.7, + crop_overlap_ratio: float = 512 / 1500, + crop_n_points_downscale_factor: int = 1, + point_grids: Optional[List[np.ndarray]] = None, + min_mask_region_area: int = 0, + output_mode: str = "binary_mask", + ) -> None: + """ + Using a SAM model, generates masks for the entire image. + Generates a grid of point prompts over the image, then filters + low quality and duplicate masks. The default settings are chosen + for SAM with a ViT-H backbone. + + Arguments: + model (Sam): The SAM model to use for mask prediction. + points_per_side (int or None): The number of points to be sampled + along one side of the image. The total number of points is + points_per_side**2. If None, 'point_grids' must provide explicit + point sampling. + points_per_batch (int): Sets the number of points run simultaneously + by the model. Higher numbers may be faster but use more GPU memory. + pred_iou_thresh (float): A filtering threshold in [0,1], using the + model's predicted mask quality. + stability_score_thresh (float): A filtering threshold in [0,1], using + the stability of the mask under changes to the cutoff used to binarize + the model's mask predictions. + stability_score_offset (float): The amount to shift the cutoff when + calculated the stability score. + box_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks. + crops_n_layers (int): If >0, mask prediction will be run again on + crops of the image. Sets the number of layers to run, where each + layer has 2**i_layer number of image crops. + crops_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks between different crops. + crop_overlap_ratio (float): Sets the degree to which crops overlap. + In the first crop layer, crops will overlap by this fraction of + the image length. Later layers with more crops scale down this overlap. + crop_n_points_downscale_factor (int): The number of points-per-side + sampled in layer n is scaled down by crop_n_points_downscale_factor**n. + point_grids (list(np.ndarray) or None): A list over explicit grids + of points used for sampling, normalized to [0,1]. The nth grid in the + list is used in the nth crop layer. Exclusive with points_per_side. + min_mask_region_area (int): If >0, postprocessing will be applied + to remove disconnected regions and holes in masks with area smaller + than min_mask_region_area. Requires opencv. + output_mode (str): The form masks are returned in. Can be 'binary_mask', + 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. + For large resolutions, 'binary_mask' may consume large amounts of + memory. + """ + + assert (points_per_side is None) != ( + point_grids is None + ), "Exactly one of points_per_side or point_grid must be provided." + if points_per_side is not None: + self.point_grids = build_all_layer_point_grids( + points_per_side, crop_n_layers, crop_n_points_downscale_factor, + ) + elif point_grids is not None: + self.point_grids = point_grids + else: + raise ValueError("Can't have both points_per_side and point_grid be None.") + + assert output_mode in [ + "binary_mask", + "uncompressed_rle", + "coco_rle", + ], f"Unknown output_mode {output_mode}." + if output_mode == "coco_rle": + from pycocotools import \ + mask as mask_utils # type: ignore # noqa: F401 + + if min_mask_region_area > 0: + import cv2 # type: ignore # noqa: F401 + + self.predictor = SamPredictor(model) + self.points_per_batch = points_per_batch + self.pred_iou_thresh = pred_iou_thresh + self.stability_score_thresh = stability_score_thresh + self.stability_score_offset = stability_score_offset + self.box_nms_thresh = box_nms_thresh + self.crop_n_layers = crop_n_layers + self.crop_nms_thresh = crop_nms_thresh + self.crop_overlap_ratio = crop_overlap_ratio + self.crop_n_points_downscale_factor = crop_n_points_downscale_factor + self.min_mask_region_area = min_mask_region_area + self.output_mode = output_mode + + @torch.no_grad() + def generate( + self, + image_features, + orig_size, + transformed_size, + positive_points: Optional[np.ndarray] = None, + negative_points: Optional[np.ndarray] = None, + ) -> List[Dict[str, Any]]: + """ + Generates masks for the given image. + + Arguments: + image (np.ndarray): The image to generate masks for, in HWC uint8 format. + + Returns: + list(dict(str, any)): A list over records for masks. Each record is + a dict containing the following keys: + segmentation (dict(str, any) or np.ndarray): The mask. If + output_mode='binary_mask', is an array of shape HW. Otherwise, + is a dictionary containing the RLE. + bbox (list(float)): The box around the mask, in XYWH format. + area (int): The area in pixels of the mask. + predicted_iou (float): The model's own prediction of the mask's + quality. This is filtered by the pred_iou_thresh parameter. + point_coords (list(list(float))): The point coordinates input + to the model to generate this mask. + stability_score (float): A measure of the mask's quality. This + is filtered on using the stability_score_thresh parameter. + crop_box (list(float)): The crop of the image used to generate + the mask, given in XYWH format. + """ + + # Generate masks + mask_data = self._generate_masks( + image_features, + orig_size, + transformed_size, + positive_points, + negative_points, + ) + + # Filter small disconnected regions and holes in masks + if self.min_mask_region_area > 0: + mask_data = self.postprocess_small_regions( + mask_data, + self.min_mask_region_area, + max(self.box_nms_thresh, self.crop_nms_thresh), + ) + + # # Encode masks + # if self.output_mode == "coco_rle": + # mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]] + # elif self.output_mode == "binary_mask": + # mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] + # else: + # mask_data["segmentations"] = mask_data["rles"] + + # # Write mask records + # curr_anns = [] + # for idx in range(len(mask_data["segmentations"])): + # ann = { + # "segmentation": mask_data["segmentations"][idx], + # "area": area_from_rle(mask_data["rles"][idx]), + # "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), + # "predicted_iou": mask_data["iou_preds"][idx].item(), + # "point_coords": [mask_data["points"][idx].tolist()], + # "stability_score": mask_data["stability_score"][idx].item(), + # "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), + # } + # curr_anns.append(ann) + + # return curr_anns + return mask_data + + def _generate_masks( + self, + image_features, + orig_size, + transformed_size, + positive_points: Optional[np.ndarray] = None, + negative_points: Optional[np.ndarray] = None, + ) -> MaskData: + # orig_size = image.shape[:2] + crop_boxes, layer_idxs = generate_crop_boxes( + orig_size, self.crop_n_layers, self.crop_overlap_ratio + ) + + # Iterate over image crops + data = MaskData() + for crop_box, layer_idx in zip(crop_boxes, layer_idxs): + crop_data = self._process_crop( + image_features, + crop_box, + layer_idx, + orig_size, + transformed_size, + positive_points, + negative_points, + ) + data.cat(crop_data) + + # Remove duplicate masks between crops + if len(crop_boxes) > 1: + # Prefer masks from smaller crops + scores = 1 / box_area(data["crop_boxes"]) + scores = scores.to(data["boxes"].device) + keep_by_nms = batched_nms( + data["boxes"].float(), + scores, + torch.zeros_like(data["boxes"][:, 0]), # categories + iou_threshold=self.crop_nms_thresh, + ) + data.filter(keep_by_nms) + + # data.to_numpy() + return data + + def _process_crop( + self, + image_features, + crop_box: List[int], + crop_layer_idx: int, + orig_size: Tuple[int, ...], + transformed_size: Tuple[int, ...], + positive_points: Optional[np.ndarray] = None, + negative_points: Optional[np.ndarray] = None, + ) -> MaskData: + # Crop the image and calculate embeddings + + self.predictor.set_image_features(image_features, orig_size, transformed_size) + cropped_im_size = orig_size + # Get points for this crop + points_scale = np.array(cropped_im_size)[None, ::-1] + if positive_points is not None: + points_for_image = positive_points * points_scale + else: + points_for_image = self.point_grids[crop_layer_idx] * points_scale + + if negative_points is not None: + negative_points = negative_points * points_scale + + # Generate masks for this crop in batches + data = MaskData() + for (points,) in batch_iterator(self.points_per_batch, points_for_image): + batch_data = self._process_batch( + points, negative_points, cropped_im_size, crop_box, orig_size + ) + data.cat(batch_data) + del batch_data + self.predictor.reset_image() + + # Remove duplicates within this crop. + keep_by_nms = batched_nms( + data["boxes"].float(), + data["iou_preds"], + torch.zeros_like(data["boxes"][:, 0]), # categories + iou_threshold=self.box_nms_thresh, + ) + data.filter(keep_by_nms) + + # Return to the original image frame + data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) + data["points"] = uncrop_points(data["points"], crop_box) + # data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) + + return data + + def _process_batch( + self, + points: np.ndarray, + negative_points: Optional[np.ndarray], + im_size: Tuple[int, ...], + crop_box: List[int], + orig_size: Tuple[int, ...], + ) -> MaskData: + orig_h, orig_w = orig_size + + # Run model on this batch + # negative_points = None + if negative_points is not None: + # with negative points + negative_points = np.repeat( + negative_points[None, :, :], points.shape[0], axis=0 + ) + points = np.concatenate([points[:, None, :], negative_points], axis=1) + transformed_points = self.predictor.transform.apply_coords(points, im_size) + in_points = torch.as_tensor( + transformed_points, device=self.predictor.device + ) + in_labels = torch.zeros( + (in_points.shape[0], in_points.shape[1]), + dtype=torch.int, + device=in_points.device, + ) + in_labels[:, 0] = 1 + else: + # positive points only + transformed_points = self.predictor.transform.apply_coords(points, im_size) + in_points = torch.as_tensor( + transformed_points, device=self.predictor.device + ) + in_labels = torch.ones( + in_points.shape[0], dtype=torch.int, device=in_points.device + ) + in_points = in_points[:, None, :] + in_labels = in_labels[:, None] + + masks, iou_preds, _ = self.predictor.predict_torch( + in_points, in_labels, multimask_output=True, return_logits=True, + ) + + # Serialize predictions and store in MaskData + data = MaskData( + masks=masks.flatten(0, 1), + iou_preds=iou_preds.flatten(0, 1), + points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)), + ) + del masks + + # Filter by predicted IoU + if self.pred_iou_thresh > 0.0: + keep_mask = data["iou_preds"] > self.pred_iou_thresh + data.filter(keep_mask) + + # Calculate stability score + data["stability_score"] = calculate_stability_score( + data["masks"], + self.predictor.model.mask_threshold, + self.stability_score_offset, + ) + if self.stability_score_thresh > 0.0: + keep_mask = data["stability_score"] >= self.stability_score_thresh + data.filter(keep_mask) + + # Threshold masks and calculate boxes + data["masks"] = data["masks"] > self.predictor.model.mask_threshold + data["boxes"] = batched_mask_to_box(data["masks"]) + + # Filter boxes that touch crop boundaries + keep_mask = ~is_box_near_crop_edge( + data["boxes"], crop_box, [0, 0, orig_w, orig_h] + ) + if not torch.all(keep_mask): + data.filter(keep_mask) + + # Compress to RLE + data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) + # data["rles"] = mask_to_rle_pytorch(data["masks"]) + # del data["masks"] + + return data + + @staticmethod + def postprocess_small_regions( + mask_data: MaskData, min_area: int, nms_thresh: float + ) -> MaskData: + """ + Removes small disconnected regions and holes in masks, then reruns + box NMS to remove any new duplicates. + + Edits mask_data in place. + + Requires open-cv as a dependency. + """ + if len(mask_data["rles"]) == 0: + return mask_data + + # Filter small disconnected regions and holes + new_masks = [] + scores = [] + for rle in mask_data["rles"]: + mask = rle_to_mask(rle) + + mask, changed = remove_small_regions(mask, min_area, mode="holes") + unchanged = not changed + mask, changed = remove_small_regions(mask, min_area, mode="islands") + unchanged = unchanged and not changed + + new_masks.append(torch.as_tensor(mask).unsqueeze(0)) + # Give score=0 to changed masks and score=1 to unchanged masks + # so NMS will prefer ones that didn't need postprocessing + scores.append(float(unchanged)) + + # Recalculate boxes and remove any new duplicates + masks = torch.cat(new_masks, dim=0) + boxes = batched_mask_to_box(masks) + keep_by_nms = batched_nms( + boxes.float(), + torch.as_tensor(scores), + torch.zeros_like(boxes[:, 0]), # categories + iou_threshold=nms_thresh, + ) + + # Only recalculate RLEs for masks that have changed + for i_mask in keep_by_nms: + if scores[i_mask] == 0.0: + mask_torch = masks[i_mask].unsqueeze(0) + mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] + mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly + mask_data.filter(keep_by_nms) + + return mask_data diff --git a/masa/models/sam/build_sam.py b/masa/models/sam/build_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..388cd062f677c16749ae735867eed839d718b004 --- /dev/null +++ b/masa/models/sam/build_sam.py @@ -0,0 +1,105 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from functools import partial + +import torch + +from ..sam import (ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, + TwoWayTransformer) + + +def build_sam_vit_h(checkpoint=None): + return _build_sam( + encoder_embed_dim=1280, + encoder_depth=32, + encoder_num_heads=16, + encoder_global_attn_indexes=[7, 15, 23, 31], + checkpoint=checkpoint, + ) + + +build_sam = build_sam_vit_h + + +def build_sam_vit_l(checkpoint=None): + return _build_sam( + encoder_embed_dim=1024, + encoder_depth=24, + encoder_num_heads=16, + encoder_global_attn_indexes=[5, 11, 17, 23], + checkpoint=checkpoint, + ) + + +def build_sam_vit_b(checkpoint=None): + return _build_sam( + encoder_embed_dim=768, + encoder_depth=12, + encoder_num_heads=12, + encoder_global_attn_indexes=[2, 5, 8, 11], + checkpoint=checkpoint, + ) + + +sam_model_registry = { + "default": build_sam_vit_h, + "vit_h": build_sam_vit_h, + "vit_l": build_sam_vit_l, + "vit_b": build_sam_vit_b, +} + + +def _build_sam( + encoder_embed_dim, + encoder_depth, + encoder_num_heads, + encoder_global_attn_indexes, + checkpoint=None, +): + prompt_embed_dim = 256 + image_size = 1024 + vit_patch_size = 16 + image_embedding_size = image_size // vit_patch_size + sam = Sam( + image_encoder=ImageEncoderViT( + depth=encoder_depth, + embed_dim=encoder_embed_dim, + img_size=image_size, + mlp_ratio=4, + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), + num_heads=encoder_num_heads, + patch_size=vit_patch_size, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=encoder_global_attn_indexes, + window_size=14, + out_chans=prompt_embed_dim, + ), + prompt_encoder=PromptEncoder( + embed_dim=prompt_embed_dim, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(image_size, image_size), + mask_in_chans=16, + ), + mask_decoder=MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, embedding_dim=prompt_embed_dim, mlp_dim=2048, num_heads=8, + ), + transformer_dim=prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + ), + pixel_mean=[123.675, 116.28, 103.53], + pixel_std=[58.395, 57.12, 57.375], + ) + sam.eval() + if checkpoint is not None: + with open(checkpoint, "rb") as f: + state_dict = torch.load(f) + sam.load_state_dict(state_dict) + return sam diff --git a/masa/models/sam/common.py b/masa/models/sam/common.py new file mode 100644 index 0000000000000000000000000000000000000000..554da79bdf6414465ff60be4ed2cfbf612021bbe --- /dev/null +++ b/masa/models/sam/common.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Type + +import torch +import torch.nn as nn + + +class MLPBlock(nn.Module): + def __init__( + self, embedding_dim: int, mlp_dim: int, act: Type[nn.Module] = nn.GELU, + ) -> None: + super().__init__() + self.lin1 = nn.Linear(embedding_dim, mlp_dim) + self.lin2 = nn.Linear(mlp_dim, embedding_dim) + self.act = act() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lin2(self.act(self.lin1(x))) + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x diff --git a/masa/models/sam/image_encoder.py b/masa/models/sam/image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..707ea16cf0b904bc7dd09fa8ad92ad23aefb50e6 --- /dev/null +++ b/masa/models/sam/image_encoder.py @@ -0,0 +1,432 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from functools import partial +from typing import Optional, Tuple, Type + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmdet.registry import MODELS + +from .common import LayerNorm2d, MLPBlock + + +# This class and its supporting functions below lightly adapted from the SAM ViTDet backbone. +@MODELS.register_module() +class ImageEncoderViT(nn.Module): + def __init__( + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + out_chans: int = 256, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = partial(torch.nn.LayerNorm, eps=1e-6), + act_layer: Type[nn.Module] = nn.GELU, + use_abs_pos: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + global_attn_indexes: Tuple[int, ...] = (), + out_indices: Tuple[int, ...] = (), + ) -> None: + """ + Args: + img_size (int): Input image size. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + global_attn_indexes (list): Indexes for blocks using global attention. + """ + super().__init__() + self.img_size = img_size + self.out_indices = out_indices + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + self.pos_embed: Optional[nn.Parameter] = None + if use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = nn.Parameter( + torch.zeros( + 1, img_size // patch_size, img_size // patch_size, embed_dim + ) + ) + + self.blocks = nn.ModuleList() + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i not in global_attn_indexes else 0, + input_size=(img_size // patch_size, img_size // patch_size), + ) + self.blocks.append(block) + + self.neck = nn.Sequential( + nn.Conv2d(embed_dim, out_chans, kernel_size=1, bias=False,), + LayerNorm2d(out_chans), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False,), + LayerNorm2d(out_chans), + ) + + def forward_base(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + if self.pos_embed is not None: + x = x + self.pos_embed + + for blk in self.blocks: + x = blk(x) + + return x.permute(0, 3, 1, 2) + + def forward_base_multi_level(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + if self.pos_embed is not None: + x = x + self.pos_embed + + features_list = [] + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in self.out_indices: + features_list.append(x.permute(0, 3, 1, 2)) + + return features_list + + def forward_neck(self, x: torch.Tensor) -> torch.Tensor: + x = self.neck(x) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.forward_base(x) + x = self.forward_neck(x) + return x + + +class Block(nn.Module): + """Transformer blocks with support of window attention and residual propagation blocks""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. If it equals 0, then + use global attention. + input_size (int or None): Input resolution for calculating the relative positional + parameter size. + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) + + self.norm2 = norm_layer(dim) + self.mlp = MLPBlock( + embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer + ) + + self.window_size = window_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + + return x + + +class Attention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool: If True, add a learnable bias to query, key, value. + rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (int or None): Input resolution for calculating the relative positional + parameter size. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + assert ( + input_size is not None + ), "Input size must be provided if using relative positional encoding." + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = ( + self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + ) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + attn = (q * self.scale) @ k.transpose(-2, -1) + + if self.use_rel_pos: + attn = add_decomposed_rel_pos( + attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W) + ) + + attn = attn.softmax(dim=-1) + x = ( + (attn @ v) + .view(B, self.num_heads, H, W, -1) + .permute(0, 2, 3, 1, 4) + .reshape(B, H, W, -1) + ) + x = self.proj(x) + + return x + + +def window_partition( + x: torch.Tensor, window_size: int +) -> Tuple[torch.Tensor, Tuple[int, int]]: + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = ( + x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + ) + return windows, (Hp, Wp) + + +def window_unpartition( + windows: torch.Tensor, + window_size: int, + pad_hw: Tuple[int, int], + hw: Tuple[int, int], +) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view( + B, Hp // window_size, Wp // window_size, window_size, window_size, -1 + ) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn = ( + attn.view(B, q_h, q_w, k_h, k_w) + + rel_h[:, :, :, :, None] + + rel_w[:, :, :, None, :] + ).view(B, q_h * q_w, k_h * k_w) + + return attn + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, int] = (16, 16), + stride: Tuple[int, int] = (16, 16), + padding: Tuple[int, int] = (0, 0), + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): embed_dim (int): Patch embedding dimension. + """ + super().__init__() + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x diff --git a/masa/models/sam/mask_decoder.py b/masa/models/sam/mask_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..a74307c854210b29645a12e427b21c2685c5d1ab --- /dev/null +++ b/masa/models/sam/mask_decoder.py @@ -0,0 +1,254 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Tuple, Type + +import torch +from mmdet.registry import MODELS +from mmengine.model import BaseModule +from torch import nn +from torch.nn import functional as F + +from .common import LayerNorm2d +from .transformer import TwoWayTransformer + + +@MODELS.register_module() +class MaskDecoder(BaseModule): + def __init__( + self, + *, + transformer_dim: int, + transformer: nn.Module = TwoWayTransformer( + depth=2, embedding_dim=256, mlp_dim=2048, num_heads=8, + ), + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + tranformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + transformer (nn.Module): the transformer used to predict masks + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + iou_head_depth (int): the depth of the MLP used to predict + mask quality + iou_head_hidden_dim (int): the hidden dimension of the MLP + used to predict mask quality + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d( + transformer_dim, transformer_dim // 4, kernel_size=2, stride=2 + ), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d( + transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2 + ), + activation(), + ) + self.output_hypernetworks_mlps = nn.ModuleList( + [ + MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) + for i in range(self.num_mask_tokens) + ] + ) + + self.iou_prediction_head = MLP( + transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth + ) + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + return_features: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (torch.Tensor): the embeddings from the image encoder + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted masks + torch.Tensor: batched predictions of mask quality + """ + if return_features: + ( + masks, + iou_pred, + hyper_in, + upscaled_embedding, + ) = self.predict_masks_with_feature( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + ) + else: + masks, iou_pred = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + ) + + # Select the correct mask or masks for outptu + if multimask_output: + mask_slice = slice(1, None) + else: + mask_slice = slice(0, 1) + masks = masks[:, mask_slice, :, :] + iou_pred = iou_pred[:, mask_slice] + + if return_features: + return masks, iou_pred, hyper_in, upscaled_embedding + + # Prepare output + return masks, iou_pred + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + output_tokens = torch.cat( + [self.iou_token.weight, self.mask_tokens.weight], dim=0 + ) + output_tokens = output_tokens.unsqueeze(0).expand( + sparse_prompt_embeddings.size(0), -1, -1 + ) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + src = src + dense_prompt_embeddings + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + iou_token_out = hs[:, 0, :] + mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + upscaled_embedding = self.output_upscaling(src) + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + hyper_in_list.append( + self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) + ) + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + + return masks, iou_pred + + def predict_masks_with_feature( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + output_tokens = torch.cat( + [self.iou_token.weight, self.mask_tokens.weight], dim=0 + ) + output_tokens = output_tokens.unsqueeze(0).expand( + sparse_prompt_embeddings.size(0), -1, -1 + ) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + src = src + dense_prompt_embeddings + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + iou_token_out = hs[:, 0, :] + mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + upscaled_embedding = self.output_upscaling(src) + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + hyper_in_list.append( + self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) + ) + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + + return masks, iou_pred, hyper_in, upscaled_embedding + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x diff --git a/masa/models/sam/predictor.py b/masa/models/sam/predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..d2cd90002a59c53b3e8eec790b8626a03c8fd818 --- /dev/null +++ b/masa/models/sam/predictor.py @@ -0,0 +1,298 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple + +import numpy as np +import torch + +from .sam import Sam +from .transforms import ResizeLongestSide + + +class SamPredictor: + def __init__(self, sam_model: Sam,) -> None: + """ + Uses SAM to calculate the image embedding for an image, and then + allow repeated, efficient mask prediction given prompts. + + Arguments: + sam_model (Sam): The model to use for mask prediction. + """ + super().__init__() + self.model = sam_model + self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) + self.reset_image() + + def set_image(self, image: np.ndarray, image_format: str = "RGB",) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. + + Arguments: + image (np.ndarray): The image for calculating masks. Expects an + image in HWC uint8 format, with pixel values in [0, 255]. + image_format (str): The color format of the image, in ['RGB', 'BGR']. + """ + assert image_format in [ + "RGB", + "BGR", + ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." + if image_format != self.model.image_format: + image = image[..., ::-1] + + # Transform the image to the form expected by the model + input_image = self.transform.apply_image(image) + input_image_torch = torch.as_tensor(input_image, device=self.device) + input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[ + None, :, :, : + ] + + self.set_torch_image(input_image_torch, image.shape[:2]) + + @torch.no_grad() + def set_torch_image( + self, transformed_image: torch.Tensor, original_image_size: Tuple[int, ...], + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. Expects the input + image to be already transformed to the format expected by the model. + + Arguments: + transformed_image (torch.Tensor): The input image, with shape + 1x3xHxW, which has been transformed with ResizeLongestSide. + original_image_size (tuple(int, int)): The size of the image + before transformation, in (H, W) format. + """ + assert ( + len(transformed_image.shape) == 4 + and transformed_image.shape[1] == 3 + and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size + ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." + self.reset_image() + + self.original_size = original_image_size + self.input_size = tuple(transformed_image.shape[-2:]) + input_image = self.model.preprocess(transformed_image) + self.features = self.model.image_encoder(input_image) + self.is_image_set = True + + @torch.no_grad() + def set_image_features( + self, + image_features: torch.Tensor, + original_image_size: Tuple[int, ...], + transformed_image_size: Tuple[int, ...], + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. Expects the input + image to be already transformed to the format expected by the model. + + Arguments: + transformed_image (torch.Tensor): The input image, with shape + 1x3xHxW, which has been transformed with ResizeLongestSide. + original_image_size (tuple(int, int)): The size of the image + before transformation, in (H, W) format. + """ + + self.reset_image() + self.original_size = original_image_size + self.input_size = tuple(transformed_image_size) + self.features = image_features + self.is_image_set = True + + def predict( + self, + point_coords: Optional[np.ndarray] = None, + point_labels: Optional[np.ndarray] = None, + box: Optional[np.ndarray] = None, + mask_input: Optional[np.ndarray] = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks for the given input prompts, using the currently set image. + + Arguments: + point_coords (np.ndarray or None): A Nx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (np.ndarray or None): A length N array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + box (np.ndarray or None): A length 4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form 1xHxW, where + for SAM, H=W=256. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (np.ndarray): The output masks in CxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (np.ndarray): An array of length C containing the model's + predictions for the quality of each mask. + (np.ndarray): An array of shape CxHxW, where C is the number + of masks and H=W=256. These low resolution logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) before mask prediction." + ) + + # Transform input prompts + coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None + if point_coords is not None: + assert ( + point_labels is not None + ), "point_labels must be supplied if point_coords is supplied." + point_coords = self.transform.apply_coords(point_coords, self.original_size) + coords_torch = torch.as_tensor( + point_coords, dtype=torch.float, device=self.device + ) + labels_torch = torch.as_tensor( + point_labels, dtype=torch.int, device=self.device + ) + coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] + if box is not None: + box = self.transform.apply_boxes(box, self.original_size) + box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) + box_torch = box_torch[None, :] + if mask_input is not None: + mask_input_torch = torch.as_tensor( + mask_input, dtype=torch.float, device=self.device + ) + mask_input_torch = mask_input_torch[None, :, :, :] + + masks, iou_predictions, low_res_masks = self.predict_torch( + coords_torch, + labels_torch, + box_torch, + mask_input_torch, + multimask_output, + return_logits=return_logits, + ) + + masks = masks[0].detach().cpu().numpy() + iou_predictions = iou_predictions[0].detach().cpu().numpy() + low_res_masks = low_res_masks[0].detach().cpu().numpy() + return masks, iou_predictions, low_res_masks + + @torch.no_grad() + def predict_torch( + self, + point_coords: Optional[torch.Tensor], + point_labels: Optional[torch.Tensor], + boxes: Optional[torch.Tensor] = None, + mask_input: Optional[torch.Tensor] = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks for the given input prompts, using the currently set image. + Input prompts are batched torch tensors and are expected to already be + transformed to the input frame using ResizeLongestSide. + + Arguments: + point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (torch.Tensor or None): A BxN array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + box (np.ndarray or None): A Bx4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form Bx1xHxW, where + for SAM, H=W=256. Masks returned by a previous iteration of the + predict method do not need further transformation. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (torch.Tensor): The output masks in BxCxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (torch.Tensor): An array of shape BxC containing the model's + predictions for the quality of each mask. + (torch.Tensor): An array of shape BxCxHxW, where C is the number + of masks and H=W=256. These low res logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) before mask prediction." + ) + + if point_coords is not None: + points = (point_coords, point_labels) + else: + points = None + + # Embed prompts + sparse_embeddings, dense_embeddings = self.model.prompt_encoder( + points=points, boxes=boxes, masks=mask_input, + ) + + # Predict masks + low_res_masks, iou_predictions = self.model.mask_decoder( + image_embeddings=self.features, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + + # Upscale the masks to the original image resolution + masks = self.model.postprocess_masks( + low_res_masks, self.input_size, self.original_size + ) + + if not return_logits: + masks = masks > self.model.mask_threshold + + return masks, iou_predictions, low_res_masks + + def get_image_embedding(self) -> torch.Tensor: + """ + Returns the image embeddings for the currently set image, with + shape 1xCxHxW, where C is the embedding dimension and (H,W) are + the embedding spatial dimension of SAM (typically C=256, H=W=64). + """ + if not self.is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) to generate an embedding." + ) + assert ( + self.features is not None + ), "Features must exist if an image has been set." + return self.features + + @property + def device(self) -> torch.device: + return self.model.device + + def reset_image(self) -> None: + """Resets the currently set image.""" + self.is_image_set = False + self.features = None + self.orig_h = None + self.orig_w = None + self.input_h = None + self.input_w = None diff --git a/masa/models/sam/prompt_encoder.py b/masa/models/sam/prompt_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..1e6e22f900c9f1c3b41012dacd695993a2320e6b --- /dev/null +++ b/masa/models/sam/prompt_encoder.py @@ -0,0 +1,225 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Optional, Tuple, Type + +import numpy as np +import torch +from mmdet.registry import MODELS +from mmengine.model import BaseModule +from torch import nn + +from .common import LayerNorm2d + + +@MODELS.register_module() +class PromptEncoder(BaseModule): + def __init__( + self, + embed_dim: int, + image_embedding_size: Tuple[int, int], + input_image_size: Tuple[int, int], + mask_in_chans: int, + activation: Type[nn.Module] = nn.GELU, + ) -> None: + """ + Encodes prompts for input to SAM's mask decoder. + + Arguments: + embed_dim (int): The prompts' embedding dimension + image_embedding_size (tuple(int, int)): The spatial size of the + image embedding, as (H, W). + input_image_size (int): The padded size of the image as input + to the image encoder, as (H, W). + mask_in_chans (int): The number of hidden channels used for + encoding input masks. + activation (nn.Module): The activation to use when encoding + input masks. + """ + super().__init__() + self.embed_dim = embed_dim + self.input_image_size = input_image_size + self.image_embedding_size = image_embedding_size + self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) + + self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners + point_embeddings = [ + nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings) + ] + self.point_embeddings = nn.ModuleList(point_embeddings) + self.not_a_point_embed = nn.Embedding(1, embed_dim) + + self.mask_input_size = ( + 4 * image_embedding_size[0], + 4 * image_embedding_size[1], + ) + self.mask_downscaling = nn.Sequential( + nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans // 4), + activation(), + nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans), + activation(), + nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), + ) + self.no_mask_embed = nn.Embedding(1, embed_dim) + + def get_dense_pe(self) -> torch.Tensor: + """ + Returns the positional encoding used to encode point prompts, + applied to a dense set of points the shape of the image encoding. + + Returns: + torch.Tensor: Positional encoding with shape + 1x(embed_dim)x(embedding_h)x(embedding_w) + """ + return self.pe_layer(self.image_embedding_size).unsqueeze(0) + + def _embed_points( + self, points: torch.Tensor, labels: torch.Tensor, pad: bool, + ) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) + padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) + points = torch.cat([points, padding_point], dim=1) + labels = torch.cat([labels, padding_label], dim=1) + point_embedding = self.pe_layer.forward_with_coords( + points, self.input_image_size + ) + point_embedding[labels == -1] = 0.0 + point_embedding[labels == -1] += self.not_a_point_embed.weight + point_embedding[labels == 0] += self.point_embeddings[0].weight + point_embedding[labels == 1] += self.point_embeddings[1].weight + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + coords = boxes.reshape(-1, 2, 2) + corner_embedding = self.pe_layer.forward_with_coords( + coords, self.input_image_size + ) + corner_embedding[:, 0, :] += self.point_embeddings[2].weight + corner_embedding[:, 1, :] += self.point_embeddings[3].weight + return corner_embedding + + def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: + """Embeds mask inputs.""" + mask_embedding = self.mask_downscaling(masks) + return mask_embedding + + def _get_batch_size( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> int: + """ + Gets the batch size of the output given the batch size of the input prompts. + """ + if points is not None: + return points[0].shape[0] + elif boxes is not None: + return boxes.shape[0] + elif masks is not None: + return masks.shape[0] + else: + return 1 + + def _get_device(self) -> torch.device: + return self.point_embeddings[0].weight.device + + def forward( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense + embeddings. + + Arguments: + points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates + and labels to embed. + boxes (torch.Tensor or none): boxes to embed + masks (torch.Tensor or none): masks to embed + + Returns: + torch.Tensor: sparse embeddings for the points and boxes, with shape + BxNx(embed_dim), where N is determined by the number of input points + and boxes. + torch.Tensor: dense embeddings for the masks, in the shape + Bx(embed_dim)x(embed_H)x(embed_W) + """ + bs = self._get_batch_size(points, boxes, masks) + sparse_embeddings = torch.empty( + (bs, 0, self.embed_dim), device=self._get_device() + ) + if points is not None: + coords, labels = points + point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) + sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) + if boxes is not None: + box_embeddings = self._embed_boxes(boxes) + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) + + if masks is not None: + dense_embeddings = self._embed_masks(masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + + return sparse_embeddings, dense_embeddings + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer( + "positional_encoding_gaussian_matrix", + scale * torch.randn((2, num_pos_feats)), + ) + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int]) -> torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w = size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w), device=device, dtype=torch.float32) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + return pe.permute(2, 0, 1) # C x H x W + + def forward_with_coords( + self, coords_input: torch.Tensor, image_size: Tuple[int, int] + ) -> torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[1] + coords[:, :, 1] = coords[:, :, 1] / image_size[0] + return self._pe_encoding(coords.to(torch.float)) # B x N x C diff --git a/masa/models/sam/sam.py b/masa/models/sam/sam.py new file mode 100644 index 0000000000000000000000000000000000000000..6f5d0a73e7daafad74ee2d3ee6272601d55d7c00 --- /dev/null +++ b/masa/models/sam/sam.py @@ -0,0 +1,175 @@ +from typing import Any, Dict, List, Tuple + +import torch +from mmdet.registry import MODELS +from torch import nn +from torch.nn import functional as F + +from .image_encoder import ImageEncoderViT +from .mask_decoder import MaskDecoder +from .prompt_encoder import PromptEncoder + + +@MODELS.register_module() +class Sam(nn.Module): + mask_threshold: float = 0.0 + image_format: str = "RGB" + + def __init__( + self, + image_encoder: ImageEncoderViT, + prompt_encoder: PromptEncoder, + mask_decoder: MaskDecoder, + pixel_mean: List[float] = [123.675, 116.28, 103.53], + pixel_std: List[float] = [58.395, 57.12, 57.375], + ) -> None: + """ + SAM predicts object masks from an image and input prompts. + + Arguments: + image_encoder (ImageEncoderViT): The backbone used to encode the + image into image embeddings that allow for efficient mask prediction. + prompt_encoder (PromptEncoder): Encodes various types of input prompts. + mask_decoder (MaskDecoder): Predicts masks from the image embeddings + and encoded prompts. + pixel_mean (list(float)): Mean values for normalizing pixels in the input image. + pixel_std (list(float)): Std values for normalizing pixels in the input image. + """ + super().__init__() + self.image_encoder = image_encoder + self.prompt_encoder = prompt_encoder + self.mask_decoder = mask_decoder + self.register_buffer( + "pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False + ) + self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) + + @property + def device(self) -> Any: + return self.pixel_mean.device + + @torch.no_grad() + def forward( + self, batched_input: List[Dict[str, Any]], multimask_output: bool, + ) -> List[Dict[str, torch.Tensor]]: + """ + Predicts masks end-to-end from provided images and prompts. + If prompts are not known in advance, using SamPredictor is + recommended over calling the model directly. + + Arguments: + batched_input (list(dict)): A list over input images, each a + dictionary with the following keys. A prompt key can be + excluded if it is not present. + 'image': The image as a torch tensor in 3xHxW format, + already transformed for input to the model. + 'original_size': (tuple(int, int)) The original size of + the image before transformation, as (H, W). + 'point_coords': (torch.Tensor) Batched point prompts for + this image, with shape BxNx2. Already transformed to the + input frame of the model. + 'point_labels': (torch.Tensor) Batched labels for point prompts, + with shape BxN. + 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. + Already transformed to the input frame of the model. + 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, + in the form Bx1xHxW. + multimask_output (bool): Whether the model should predict multiple + disambiguating masks, or return a single mask. + + Returns: + (list(dict)): A list over input images, where each element is + as dictionary with the following keys. + 'masks': (torch.Tensor) Batched binary mask predictions, + with shape BxCxHxW, where B is the number of input promts, + C is determiend by multimask_output, and (H, W) is the + original size of the image. + 'iou_predictions': (torch.Tensor) The model's predictions + of mask quality, in shape BxC. + 'low_res_logits': (torch.Tensor) Low resolution logits with + shape BxCxHxW, where H=W=256. Can be passed as mask input + to subsequent iterations of prediction. + """ + input_images = torch.stack( + [self.preprocess(x["image"]) for x in batched_input], dim=0 + ) + image_embeddings = self.image_encoder(input_images) + + outputs = [] + for image_record, curr_embedding in zip(batched_input, image_embeddings): + if "point_coords" in image_record: + points = (image_record["point_coords"], image_record["point_labels"]) + else: + points = None + sparse_embeddings, dense_embeddings = self.prompt_encoder( + points=points, + boxes=image_record.get("boxes", None), + masks=image_record.get("mask_inputs", None), + ) + low_res_masks, iou_predictions = self.mask_decoder( + image_embeddings=curr_embedding.unsqueeze(0), + image_pe=self.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + masks = self.postprocess_masks( + low_res_masks, + input_size=image_record["image"].shape[-2:], + original_size=image_record["original_size"], + ) + masks = masks > self.mask_threshold + outputs.append( + { + "masks": masks, + "iou_predictions": iou_predictions, + "low_res_logits": low_res_masks, + } + ) + + return outputs + + def postprocess_masks( + self, + masks: torch.Tensor, + input_size: Tuple[int, ...], + original_size: Tuple[int, ...], + ) -> torch.Tensor: + """ + Remove padding and upscale masks to the original image size. + + Arguments: + masks (torch.Tensor): Batched masks from the mask_decoder, + in BxCxHxW format. + input_size (tuple(int, int)): The size of the image input to the + model, in (H, W) format. Used to remove padding. + original_size (tuple(int, int)): The original size of the image + before resizing for input to the model, in (H, W) format. + + Returns: + (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) + is given by original_size. + """ + masks = F.interpolate( + masks, + (self.image_encoder.img_size, self.image_encoder.img_size), + mode="bilinear", + align_corners=False, + ) + masks = masks[..., : input_size[0], : input_size[1]] + masks = F.interpolate( + masks, original_size, mode="bilinear", align_corners=False + ) + return masks + + def preprocess(self, x: torch.Tensor) -> torch.Tensor: + """Normalize pixel values and pad to a square input.""" + # Normalize colors + x = (x - self.pixel_mean) / self.pixel_std + + # Pad + h, w = x.shape[-2:] + padh = self.image_encoder.img_size - h + padw = self.image_encoder.img_size - w + x = F.pad(x, (0, padw, 0, padh)) + return x diff --git a/masa/models/sam/transformer.py b/masa/models/sam/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..ff4bd8a456116921f910e55ae077af5c9a86dcac --- /dev/null +++ b/masa/models/sam/transformer.py @@ -0,0 +1,233 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Tuple, Type + +import torch +from torch import Tensor, nn + +from .common import MLPBlock + + +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + Args: + depth (int): number of layers in the transformer + embedding_dim (int): the channel dimension for the input embeddings + num_heads (int): the number of heads for multihead attention. Must + divide embedding_dim + mlp_dim (int): the channel dimension internal to the MLP block + activation (nn.Module): the activation to use in the MLP block + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + self.final_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, image_embedding: Tensor, image_pe: Tensor, point_embedding: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (torch.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (torch.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (torch.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding + torch.Tensor: the processed image_embedding + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + bs, c, h, w = image_embedding.shape + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys = layer( + queries=queries, keys=keys, query_pe=point_embedding, key_pe=image_pe, + ) + + # Apply the final attenion layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor + ) -> Tuple[Tensor, Tensor]: + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, embedding_dim: int, num_heads: int, downsample_rate: int = 1, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert ( + self.internal_dim % num_heads == 0 + ), "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(embedding_dim, self.internal_dim) + self.v_proj = nn.Linear(embedding_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Attention + _, _, _, c_per_head = q.shape + attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens + attn = attn / math.sqrt(c_per_head) + attn = torch.softmax(attn, dim=-1) + + # Get output + out = attn @ v + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out diff --git a/masa/models/sam/transforms.py b/masa/models/sam/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..53682ea95b8be4cdac70ee65ce62d6a33977e2dd --- /dev/null +++ b/masa/models/sam/transforms.py @@ -0,0 +1,113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from copy import deepcopy +from typing import Tuple + +import numpy as np +import torch +from torch.nn import functional as F +from torchvision.transforms.functional import (resize, # type: ignore + to_pil_image) + + +class ResizeLongestSide: + """ + Resizes images to longest side 'target_length', as well as provides + methods for resizing coordinates and boxes. Provides methods for + transforming both numpy array and batched torch tensors. + """ + + def __init__(self, target_length: int) -> None: + self.target_length = target_length + + def apply_image(self, image: np.ndarray) -> np.ndarray: + """ + Expects a numpy array with shape HxWxC in uint8 format. + """ + target_size = self.get_preprocess_shape( + image.shape[0], image.shape[1], self.target_length + ) + return np.array(resize(to_pil_image(image), target_size)) + + def apply_coords( + self, coords: np.ndarray, original_size: Tuple[int, ...] + ) -> np.ndarray: + """ + Expects a numpy array of length 2 in the final dimension. Requires the + original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.get_preprocess_shape( + original_size[0], original_size[1], self.target_length + ) + coords = deepcopy(coords).astype(float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes( + self, boxes: np.ndarray, original_size: Tuple[int, ...] + ) -> np.ndarray: + """ + Expects a numpy array shape Bx4. Requires the original image size + in (H, W) format. + """ + boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: + """ + Expects batched images with shape BxCxHxW and float format. This + transformation may not exactly match apply_image. apply_image is + the transformation expected by the model. + """ + # Expects an image in BCHW format. May not exactly match apply_image. + target_size = self.get_preprocess_shape( + image.shape[0], image.shape[1], self.target_length + ) + return F.interpolate( + image, target_size, mode="bilinear", align_corners=False, antialias=True + ) + + def apply_coords_torch( + self, coords: torch.Tensor, original_size: Tuple[int, ...] + ) -> torch.Tensor: + """ + Expects a torch tensor with length 2 in the last dimension. Requires the + original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.get_preprocess_shape( + original_size[0], original_size[1], self.target_length + ) + coords = deepcopy(coords).to(torch.float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes_torch( + self, boxes: torch.Tensor, original_size: Tuple[int, ...] + ) -> torch.Tensor: + """ + Expects a torch tensor with shape Bx4. Requires the original image + size in (H, W) format. + """ + boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + @staticmethod + def get_preprocess_shape( + oldh: int, oldw: int, long_side_length: int + ) -> Tuple[int, int]: + """ + Compute the output size given input size and target long side length. + """ + scale = long_side_length * 1.0 / max(oldh, oldw) + newh, neww = oldh * scale, oldw * scale + neww = int(neww + 0.5) + newh = int(newh + 0.5) + return (newh, neww) diff --git a/masa/models/tracker/__init__.py b/masa/models/tracker/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..609656c09ca0b7b8fad95d0ded2b7a6238b502b9 --- /dev/null +++ b/masa/models/tracker/__init__.py @@ -0,0 +1,2 @@ +from .masa_bdd_tracker import MasaBDDTracker +from .masa_tao_tracker import MasaTaoTracker diff --git a/masa/models/tracker/__pycache__/__init__.cpython-311.pyc b/masa/models/tracker/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..73a590958417bf7cb273dbfbe778137dcc72e000 Binary files /dev/null and b/masa/models/tracker/__pycache__/__init__.cpython-311.pyc differ diff --git a/masa/models/tracker/__pycache__/masa_bdd_tracker.cpython-311.pyc b/masa/models/tracker/__pycache__/masa_bdd_tracker.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9383db037d3ba00f0aa6b1b69f9297a7a238729e Binary files /dev/null and b/masa/models/tracker/__pycache__/masa_bdd_tracker.cpython-311.pyc differ diff --git a/masa/models/tracker/__pycache__/masa_tao_tracker.cpython-311.pyc b/masa/models/tracker/__pycache__/masa_tao_tracker.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dbdae717921e92991818c54ddf083724ee3a8927 Binary files /dev/null and b/masa/models/tracker/__pycache__/masa_tao_tracker.cpython-311.pyc differ diff --git a/masa/models/tracker/masa_bdd_tracker.py b/masa/models/tracker/masa_bdd_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..8c18e19780a45e5b54edf25811a05396d30601ea --- /dev/null +++ b/masa/models/tracker/masa_bdd_tracker.py @@ -0,0 +1,351 @@ +""" +Author: Siyuan Li +Licensed: Apache-2.0 License +""" + +from typing import List, Tuple + +import torch +import torch.nn.functional as F +from mmdet.models.trackers.base_tracker import BaseTracker +from mmdet.registry import MODELS +from mmdet.structures import TrackDataSample +from mmdet.structures.bbox import bbox_overlaps +from mmengine.structures import InstanceData +from torch import Tensor + + +@MODELS.register_module() +class MasaBDDTracker(BaseTracker): + """Tracker for MASA on BDD benchmark. + + Args: + init_score_thr (float): The cls_score threshold to + initialize a new tracklet. Defaults to 0.8. + obj_score_thr (float): The cls_score threshold to + update a tracked tracklet. Defaults to 0.5. + match_score_thr (float): The match threshold. Defaults to 0.5. + memo_tracklet_frames (int): The most frames in a tracklet memory. + Defaults to 10. + memo_backdrop_frames (int): The most frames in the backdrops. + Defaults to 1. + memo_momentum (float): The momentum value for embeds updating. + Defaults to 0.8. + nms_conf_thr (float): The NMS threshold for confidence. + Defaults to 0.5. + nms_backdrop_iou_thr (float): The NMS threshold for backdrop IoU. + Defaults to 0.3. + nms_class_iou_thr (float): The NMS threshold for class IoU. + Defaults to 0.7. + with_cats (bool): Whether to track with the same category. + Defaults to False. + match_metric (str): The match metric. Can be 'bisoftmax', 'softmax', or 'cosine'. Defaults to 'bisoftmax'. + """ + + def __init__( + self, + init_score_thr: float = 0.8, + obj_score_thr: float = 0.5, + match_score_thr: float = 0.5, + memo_tracklet_frames: int = 10, + memo_backdrop_frames: int = 1, + memo_momentum: float = 0.8, + nms_conf_thr: float = 0.5, + nms_backdrop_iou_thr: float = 0.3, + nms_class_iou_thr: float = 0.7, + with_cats: bool = False, + match_metric: str = "bisoftmax", + **kwargs + ): + super().__init__(**kwargs) + assert 0 <= memo_momentum <= 1.0 + assert memo_tracklet_frames >= 0 + assert memo_backdrop_frames >= 0 + + self.init_score_thr = init_score_thr + self.obj_score_thr = obj_score_thr + self.match_score_thr = match_score_thr + self.memo_tracklet_frames = memo_tracklet_frames + self.memo_backdrop_frames = memo_backdrop_frames + self.memo_momentum = memo_momentum + self.nms_conf_thr = nms_conf_thr + self.nms_backdrop_iou_thr = nms_backdrop_iou_thr + self.nms_class_iou_thr = nms_class_iou_thr + self.with_cats = with_cats + assert match_metric in ["bisoftmax", "softmax", "cosine"] + self.match_metric = match_metric + + self.num_tracks = 0 + self.tracks = dict() + self.backdrops = [] + + def reset(self): + """Reset the buffer of the tracker.""" + self.num_tracks = 0 + self.tracks = dict() + self.backdrops = [] + + def update( + self, + ids: Tensor, + bboxes: Tensor, + embeds: Tensor, + labels: Tensor, + scores: Tensor, + frame_id: int, + ) -> None: + """Tracking forward function. + + Args: + ids (Tensor): of shape(N, ). + bboxes (Tensor): of shape (N, 5). + embeds (Tensor): of shape (N, 256). + labels (Tensor): of shape (N, ). + scores (Tensor): of shape (N, ). + frame_id (int): The id of current frame, 0-index. + """ + tracklet_inds = ids > -1 + + for id, bbox, embed, label, score in zip( + ids[tracklet_inds], + bboxes[tracklet_inds], + embeds[tracklet_inds], + labels[tracklet_inds], + scores[tracklet_inds], + ): + id = int(id) + # update the tracked ones and initialize new tracks + if id in self.tracks.keys(): + velocity = (bbox - self.tracks[id]["bbox"]) / ( + frame_id - self.tracks[id]["last_frame"] + ) + self.tracks[id]["bbox"] = bbox + self.tracks[id]["embed"] = (1 - self.memo_momentum) * self.tracks[id][ + "embed" + ] + self.memo_momentum * embed + self.tracks[id]["last_frame"] = frame_id + self.tracks[id]["label"] = label + self.tracks[id]["score"] = score + self.tracks[id]["velocity"] = ( + self.tracks[id]["velocity"] * self.tracks[id]["acc_frame"] + + velocity + ) / (self.tracks[id]["acc_frame"] + 1) + self.tracks[id]["acc_frame"] += 1 + else: + self.tracks[id] = dict( + bbox=bbox, + embed=embed, + label=label, + score=score, + last_frame=frame_id, + velocity=torch.zeros_like(bbox), + acc_frame=0, + ) + # backdrop update according to IoU + backdrop_inds = torch.nonzero(ids == -1, as_tuple=False).squeeze(1) + ious = bbox_overlaps(bboxes[backdrop_inds], bboxes) + for i, ind in enumerate(backdrop_inds): + if (ious[i, :ind] > self.nms_backdrop_iou_thr).any(): + backdrop_inds[i] = -1 + backdrop_inds = backdrop_inds[backdrop_inds > -1] + # old backdrops would be removed at first + self.backdrops.insert( + 0, + dict( + bboxes=bboxes[backdrop_inds], + embeds=embeds[backdrop_inds], + labels=labels[backdrop_inds], + ), + ) + + # pop memo + invalid_ids = [] + for k, v in self.tracks.items(): + if frame_id - v["last_frame"] >= self.memo_tracklet_frames: + invalid_ids.append(k) + for invalid_id in invalid_ids: + self.tracks.pop(invalid_id) + + if len(self.backdrops) > self.memo_backdrop_frames: + self.backdrops.pop() + + @property + def memo(self) -> Tuple[Tensor, ...]: + """Get tracks memory.""" + memo_embeds = [] + memo_ids = [] + memo_bboxes = [] + memo_labels = [] + # velocity of tracks + memo_vs = [] + # get tracks + for k, v in self.tracks.items(): + memo_bboxes.append(v["bbox"][None, :]) + memo_embeds.append(v["embed"][None, :]) + memo_ids.append(k) + memo_labels.append(v["label"].view(1, 1)) + memo_vs.append(v["velocity"][None, :]) + memo_ids = torch.tensor(memo_ids, dtype=torch.long).view(1, -1) + # get backdrops + for backdrop in self.backdrops: + backdrop_ids = torch.full( + (1, backdrop["embeds"].size(0)), -1, dtype=torch.long + ) + backdrop_vs = torch.zeros_like(backdrop["bboxes"]) + memo_bboxes.append(backdrop["bboxes"]) + memo_embeds.append(backdrop["embeds"]) + memo_ids = torch.cat([memo_ids, backdrop_ids], dim=1) + memo_labels.append(backdrop["labels"][:, None]) + memo_vs.append(backdrop_vs) + + memo_bboxes = torch.cat(memo_bboxes, dim=0) + memo_embeds = torch.cat(memo_embeds, dim=0) + memo_labels = torch.cat(memo_labels, dim=0).squeeze(1) + memo_vs = torch.cat(memo_vs, dim=0) + return memo_bboxes, memo_labels, memo_embeds, memo_ids.squeeze(0), memo_vs + + def track( + self, + model: torch.nn.Module, + img: torch.Tensor, + feats: List[torch.Tensor], + data_sample: TrackDataSample, + rescale=True, + with_segm=False, + **kwargs + ) -> InstanceData: + """Tracking forward function. + + Args: + model (nn.Module): MOT model. + img (Tensor): of shape (T, C, H, W) encoding input image. + Typically these should be mean centered and std scaled. + The T denotes the number of key images and usually is 1 in + QDTrack method. + feats (list[Tensor]): Multi level feature maps of `img`. + data_sample (:obj:`TrackDataSample`): The data sample. + It includes information such as `pred_instances`. + rescale (bool, optional): If True, the bounding boxes should be + rescaled to fit the original scale of the image. Defaults to + True. + + Returns: + :obj:`InstanceData`: Tracking results of the input images. + Each InstanceData usually contains ``bboxes``, ``labels``, + ``scores`` and ``instances_id``. + """ + metainfo = data_sample.metainfo + bboxes = data_sample.pred_instances.bboxes + labels = data_sample.pred_instances.labels + scores = data_sample.pred_instances.scores + + frame_id = metainfo.get("frame_id", -1) + # create pred_track_instances + pred_track_instances = InstanceData() + + # return zero bboxes if there is no track targets + if bboxes.shape[0] == 0: + ids = torch.zeros_like(labels) + pred_track_instances = data_sample.pred_instances.clone() + pred_track_instances.instances_id = ids + return pred_track_instances + + # get track feats + rescaled_bboxes = bboxes.clone() + if rescale: + scale_factor = rescaled_bboxes.new_tensor(metainfo["scale_factor"]).repeat( + (1, 2) + ) + rescaled_bboxes = rescaled_bboxes * scale_factor + track_feats = model.track_head.predict(feats, [rescaled_bboxes]) + # sort according to the object_score + _, inds = scores.sort(descending=True) + bboxes = bboxes[inds] + scores = scores[inds] + labels = labels[inds] + embeds = track_feats[inds, :] + if with_segm: + mask_inds = torch.arange(bboxes.size(0)).to(embeds.device) + mask_inds = mask_inds[inds] + else: + mask_inds = [] + + # duplicate removal for potential backdrops and cross classes + valids = bboxes.new_ones((bboxes.size(0))) + ious = bbox_overlaps(bboxes, bboxes) + for i in range(1, bboxes.size(0)): + thr = ( + self.nms_backdrop_iou_thr + if scores[i] < self.obj_score_thr + else self.nms_class_iou_thr + ) + if (ious[i, :i] > thr).any(): + valids[i] = 0 + valids = valids == 1 + bboxes = bboxes[valids] + scores = scores[valids] + labels = labels[valids] + embeds = embeds[valids, :] + if with_segm: + mask_inds = mask_inds[valids] + + # init ids container + ids = torch.full((bboxes.size(0),), -1, dtype=torch.long) + + # match if buffer is not empty + if bboxes.size(0) > 0 and not self.empty: + (memo_bboxes, memo_labels, memo_embeds, memo_ids, memo_vs) = self.memo + + if self.match_metric == "bisoftmax": + feats = torch.mm(embeds, memo_embeds.t()) + d2t_scores = feats.softmax(dim=1) + t2d_scores = feats.softmax(dim=0) + match_scores = (d2t_scores + t2d_scores) / 2 + elif self.match_metric == "softmax": + feats = torch.mm(embeds, memo_embeds.t()) + match_scores = feats.softmax(dim=1) + elif self.match_metric == "cosine": + match_scores = torch.mm( + F.normalize(embeds, p=2, dim=1), + F.normalize(memo_embeds, p=2, dim=1).t(), + ) + else: + raise NotImplementedError + # track with the same category + if self.with_cats: + cat_same = labels.view(-1, 1) == memo_labels.view(1, -1) + match_scores *= cat_same.float().to(match_scores.device) + # track according to match_scores + for i in range(bboxes.size(0)): + conf, memo_ind = torch.max(match_scores[i, :], dim=0) + id = memo_ids[memo_ind] + if conf > self.match_score_thr: + if id > -1: + # keep bboxes with high object score + # and remove background bboxes + if scores[i] > self.obj_score_thr: + ids[i] = id + match_scores[:i, memo_ind] = 0 + match_scores[i + 1 :, memo_ind] = 0 + else: + if conf > self.nms_conf_thr: + ids[i] = -2 + # initialize new tracks + new_inds = (ids == -1) & (scores > self.init_score_thr).cpu() + num_news = new_inds.sum() + ids[new_inds] = torch.arange( + self.num_tracks, self.num_tracks + num_news, dtype=torch.long + ) + self.num_tracks += num_news + + self.update(ids, bboxes, embeds, labels, scores, frame_id) + tracklet_inds = ids > -1 + # update pred_track_instances + pred_track_instances.bboxes = bboxes[tracklet_inds] + pred_track_instances.labels = labels[tracklet_inds] + pred_track_instances.scores = scores[tracklet_inds] + pred_track_instances.instances_id = ids[tracklet_inds] + if with_segm: + pred_track_instances.mask_inds = mask_inds[tracklet_inds] + + return pred_track_instances diff --git a/masa/models/tracker/masa_tao_tracker.py b/masa/models/tracker/masa_tao_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..5d9c00c0105b356a43ab748fb9d4dde9fec3aeff --- /dev/null +++ b/masa/models/tracker/masa_tao_tracker.py @@ -0,0 +1,383 @@ +""" +Author: Siyuan Li +Licensed: Apache-2.0 License +""" + +from typing import List, Tuple + +import torch +import torch.nn.functional as F +from mmdet.models.trackers.base_tracker import BaseTracker +from mmdet.registry import MODELS +from mmdet.structures import TrackDataSample +from mmdet.structures.bbox import bbox_overlaps +from mmengine.structures import InstanceData +from torch import Tensor + + +@MODELS.register_module() +class MasaTaoTracker(BaseTracker): + """Tracker for MASA on TAO benchmark. + + Args: + init_score_thr (float): The cls_score threshold to + initialize a new tracklet. Defaults to 0.8. + obj_score_thr (float): The cls_score threshold to + update a tracked tracklet. Defaults to 0.5. + match_score_thr (float): The match threshold. Defaults to 0.5. + memo_tracklet_frames (int): The most frames in a tracklet memory. + Defaults to 10. + memo_momentum (float): The momentum value for embeds updating. + Defaults to 0.8. + distractor_score_thr (float): The score threshold to consider an object as a distractor. + Defaults to 0.5. + distractor_nms_thr (float): The NMS threshold for filtering out distractors. + Defaults to 0.3. + with_cats (bool): Whether to track with the same category. + Defaults to True. + match_metric (str): The match metric. Can be 'bisoftmax', 'softmax', or 'cosine'. Defaults to 'bisoftmax'. + max_distance (float): Maximum distance for considering matches. Defaults to -1. + fps (int): Frames per second of the input video. Used for calculating growth factor. Defaults to 1. + """ + + def __init__( + self, + init_score_thr: float = 0.8, + obj_score_thr: float = 0.5, + match_score_thr: float = 0.5, + memo_tracklet_frames: int = 10, + memo_momentum: float = 0.8, + distractor_score_thr: float = 0.5, + distractor_nms_thr=0.3, + with_cats: bool = True, + max_distance: float = -1, + fps=1, + **kwargs + ): + super().__init__(**kwargs) + assert 0 <= memo_momentum <= 1.0 + assert memo_tracklet_frames >= 0 + self.init_score_thr = init_score_thr + self.obj_score_thr = obj_score_thr + self.match_score_thr = match_score_thr + self.memo_tracklet_frames = memo_tracklet_frames + self.memo_momentum = memo_momentum + self.distractor_score_thr = distractor_score_thr + self.distractor_nms_thr = distractor_nms_thr + self.with_cats = with_cats + + self.num_tracks = 0 + self.tracks = dict() + self.backdrops = [] + self.max_distance = max_distance # Maximum distance for considering matches + self.fps = fps + self.growth_factor = self.fps / 6 # Growth factor for the distance mask + self.distance_smoothing_factor = 100 / self.fps + + def reset(self): + """Reset the buffer of the tracker.""" + self.num_tracks = 0 + self.tracks = dict() + self.backdrops = [] + + def update( + self, + ids: Tensor, + bboxes: Tensor, + embeds: Tensor, + labels: Tensor, + scores: Tensor, + frame_id: int, + ) -> None: + """Tracking forward function. + + Args: + ids (Tensor): of shape(N, ). + bboxes (Tensor): of shape (N, 5). + embeds (Tensor): of shape (N, 256). + labels (Tensor): of shape (N, ). + scores (Tensor): of shape (N, ). + frame_id (int): The id of current frame, 0-index. + """ + tracklet_inds = ids > -1 + + for id, bbox, embed, label, score in zip( + ids[tracklet_inds], + bboxes[tracklet_inds], + embeds[tracklet_inds], + labels[tracklet_inds], + scores[tracklet_inds], + ): + id = int(id) + # update the tracked ones and initialize new tracks + if id in self.tracks.keys(): + self.tracks[id]["bbox"] = bbox + self.tracks[id]["embed"] = (1 - self.memo_momentum) * self.tracks[id][ + "embed" + ] + self.memo_momentum * embed + self.tracks[id]["last_frame"] = frame_id + self.tracks[id]["label"] = label + self.tracks[id]["score"] = score + else: + self.tracks[id] = dict( + bbox=bbox, + embed=embed, + label=label, + score=score, + last_frame=frame_id, + ) + + # pop memo + invalid_ids = [] + for k, v in self.tracks.items(): + if frame_id - v["last_frame"] >= self.memo_tracklet_frames: + invalid_ids.append(k) + for invalid_id in invalid_ids: + self.tracks.pop(invalid_id) + + @property + def memo(self) -> Tuple[Tensor, ...]: + """Get tracks memory.""" + memo_embeds = [] + memo_ids = [] + memo_bboxes = [] + memo_labels = [] + memo_frame_ids = [] + + # get tracks + for k, v in self.tracks.items(): + memo_bboxes.append(v["bbox"][None, :]) + memo_embeds.append(v["embed"][None, :]) + memo_ids.append(k) + memo_labels.append(v["label"].view(1, 1)) + memo_frame_ids.append(v["last_frame"]) + + memo_ids = torch.tensor(memo_ids, dtype=torch.long).view(1, -1) + memo_bboxes = torch.cat(memo_bboxes, dim=0) + memo_embeds = torch.cat(memo_embeds, dim=0) + memo_labels = torch.cat(memo_labels, dim=0).squeeze(1) + memo_frame_ids = torch.tensor(memo_frame_ids, dtype=torch.long).view(1, -1) + + return ( + memo_bboxes, + memo_labels, + memo_embeds, + memo_ids.squeeze(0), + memo_frame_ids.squeeze(0), + ) + + def compute_distance_mask(self, bboxes1, bboxes2, frame_ids1, frame_ids2): + """Compute a mask based on the pairwise center distances and frame IDs with piecewise soft-weighting.""" + centers1 = (bboxes1[:, :2] + bboxes1[:, 2:]) / 2.0 + centers2 = (bboxes2[:, :2] + bboxes2[:, 2:]) / 2.0 + distances = torch.cdist(centers1, centers2) + + frame_id_diff = torch.abs(frame_ids1[:, None] - frame_ids2[None, :]).to( + distances.device + ) + + # Define a scaling factor for the distance based on frame difference (exponential growth) + scaling_factor = torch.exp(frame_id_diff.float() / self.growth_factor) + + # Apply the scaling factor to max_distance + adaptive_max_distance = self.max_distance * scaling_factor + + # Create a piecewise function for soft gating + soft_distance_mask = torch.where( + distances <= adaptive_max_distance, + torch.ones_like(distances), + torch.exp( + -(distances - adaptive_max_distance) / self.distance_smoothing_factor + ), + ) + + return soft_distance_mask + + def track( + self, + model: torch.nn.Module, + img: torch.Tensor, + feats: List[torch.Tensor], + data_sample: TrackDataSample, + rescale=True, + with_segm=False, + **kwargs + ) -> InstanceData: + """Tracking forward function. + + Args: + model (nn.Module): MOT model. + img (Tensor): of shape (T, C, H, W) encoding input image. + Typically these should be mean centered and std scaled. + The T denotes the number of key images and usually is 1. + feats (list[Tensor]): Multi level feature maps of `img`. + data_sample (:obj:`TrackDataSample`): The data sample. + It includes information such as `pred_instances`. + rescale (bool, optional): If True, the bounding boxes should be + rescaled to fit the original scale of the image. Defaults to + True. + + Returns: + :obj:`InstanceData`: Tracking results of the input images. + Each InstanceData usually contains ``bboxes``, ``labels``, + ``scores`` and ``instances_id``. + """ + metainfo = data_sample.metainfo + bboxes = data_sample.pred_instances.bboxes + labels = data_sample.pred_instances.labels + scores = data_sample.pred_instances.scores + + frame_id = metainfo.get("frame_id", -1) + # create pred_track_instances + pred_track_instances = InstanceData() + + # return zero bboxes if there is no track targets + if bboxes.shape[0] == 0: + ids = torch.zeros_like(labels) + pred_track_instances = data_sample.pred_instances.clone() + pred_track_instances.instances_id = ids + pred_track_instances.mask_inds = torch.zeros_like(labels) + return pred_track_instances + + # get track feats + rescaled_bboxes = bboxes.clone() + if rescale: + scale_factor = rescaled_bboxes.new_tensor(metainfo["scale_factor"]).repeat( + (1, 2) + ) + rescaled_bboxes = rescaled_bboxes * scale_factor + track_feats = model.track_head.predict(feats, [rescaled_bboxes]) + # sort according to the object_score + _, inds = scores.sort(descending=True) + bboxes = bboxes[inds] + scores = scores[inds] + labels = labels[inds] + embeds = track_feats[inds, :] + if with_segm: + mask_inds = torch.arange(bboxes.size(0)).to(embeds.device) + mask_inds = mask_inds[inds] + else: + mask_inds = [] + + bboxes, labels, scores, embeds, mask_inds = self.remove_distractor( + bboxes, + labels, + scores, + track_feats=embeds, + mask_inds=mask_inds, + nms="inter", + distractor_score_thr=self.distractor_score_thr, + distractor_nms_thr=self.distractor_nms_thr, + ) + + # init ids container + ids = torch.full((bboxes.size(0),), -1, dtype=torch.long) + + # match if buffer is not empty + if bboxes.size(0) > 0 and not self.empty: + ( + memo_bboxes, + memo_labels, + memo_embeds, + memo_ids, + memo_frame_ids, + ) = self.memo + + feats = torch.mm(embeds, memo_embeds.t()) + d2t_scores = feats.softmax(dim=1) + t2d_scores = feats.softmax(dim=0) + match_scores_bisoftmax = (d2t_scores + t2d_scores) / 2 + + match_scores_cosine = torch.mm( + F.normalize(embeds, p=2, dim=1), + F.normalize(memo_embeds, p=2, dim=1).t(), + ) + + match_scores = (match_scores_bisoftmax + match_scores_cosine) / 2 + + if self.max_distance != -1: + + # Compute the mask based on spatial proximity + current_frame_ids = torch.full( + (bboxes.size(0),), frame_id, dtype=torch.long + ) + distance_mask = self.compute_distance_mask( + bboxes, memo_bboxes, current_frame_ids, memo_frame_ids + ) + + # Apply the mask to the match scores + match_scores = match_scores * distance_mask + + # track according to match_scores + for i in range(bboxes.size(0)): + conf, memo_ind = torch.max(match_scores[i, :], dim=0) + id = memo_ids[memo_ind] + if conf > self.match_score_thr: + if id > -1: + # keep bboxes with high object score + # and remove background bboxes + if scores[i] > self.obj_score_thr: + ids[i] = id + match_scores[:i, memo_ind] = 0 + match_scores[i + 1 :, memo_ind] = 0 + + # initialize new tracks + new_inds = (ids == -1) & (scores > self.init_score_thr).cpu() + num_news = new_inds.sum() + ids[new_inds] = torch.arange( + self.num_tracks, self.num_tracks + num_news, dtype=torch.long + ) + self.num_tracks += num_news + + self.update(ids, bboxes, embeds, labels, scores, frame_id) + tracklet_inds = ids > -1 + # update pred_track_instances + pred_track_instances.bboxes = bboxes[tracklet_inds] + pred_track_instances.labels = labels[tracklet_inds] + pred_track_instances.scores = scores[tracklet_inds] + pred_track_instances.instances_id = ids[tracklet_inds] + if with_segm: + pred_track_instances.mask_inds = mask_inds[tracklet_inds] + + return pred_track_instances + + def remove_distractor( + self, + bboxes, + labels, + scores, + track_feats, + mask_inds=[], + distractor_score_thr=0.5, + distractor_nms_thr=0.3, + nms="inter", + ): + # all objects is valid here + valid_inds = labels > -1 + # nms + low_inds = torch.nonzero(scores < distractor_score_thr, as_tuple=False).squeeze( + 1 + ) + if nms == "inter": + ious = bbox_overlaps(bboxes[low_inds, :], bboxes[:, :]) + elif nms == "intra": + cat_same = labels[low_inds].view(-1, 1) == labels.view(1, -1) + ious = bbox_overlaps(bboxes[low_inds, :], bboxes) + ious *= cat_same.to(ious.device) + else: + raise NotImplementedError + + for i, ind in enumerate(low_inds): + if (ious[i, :ind] > distractor_nms_thr).any(): + valid_inds[ind] = False + + bboxes = bboxes[valid_inds] + labels = labels[valid_inds] + scores = scores[valid_inds] + if track_feats is not None: + track_feats = track_feats[valid_inds] + + if len(mask_inds) > 0: + mask_inds = mask_inds[valid_inds] + + return bboxes, labels, scores, track_feats, mask_inds diff --git a/masa/version.py b/masa/version.py new file mode 100644 index 0000000000000000000000000000000000000000..37036c00f69ddb3ef78b937b32af231a36fbda22 --- /dev/null +++ b/masa/version.py @@ -0,0 +1,26 @@ +__version__ = "0.1.0" + + +def parse_version_info(version_str): + """Parse a version string into a tuple. + + Args: + version_str (str): The version string. + Returns: + tuple[int | str]: The version info, e.g., "1.3.0" is parsed into + (1, 3, 0), and "2.0.0rc1" is parsed into (2, 0, 0, 'rc1'). + """ + version_info = [] + for x in version_str.split("."): + if x.isdigit(): + version_info.append(int(x)) + elif x.find("rc") != -1: + patch_version = x.split("rc") + version_info.append(int(patch_version[0])) + version_info.append(f"rc{patch_version[1]}") + return tuple(version_info) + + +version_info = parse_version_info(__version__) + +__all__ = ["__version__", "version_info", "parse_version_info"] diff --git a/masa/visualization/__init__.py b/masa/visualization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d348f65492f0e4b01d5403a1b55da2753fb080c9 --- /dev/null +++ b/masa/visualization/__init__.py @@ -0,0 +1,3 @@ +from .visualizer import MasaTrackLocalVisualizer + +__all__ = ["MasaTrackLocalVisualizer"] diff --git a/masa/visualization/__pycache__/__init__.cpython-311.pyc b/masa/visualization/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..675da7c1a29efd8a47f939b979adb2b50ea5621b Binary files /dev/null and b/masa/visualization/__pycache__/__init__.cpython-311.pyc differ diff --git a/masa/visualization/__pycache__/visualizer.cpython-311.pyc b/masa/visualization/__pycache__/visualizer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..250445731792156736b4ba51d13615d7e6009f3b Binary files /dev/null and b/masa/visualization/__pycache__/visualizer.cpython-311.pyc differ diff --git a/masa/visualization/visualizer.py b/masa/visualization/visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..efd00b67a98108446ab1c859aeb3e4e7dbfbe847 --- /dev/null +++ b/masa/visualization/visualizer.py @@ -0,0 +1,249 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Tuple, Union + +import mmcv +import numpy as np + +try: + import seaborn as sns +except ImportError: + sns = None +from mmdet.registry import VISUALIZERS +from mmdet.structures import DetDataSample +from mmdet.structures.mask import bitmap_to_polygon +from mmdet.visualization.palette import _get_adaptive_scales +from mmengine.dist import master_only +from mmengine.structures import InstanceData, PixelData +from mmengine.visualization import Visualizer + + +def random_color(seed): + """Random a color according to the input seed.""" + if sns is None: + raise RuntimeError( + "motmetrics is not installed,\ + please install it by: pip install seaborn" + ) + np.random.seed(seed) + colors = sns.color_palette("tab20") + color = colors[np.random.choice(range(len(colors)))] + color = tuple([int(255 * c) for c in color]) + return color + + +@VISUALIZERS.register_module() +class MasaTrackLocalVisualizer(Visualizer): + """Tracking Local Visualizer for the MOT, VIS tasks. + + Args: + name (str): Name of the instance. Defaults to 'visualizer'. + image (np.ndarray, optional): the origin image to draw. The format + should be RGB. Defaults to None. + vis_backends (list, optional): Visual backend config list. + Defaults to None. + save_dir (str, optional): Save file dir for all storage backends. + If it is None, the backend storage will not save any data. + line_width (int, float): The linewidth of lines. + Defaults to 3. + alpha (int, float): The transparency of bboxes or mask. + Defaults to 0.8. + """ + + def __init__( + self, + name: str = "visualizer", + image: Optional[np.ndarray] = None, + vis_backends: Optional[Dict] = None, + save_dir: Optional[str] = None, + line_width: Union[int, float] = 5, + alpha: float = 0.8, + show_polygon: bool = False, + texts=None, + ) -> None: + super().__init__(name, image, vis_backends, save_dir) + self.line_width = line_width + self.alpha = alpha + self.show_polygon = show_polygon + # Set default value. When calling + # `TrackLocalVisualizer().dataset_meta=xxx`, + # it will override the default value. + self.dataset_meta = {} + if texts is not None: + if isinstance(texts, str): + if not texts.endswith("."): + original_caption = texts + " . " + original_caption = original_caption.split(" . ") + class_names = list(filter(lambda x: len(x) > 0, original_caption)) + else: + class_names = list(texts) + self.label_names = class_names + else: + self.label_names = None + + def _draw_instances( + self, image: np.ndarray, instances: ["InstanceData"] + ) -> np.ndarray: + """Draw instances of GT or prediction. + + Args: + image (np.ndarray): The image to draw. + instances (:obj:`InstanceData`): Data structure for + instance-level annotations or predictions. + Returns: + np.ndarray: the drawn image which channel is RGB. + """ + self.set_image(image) + classes = self.dataset_meta.get("classes", None) + + # get colors and texts + # for the MOT and VIS tasks + colors = [random_color(_id) for _id in instances.instances_id] + + # draw bboxes and texts + if "bboxes" in instances: + # draw bboxes + bboxes = instances.bboxes.clone() + labels = instances.labels.clone() + + self.draw_bboxes( + bboxes, + edge_colors=colors, + alpha=self.alpha, + line_widths=self.line_width, + ) + # draw texts + positions = bboxes[:, :2] - self.line_width + areas = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0]) + scales = _get_adaptive_scales(areas.cpu().numpy()) + + for i, (pos, label) in enumerate(zip(positions, labels)): + if self.label_names is not None: + label_text = self.label_names[label] + else: + label_text = ( + classes[label] if classes is not None else f"class {label}" + ) + + if "instances_id" in instances: + label_text += f" | {instances.instances_id[i]}" + + if "scores" in instances: + score = round(float(instances.scores[i]) * 100, 1) + label_text += f": {score}" + + self.draw_texts( + label_text, + pos, + colors="black", + font_sizes=int(13 * scales[i]), + bboxes=[ + { + "facecolor": [c / 255 for c in colors[i]], + "alpha": 0.8, + "pad": 0.7, + "edgecolor": "none", + } + ], + ) + + # draw masks + if "masks" in instances: + masks = instances.masks + polygons = [] + for i, mask in enumerate(masks): + contours, _ = bitmap_to_polygon(mask) + polygons.extend(contours) + if self.show_polygon: + self.draw_polygons(polygons, edge_colors="w", alpha=self.alpha) + self.draw_binary_masks(masks, colors=colors, alphas=self.alpha) + + return self.get_image() + + @master_only + def add_datasample( + self, + name: str, + image: np.ndarray, + data_sample: DetDataSample = None, + draw_gt: bool = True, + draw_pred: bool = True, + show: bool = False, + wait_time: int = 0, + # TODO: Supported in mmengine's Viusalizer. + out_file: Optional[str] = None, + pred_score_thr: float = 0.3, + vis_score=False, + step: int = 0, + fps=None, + ) -> None: + """Draw datasample and save to all backends. + + - If GT and prediction are plotted at the same time, they are + displayed in a stitched image where the left image is the + ground truth and the right image is the prediction. + - If ``show`` is True, all storage backends are ignored, and + the images will be displayed in a local window. + - If ``out_file`` is specified, the drawn image will be + saved to ``out_file``. t is usually used when the display + is not available. + Args: + name (str): The image identifier. + image (np.ndarray): The image to draw. + data_sample (OptTrackSampleList): A data + sample that contain annotations and predictions. + Defaults to None. + draw_gt (bool): Whether to draw GT TrackDataSample. + Default to True. + draw_pred (bool): Whether to draw Prediction TrackDataSample. + Defaults to True. + show (bool): Whether to display the drawn image. Default to False. + wait_time (int): The interval of show (s). Defaults to 0. + out_file (str): Path to output file. Defaults to None. + pred_score_thr (float): The threshold to visualize the bboxes + and masks. Defaults to 0.3. + step (int): Global step value to record. Defaults to 0. + """ + gt_img_data = None + pred_img_data = None + + if data_sample is not None: + data_sample = data_sample.cpu() + + if draw_gt and data_sample is not None: + assert "gt_instances" in data_sample + gt_img_data = self._draw_instances(image, data_sample.gt_instances) + + if draw_pred and data_sample is not None: + assert "pred_track_instances" in data_sample + pred_instances = data_sample.pred_track_instances + if "scores" in pred_instances: + pred_instances = pred_instances[ + pred_instances.scores > pred_score_thr + ].cpu() + pred_img_data = self._draw_instances(image, pred_instances) + + if fps is not None: + self.draw_texts( + f"FPS: {fps: .1f}", + np.array([10, 10]), + colors="black", + font_sizes=15, + bboxes=[ + {"facecolor": "w", "alpha": 0.8, "pad": 0.7, "edgecolor": "none"} + ], + ) + + if gt_img_data is not None and pred_img_data is not None: + drawn_img = np.concatenate((gt_img_data, pred_img_data), axis=1) + elif gt_img_data is not None: + drawn_img = gt_img_data + else: + drawn_img = pred_img_data + + if show: + self.show(drawn_img, win_name=name, wait_time=wait_time) + + if out_file is not None: + mmcv.imwrite(drawn_img[..., ::-1], out_file) + else: + self.add_image(name, drawn_img, step) diff --git a/saved_models/masa_models/gdino_masa.pth b/saved_models/masa_models/gdino_masa.pth new file mode 100644 index 0000000000000000000000000000000000000000..7ba02757229ba7ba0a2b3cd2c33f04ff510ae549 --- /dev/null +++ b/saved_models/masa_models/gdino_masa.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:15f4380ae4b17ac98c80a9a56bf6e569bac5bc21592906ccd3945a97438424ad +size 1091120775 diff --git a/tools/dist_test.sh b/tools/dist_test.sh new file mode 100755 index 0000000000000000000000000000000000000000..dea131b43ea8f1222661d20603d40c18ea7f28a1 --- /dev/null +++ b/tools/dist_test.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash + +CONFIG=$1 +CHECKPOINT=$2 +GPUS=$3 +NNODES=${NNODES:-1} +NODE_RANK=${NODE_RANK:-0} +PORT=${PORT:-29500} +MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ +python -m torch.distributed.launch \ + --nnodes=$NNODES \ + --node_rank=$NODE_RANK \ + --master_addr=$MASTER_ADDR \ + --nproc_per_node=$GPUS \ + --master_port=$PORT \ + $(dirname "$0")/test.py \ + $CONFIG \ + $CHECKPOINT \ + --launcher pytorch \ + ${@:4} diff --git a/tools/test.py b/tools/test.py new file mode 100644 index 0000000000000000000000000000000000000000..85fb051cb041f1e5ca6495b8ca1b30ad1fbbeef3 --- /dev/null +++ b/tools/test.py @@ -0,0 +1,156 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os +import sys +import os.path as osp +import warnings +from copy import deepcopy + +from mmengine import ConfigDict +from mmengine.config import Config, DictAction +from mmengine.runner import Runner + +from mmdet.engine.hooks.utils import trigger_visualization_hook +from mmdet.evaluation import DumpDetResults +from mmdet.registry import RUNNERS +from mmdet.utils import setup_cache_size_limit_of_dynamo + +# Correct the path to point directly to the root of your project where 'masa' is located +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) +sys.path.insert(0, project_root) + +import masa +import projects.Detic_new.detic + +# TODO: support fuse_conv_bn and format_only +def parse_args(): + parser = argparse.ArgumentParser( + description='MASA test (and eval) a model') + parser.add_argument('config', help='test config file path') + parser.add_argument('checkpoint', help='checkpoint file') + parser.add_argument( + '--work-dir', + help='the directory to save the file containing evaluation metrics') + parser.add_argument( + '--out', + type=str, + help='dump predictions to a pickle file for offline evaluation') + parser.add_argument( + '--show', action='store_true', help='show prediction results') + parser.add_argument( + '--show-dir', + help='directory where painted images will be saved. ' + 'If specified, it will be automatically saved ' + 'to the work_dir/timestamp/show_dir') + parser.add_argument( + '--wait-time', type=float, default=2, help='the interval of show (s)') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') + parser.add_argument('--tta', action='store_true') + # When using PyTorch version >= 2.0.0, the `torch.distributed.launch` + # will pass the `--local-rank` parameter to `tools/train.py` instead + # of `--local_rank`. + parser.add_argument('--local_rank', '--local-rank', type=int, default=0) + args = parser.parse_args() + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = str(args.local_rank) + return args + + +def main(): + args = parse_args() + + # Reduce the number of repeated compilations and improve + # testing speed. + setup_cache_size_limit_of_dynamo() + + # load config + cfg = Config.fromfile(args.config) + cfg.launcher = args.launcher + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + # work_dir is determined in this priority: CLI > segment in file > filename + if args.work_dir is not None: + # update configs according to CLI args if args.work_dir is not None + cfg.work_dir = args.work_dir + elif cfg.get('work_dir', None) is None: + # use config filename as default work_dir if cfg.work_dir is None + cfg.work_dir = osp.join('./work_dirs', + osp.splitext(osp.basename(args.config))[0]) + + cfg.load_from = args.checkpoint + + if args.show or args.show_dir: + cfg = trigger_visualization_hook(cfg, args) + + if args.tta: + + if 'tta_model' not in cfg: + warnings.warn('Cannot find ``tta_model`` in config, ' + 'we will set it as default.') + cfg.tta_model = dict( + type='DetTTAModel', + tta_cfg=dict( + nms=dict(type='nms', iou_threshold=0.5), max_per_img=100)) + if 'tta_pipeline' not in cfg: + warnings.warn('Cannot find ``tta_pipeline`` in config, ' + 'we will set it as default.') + test_data_cfg = cfg.test_dataloader.dataset + while 'dataset' in test_data_cfg: + test_data_cfg = test_data_cfg['dataset'] + cfg.tta_pipeline = deepcopy(test_data_cfg.pipeline) + flip_tta = dict( + type='TestTimeAug', + transforms=[ + [ + dict(type='RandomFlip', prob=1.), + dict(type='RandomFlip', prob=0.) + ], + [ + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', + 'img_shape', 'scale_factor', 'flip', + 'flip_direction')) + ], + ]) + cfg.tta_pipeline[-1] = flip_tta + cfg.model = ConfigDict(**cfg.tta_model, module=cfg.model) + cfg.test_dataloader.dataset.pipeline = cfg.tta_pipeline + + # build the runner from config + if 'runner_type' not in cfg: + # build the default runner + runner = Runner.from_cfg(cfg) + else: + # build customized runner from the registry + # if 'runner_type' is set in the cfg + runner = RUNNERS.build(cfg) + + # add `DumpResults` dummy metric + if args.out is not None: + assert args.out.endswith(('.pkl', '.pickle')), \ + 'The dump file must be a pkl file.' + runner.test_evaluator.metrics.append( + DumpDetResults(out_file_path=args.out)) + + # start testing + runner.test() + + +if __name__ == '__main__': + main() \ No newline at end of file