KyanChen commited on
Commit
6c06d1a
1 Parent(s): 3299e88

Upload 34 files

Browse files
Files changed (32) hide show
  1. .gitattributes +0 -4
  2. App_main.py +109 -0
  3. qs_run.sh +32 -0
  4. readme 2.md +117 -0
  5. readme_cn.md +116 -0
  6. requirements.txt +26 -0
  7. tools/__init__.py +0 -0
  8. tools/ins_seg/analysis_tools/browse_dataset_mmdet_mmyolo_pl.py +266 -0
  9. tools/ins_seg/analysis_tools/dataset_analysis.py +518 -0
  10. tools/ins_seg/dataset_converters/cityscapes.py +152 -0
  11. tools/ins_seg/dataset_converters/whu_building_convert.py +143 -0
  12. tools/ins_seg/sam/sam_cls/get_sam_cls_crops.py +120 -0
  13. tools/ins_seg/sam/sam_cls/get_sam_cls_metrics.py +170 -0
  14. tools/ins_seg/sam/sam_cls/segment_anything/__init__.py +15 -0
  15. tools/ins_seg/sam/sam_cls/segment_anything/automatic_mask_generator.py +372 -0
  16. tools/ins_seg/sam/sam_cls/segment_anything/build_sam.py +107 -0
  17. tools/ins_seg/sam/sam_cls/segment_anything/modeling/__init__.py +11 -0
  18. tools/ins_seg/sam/sam_cls/segment_anything/modeling/common.py +43 -0
  19. tools/ins_seg/sam/sam_cls/segment_anything/modeling/image_encoder.py +395 -0
  20. tools/ins_seg/sam/sam_cls/segment_anything/modeling/mask_decoder.py +176 -0
  21. tools/ins_seg/sam/sam_cls/segment_anything/modeling/prompt_encoder.py +214 -0
  22. tools/ins_seg/sam/sam_cls/segment_anything/modeling/sam.py +174 -0
  23. tools/ins_seg/sam/sam_cls/segment_anything/modeling/transformer.py +240 -0
  24. tools/ins_seg/sam/sam_cls/segment_anything/predictor.py +269 -0
  25. tools/ins_seg/sam/sam_cls/segment_anything/utils/__init__.py +5 -0
  26. tools/ins_seg/sam/sam_cls/segment_anything/utils/amg.py +346 -0
  27. tools/ins_seg/sam/sam_cls/segment_anything/utils/onnx.py +144 -0
  28. tools/ins_seg/sam/sam_cls/segment_anything/utils/transforms.py +102 -0
  29. tools/predict.py +46 -0
  30. tools/test.py +43 -0
  31. tools/train.py +48 -0
  32. visualizer/test_img.jpg +0 -0
.gitattributes CHANGED
@@ -33,7 +33,3 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
- data/WHU/annotations/WHU_building_test.json filter=lfs diff=lfs merge=lfs -text
37
- data/WHU/annotations/WHU_building_train.json filter=lfs diff=lfs merge=lfs -text
38
- mmpretrain/annotations/WHU_building_test.json filter=lfs diff=lfs merge=lfs -text
39
- mmpretrain/annotations/WHU_building_train.json filter=lfs diff=lfs merge=lfs -text
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
App_main.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import mmcv
3
+ import mmengine
4
+ import numpy as np
5
+ import os
6
+ from mmengine import Config, get
7
+ from mmengine.dataset import Compose
8
+ from mmpl.registry import MODELS, VISUALIZERS
9
+ from mmpl.utils import register_all_modules
10
+ register_all_modules()
11
+ # os.system('nvidia-smi')
12
+ # os.system('ls /usr/local')
13
+ # os.system('pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu117')
14
+ # os.system('pip install -U openmim')
15
+ # os.system('mim install mmcv==2.0.0')
16
+ # os.system('mim install mmengine')
17
+
18
+ import gradio as gr
19
+ import torch
20
+
21
+ device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
22
+
23
+
24
+ def construct_sample(img, pipeline):
25
+ img = np.array(img)[:, :, ::-1]
26
+ inputs = {
27
+ 'ori_shape': img.shape[:2],
28
+ 'img': img,
29
+ }
30
+ pipeline = Compose(pipeline)
31
+ sample = pipeline(inputs)
32
+ return sample
33
+
34
+ def build_model(cp, model_cfg):
35
+ model_cpkt = torch.load(cp, map_location='cpu')
36
+ model = MODELS.build(model_cfg)
37
+ model.load_state_dict(model_cpkt, strict=True)
38
+ model.to(device=device)
39
+ model.eval()
40
+ return model
41
+
42
+
43
+ # Function for building extraction
44
+ def inference_func(ori_img, cp):
45
+ checkpoint = f'pretrain/{cp}_anchor.pth'
46
+ cfg = f'configs/huggingface/rsprompter_anchor_{cp}_config.py'
47
+ cfg = Config.fromfile(cfg)
48
+ sample = construct_sample(ori_img, cfg.predict_pipeline)
49
+ sample['inputs'] = [sample['inputs']]
50
+ sample['data_samples'] = [sample['data_samples']]
51
+
52
+ print('Use: ', device)
53
+ model = build_model(checkpoint, cfg.model_cfg)
54
+
55
+ with torch.no_grad():
56
+ pred_results = model.predict_step(sample, batch_idx=0)
57
+
58
+ cfg.visualizer.setdefault('save_dir', 'visualizer')
59
+ visualizer = VISUALIZERS.build(cfg.visualizer)
60
+
61
+ data_sample = pred_results[0]
62
+ img = np.array(ori_img).copy()
63
+ out_file = 'visualizer/test_img.jpg'
64
+ mmengine.mkdir_or_exist(os.path.dirname(out_file))
65
+ visualizer.add_datasample(
66
+ 'test_img',
67
+ img,
68
+ draw_gt=False,
69
+ data_sample=data_sample,
70
+ show=False,
71
+ wait_time=0.01,
72
+ pred_score_thr=0.4,
73
+ out_file=out_file
74
+ )
75
+ img_bytes = get(out_file)
76
+ img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
77
+ return img
78
+
79
+ title = "RSPrompter"
80
+ description = "Gradio demo for RSPrompter. Upload image from WHU building dataset, NWPU dataset, or SSDD Dataset or click any one of the examples, " \
81
+ "Then select the prompt model, and click \"Submit\" and wait for the result. \n \n" \
82
+ "Paper: RSPrompter: Learning to Prompt for Remote Sensing Instance Segmentation based on Visual Foundation Model"
83
+
84
+ article = "<p style='text-align: center'><a href='https://kyanchen.github.io/RSPrompter/' target='_blank'>RSPrompter Project " \
85
+ "Page</a></p> "
86
+
87
+ files = glob.glob('examples/NWPU*')
88
+ examples = [[f, f.split('/')[-1].split('_')[0]] for f in files]
89
+
90
+ with gr.Blocks() as demo:
91
+ image_input = gr.Image(type='pil', label='Input Img')
92
+ # with gr.Row().style(equal_height=True):
93
+ # image_LR_output = gr.outputs.Image(label='LR Img', type='numpy')
94
+ image_output = gr.Image(label='Segment Result', type='numpy')
95
+ with gr.Row():
96
+ checkpoint = gr.Radio(['WHU', 'NWPU', 'SSDD'], label='Checkpoint')
97
+
98
+ io = gr.Interface(fn=inference_func,
99
+ inputs=[image_input, checkpoint],
100
+ outputs=[image_output],
101
+ title=title,
102
+ description=description,
103
+ article=article,
104
+ allow_flagging='auto',
105
+ examples=examples,
106
+ cache_examples=True,
107
+ layout="grid"
108
+ )
109
+ io.launch()
qs_run.sh ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ source ~/.bashrc
3
+ conda activate torch2mmcv2 # torch1mmcv1 torch1mmcv2 torch2mmcv1 torch2mmcv2
4
+ pip install albumentations
5
+ pip install importlib_metadata
6
+ pip install --upgrade mmengine
7
+ pip install instaboostfast
8
+ #pip install deepspeed
9
+ # pip install anypackage
10
+ # yum install which
11
+ # source /opt/rh/devtoolset-9/enable
12
+ # mim install mmcv>=2.0.0rc4
13
+
14
+ cd /mnt/search01/usr/chenkeyan/codes/lightning_framework
15
+ #TORCH_DISTRIBUTED_DEBUG=DETAIL
16
+ case $# in
17
+ 0)
18
+ python tools/train.py
19
+ ;;
20
+ 1)
21
+ python tools/train.py --config $1
22
+ ;;
23
+ 2)
24
+ python tools/train.py --config $1 --ckpt-path $2
25
+ ;;
26
+ esac
27
+ # TORCH_DISTRIBUTED_DEBUG=DETAIL
28
+ #python train.py
29
+ #python -m torch.distributed.launch --nproc_per_node=$GPU_NUM --nnodes=$WORLD_SIZE --node_rank=$RANK --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT --use_env train.py
30
+ #python -m torch.distributed.launch --nproc_per_node=$GPU_NUM --nnodes=$WORLD_SIZE --node_rank=$RANK --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT --use_env train_pipe.py
31
+ # juicesync src dst
32
+ # juicefs rmr your_dir
readme 2.md ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RSPrompter: Learning to Prompt for Remote Sensing Instance Segmentation based on Visual Foundation Model
2
+
3
+ English | [简体中文](/readme_cn.md)
4
+
5
+ This is the pytorch implement of our paper "RSPrompter: Learning to Prompt for Remote Sensing Instance Segmentation based on Visual Foundation Model"
6
+
7
+
8
+ [Project Page](https://kyanchen.github.io/RSPrompter/) $\cdot$ [PDF Download](https://arxiv.org/abs/2306.16269) $\cdot$ [HuggingFace Demo](https://huggingface.co/spaces/KyanChen/RSPrompter)
9
+
10
+
11
+ ## 0. Environment Setup
12
+
13
+ ### 0.1 Create a virtual environment
14
+
15
+ ```shell
16
+ conda create -n RSPrompter python=3.10
17
+ ```
18
+
19
+ ### 0.2 Activate the virtual environment
20
+ ```sehll
21
+ conda activate RSPrompter
22
+ ```
23
+
24
+ ### 0.3 Install pytorch
25
+ Version of 1.x is also work, but the version of 2.x is recommended.
26
+ ```shell
27
+ pip install torch torchvision --index-url https://download.pytorch.org/whl/cu117
28
+ ```
29
+
30
+ ### 0.3 [Optional] Install pytorch
31
+ ```shell
32
+ conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia
33
+ ```
34
+
35
+ ### 0.4 Install mmcv
36
+ Version of 2.x is recommended.
37
+ ```shell
38
+ pip install mmcv==2.0.0 -f https://download.openmmlab.com/mmcv/dist/cu117/torch2.0/index.html
39
+ ```
40
+ Please refer to [installation documentation](https://mmcv.readthedocs.io/en/latest/get_started/installation.html) for more detailed installation.
41
+
42
+ ### 0.5 Install other dependencies
43
+ ```shell
44
+ pip install -r requirements.txt
45
+ ```
46
+
47
+ ## 1. Data Preparation
48
+
49
+ ### 1.1 Dataset
50
+
51
+ #### WHU Dataset
52
+ WHU dataset can be downloaded from [WHU](https://aistudio.baidu.com/aistudio/datasetdetail/56502). After downloading, put the dataset into the **data** folder, which contains some image examples.
53
+
54
+ #### NWPU Dataset
55
+ NWPU dataset can be downloaded from [NWPU](https://aistudio.baidu.com/aistudio/datasetdetail/52812). After downloading, put the dataset into the **data** folder, which contains some image examples.
56
+
57
+ #### SSDD Dataset
58
+ SSDD dataset can be downloaded from [SSDD](https://aistudio.baidu.com/aistudio/datasetdetail/100924). After downloading, put the dataset into the **data** folder, which contains some image examples.
59
+
60
+ #### 1.2 Split the dataset into train and test set
61
+ The dataset split files and annotation files are provided in this project, which are stored in the **data/*/annotations** folder in COCO annotation format.
62
+
63
+ ## 2. Model Training
64
+
65
+ ### 2.1 Train SAM-based model
66
+
67
+ #### 2.1.1 Config file
68
+ The config file is located in the **configs/rsprompter** folder, which can be modified according to the situation. The config file provides three models: SAM-seg, SAM-det, and RSPrompter.
69
+
70
+ #### 2.1.2 Train
71
+ Some parameters of the training can also be modified in the above configuration file. The main modification of the parameters in trainer_cfg, such as single-card multi-card training, etc., for specific configuration modifications, please refer to the Trainer of Pytorch Lightning.
72
+ ```shell
73
+ python tools/train.py
74
+ ```
75
+
76
+ ### 2.2 [Optional] Train other models
77
+ #### 2.2.1 Config file
78
+ The config file is located in the **configs/rsprompter** folder, which provides only the configuration of Mask R-CNN and Mask2Former. The configuration of other models can refer to these two configuration files and the model config in MMDetection.
79
+
80
+ #### 2.2.2 Train
81
+ Modify the config path in **tools/train.py** and then run
82
+ ```shell
83
+ python tools/train.py
84
+ ```
85
+
86
+ ## 3. Model Evaluation
87
+ The config file is located in the **configs/rsprompter** folder, which can be modified according to the situation.
88
+ When the val_evaluator and val_loader are configured in the configuration file, the model will automatically evaluate the model on the validation set during model training, and the evaluation results will be uploaded to Wandb and can be viewed in Wandb.
89
+ If you need to perform offline evaluation on the test set, you need to configure the test_evaluator and test_loader in the configuration file, as well as the config and ckpt-path paths in **tools/test.py**, and then run
90
+ ```shell
91
+ python tools/test.py
92
+ ```
93
+
94
+ ## 4. [Optional] Model Visualization
95
+ The config file is located in the **configs/rsprompter** folder, which can be modified according to the situation. You can modify the parameters of DetVisualizationHook and DetLocalVisualizer in the configuration file, as well as the config and ckpt-path paths in **tools/predict.py**, and then run
96
+ ```shell
97
+ python tools/predict.py
98
+ ```
99
+
100
+ ## 5. [Optional] Model Download
101
+ This project provides the model weights of RSPrompter-anchor, which are located in [huggingface space](https://huggingface.co/spaces/KyanChen/RSPrompter/tree/main/pretrain)
102
+
103
+ ## 6. [Optional] Citation
104
+ If you find this project useful for your research, please cite our paper.
105
+
106
+ If you have any other questions, please contact me!!!
107
+
108
+ ```
109
+ @misc{chen2023rsprompter,
110
+ title={RSPrompter: Learning to Prompt for Remote Sensing Instance Segmentation based on Visual Foundation Model},
111
+ author={Keyan Chen and Chenyang Liu and Hao Chen and Haotian Zhang and Wenyuan Li and Zhengxia Zou and Zhenwei Shi},
112
+ year={2023},
113
+ eprint={2306.16269},
114
+ archivePrefix={arXiv},
115
+ primaryClass={cs.CV}
116
+ }
117
+ ```
readme_cn.md ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RSPrompter: Learning to Prompt for Remote Sensing Instance Segmentation based on Visual Foundation Model
2
+
3
+ [English](/readme.md) | 简体中文
4
+
5
+
6
+ 本项目是论文"RSPrompter: Learning to Prompt for Remote Sensing Instance Segmentation based on Visual Foundation Model"的Pytorch实现
7
+
8
+
9
+ [项目主页](https://kyanchen.github.io/RSPrompter/) $\cdot$ [PDF下载](https://arxiv.org/abs/2306.16269) $\cdot$ [HuggingFace 样例](https://huggingface.co/spaces/KyanChen/RSPrompter)
10
+
11
+
12
+ ## 0. 环境准备
13
+ ### 0.1 建立虚拟环境
14
+ ```shell
15
+ conda create -n RSPrompter python=3.10
16
+ ```
17
+ ### 0.2 激活虚拟环境
18
+ ```sehll
19
+ conda activate RSPrompter
20
+ ```
21
+ ### 0.3 安装pytorch
22
+ 1.x版本也可以,但是建议使用2.x版本
23
+ ```shell
24
+ pip install torch torchvision --index-url https://download.pytorch.org/whl/cu117
25
+ ```
26
+ ### 0.3 [可选]安装pytorch
27
+ ```shell
28
+ conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia
29
+ ```
30
+ ### 0.4 安装mmcv
31
+ 建议2.x版本
32
+ ```shell
33
+ pip install mmcv==2.0.0 -f https://download.openmmlab.com/mmcv/dist/cu117/torch2.0/index.html
34
+ ```
35
+ 更多安装信息请参考[安装文档](https://mmcv.readthedocs.io/zh_CN/latest/get_started/installation.html)
36
+ ### 0.5 安装其他依赖
37
+ ```shell
38
+ pip install -r requirements.txt
39
+ ```
40
+
41
+ ## 1. 数据准备
42
+
43
+ ### 1.1 数据集
44
+
45
+ #### WHU数据集
46
+ WHU数据集可以从[WHU](https://aistudio.baidu.com/aistudio/datasetdetail/56502)下载,下载后将数据集放到**data**文件夹中,该文件夹放入了一些图像示例。
47
+
48
+ #### NWPU数据集
49
+ NWPU数据集可以从[NWPU](https://aistudio.baidu.com/aistudio/datasetdetail/52812)下载,下载后将数据集放到**data**文件夹中,该文件夹放入了一些图像示例。
50
+
51
+ #### SSDD数据集
52
+ SSDD数据集可以从[SSDD](https://aistudio.baidu.com/aistudio/datasetdetail/100924)下载,下载后将数据集放到**data**文件夹中,该文件夹放入了一些图像示例。
53
+
54
+ ### 1.2 划分训练测试集
55
+ 在本项目中已提供论文中的数据集划分文件和标注文件,以COCO标注格式存储,位于**data/*/annotations**文件夹中。
56
+
57
+ ## 2. 模型训练
58
+
59
+ ### 2.1 训练SAM-based模型
60
+
61
+ #### 2.1.1 配置文件
62
+ 配置文件位于**configs/rsprompter**文件夹中,可以依据情况修改该文件中的参数,提供了SAM-seg,SAM-det,RSPrompter三种模型的配置文件。
63
+
64
+ #### 2.1.2 训练
65
+ 训练的一些参数配置也可以在上述配置文件中修改,主要修改trainer_cfg中的参数,例如单卡多卡训练等,具体配置修改参考Pytorch Lightning的Trainer。
66
+ ```shell
67
+ python tools/train.py
68
+ ```
69
+
70
+
71
+ ### 2.2 [可选] 训练其他模型
72
+ #### 2.2.1 配置文件
73
+ 配置文件位于**configs/rsprompter**文件夹中,仅提供了Mask R-CNN和Mask2Former的配置,其他模型的配置可以参考这两个配置文件和MMDetection中的模型config进行修改。
74
+
75
+ #### 2.2.2 训练
76
+ 修改**tools/train.py**中的config路径,然后运行
77
+ ```shell
78
+ python tools/train.py
79
+ ```
80
+
81
+
82
+ ## 3. 模型评测
83
+
84
+ 模型配置文件位于**configs/rsprompter**文件夹中,可以依据情况修改该文件中的参数。
85
+ 当配置了该文件中val_evaluator和val_loader,在模型训练时,会自动进行模型在验证集上的评测,评测结果会上传到Wandb中,可以在Wandb中查看。
86
+ 如果需要在测试集上进行离线评测,需要配置配置文件中的test_evaluator和test_loader,以及**tools/test.py**中的config和ckpt-path路径,然后运行
87
+ ```shell
88
+ python tools/test.py
89
+ ```
90
+
91
+ ## 4. [可选]结果可视化
92
+ 模型配置文件位于**configs/rsprompter**文件夹中,可以依据情况修改该文件中的**DetVisualizationHook**和**DetLocalVisualizer**的参数,
93
+ 以及**tools/predict.py**中的config和ckpt-path路径,然后运行
94
+ ```shell
95
+ python tools/predict.py
96
+ ```
97
+
98
+
99
+ ## 5. [可选]模型下载
100
+ 本项目提供了RSPrompter-anchor的模型权重,位于[huggingface space](https://huggingface.co/spaces/KyanChen/RSPrompter/tree/main/pretrain)中
101
+
102
+ ## 6. [可选]引用
103
+ 如果您认为本项目对您的研究有所帮助,请引用我们的论文.
104
+
105
+ 如果您有其他问题,请联系我!!!
106
+
107
+ ```
108
+ @misc{chen2023rsprompter,
109
+ title={RSPrompter: Learning to Prompt for Remote Sensing Instance Segmentation based on Visual Foundation Model},
110
+ author={Keyan Chen and Chenyang Liu and Hao Chen and Haotian Zhang and Wenyuan Li and Zhengxia Zou and Zhenwei Shi},
111
+ year={2023},
112
+ eprint={2306.16269},
113
+ archivePrefix={arXiv},
114
+ primaryClass={cs.CV}
115
+ }
116
+ ```
requirements.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ mmcv>=2.0.0rc4
3
+ importlib_metadata
4
+ mat4py
5
+
6
+ torchvision
7
+ lightning==2.0.1
8
+ mmengine
9
+ openmim
10
+ wandb
11
+ redis
12
+ scikit-learn
13
+ scikit-image
14
+ tenacity
15
+ torchmetrics
16
+ tensorboardx
17
+ transformers
18
+ ipdb
19
+ prettytable
20
+ einops
21
+
22
+ imageio
23
+ pycocotools
24
+ shapely
25
+ terminaltables
26
+ albumentations
tools/__init__.py ADDED
File without changes
tools/ins_seg/analysis_tools/browse_dataset_mmdet_mmyolo_pl.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os.path as osp
3
+ import sys
4
+ from typing import Tuple
5
+
6
+ import cv2
7
+ import mmcv
8
+ import numpy as np
9
+ from mmdet.models.utils import mask2ndarray
10
+ from mmdet.structures.bbox import BaseBoxes
11
+ from mmengine.config import Config, DictAction
12
+ from mmengine.dataset import Compose
13
+ from mmengine.registry import init_default_scope
14
+ from mmengine.utils import ProgressBar
15
+ from mmengine.visualization import Visualizer
16
+
17
+ from mmyolo.registry import DATASETS, VISUALIZERS
18
+
19
+
20
+ # TODO: Support for printing the change in key of results
21
+ def parse_args():
22
+ parser = argparse.ArgumentParser(description='Browse a dataset')
23
+ parser.add_argument('--config', default='configs/ins_seg/seg_maskrcnn_whu_bs8_config.py', help='train config file path')
24
+ parser.add_argument(
25
+ '--phase',
26
+ '-p',
27
+ default='val',
28
+ type=str,
29
+ choices=['train', 'test', 'val'],
30
+ help='phase of dataset to visualize, accept "train" "test" and "val".'
31
+ ' Defaults to "train".')
32
+ parser.add_argument(
33
+ '--mode',
34
+ '-m',
35
+ default='transformed',
36
+ type=str,
37
+ choices=['original', 'transformed', 'pipeline'],
38
+ help='display mode; display original pictures or '
39
+ 'transformed pictures or comparison pictures. "original" '
40
+ 'means show images load from disk; "transformed" means '
41
+ 'to show images after transformed; "pipeline" means show all '
42
+ 'the intermediate images. Defaults to "transformed".')
43
+ parser.add_argument(
44
+ '--out-dir',
45
+ default='output',
46
+ type=str,
47
+ help='If there is no display interface, you can save it.')
48
+ parser.add_argument('--not-show', default=False, action='store_true')
49
+ parser.add_argument(
50
+ '--show-number',
51
+ '-n',
52
+ type=int,
53
+ default=sys.maxsize,
54
+ help='number of images selected to visualize, '
55
+ 'must bigger than 0. if the number is bigger than length '
56
+ 'of dataset, show all the images in dataset; '
57
+ 'default "sys.maxsize", show all images in dataset')
58
+ parser.add_argument(
59
+ '--show-interval',
60
+ '-i',
61
+ type=float,
62
+ default=3,
63
+ help='the interval of show (s)')
64
+ parser.add_argument(
65
+ '--cfg-options',
66
+ nargs='+',
67
+ action=DictAction,
68
+ help='override some settings in the used config, the key-value pair '
69
+ 'in xxx=yyy format will be merged into config file. If the value to '
70
+ 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
71
+ 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
72
+ 'Note that the quotation marks are necessary and that no white space '
73
+ 'is allowed.')
74
+ args = parser.parse_args()
75
+ return args
76
+
77
+
78
+ def _get_adaptive_scale(img_shape: Tuple[int, int],
79
+ min_scale: float = 0.3,
80
+ max_scale: float = 3.0) -> float:
81
+ """Get adaptive scale according to image shape.
82
+
83
+ The target scale depends on the the short edge length of the image. If the
84
+ short edge length equals 224, the output is 1.0. And output linear
85
+ scales according the short edge length. You can also specify the minimum
86
+ scale and the maximum scale to limit the linear scale.
87
+
88
+ Args:
89
+ img_shape (Tuple[int, int]): The shape of the canvas image.
90
+ min_scale (int): The minimum scale. Defaults to 0.3.
91
+ max_scale (int): The maximum scale. Defaults to 3.0.
92
+ Returns:
93
+ int: The adaptive scale.
94
+ """
95
+ short_edge_length = min(img_shape)
96
+ scale = short_edge_length / 224.
97
+ return min(max(scale, min_scale), max_scale)
98
+
99
+
100
+ def make_grid(imgs, names):
101
+ """Concat list of pictures into a single big picture, align height here."""
102
+ visualizer = Visualizer.get_current_instance()
103
+ ori_shapes = [img.shape[:2] for img in imgs]
104
+ max_height = int(max(img.shape[0] for img in imgs) * 1.1)
105
+ min_width = min(img.shape[1] for img in imgs)
106
+ horizontal_gap = min_width // 10
107
+ img_scale = _get_adaptive_scale((max_height, min_width))
108
+
109
+ texts = []
110
+ text_positions = []
111
+ start_x = 0
112
+ for i, img in enumerate(imgs):
113
+ pad_height = (max_height - img.shape[0]) // 2
114
+ pad_width = horizontal_gap // 2
115
+ # make border
116
+ imgs[i] = cv2.copyMakeBorder(
117
+ img,
118
+ pad_height,
119
+ max_height - img.shape[0] - pad_height + int(img_scale * 30 * 2),
120
+ pad_width,
121
+ pad_width,
122
+ cv2.BORDER_CONSTANT,
123
+ value=(255, 255, 255))
124
+ texts.append(f'{"execution: "}{i}\n{names[i]}\n{ori_shapes[i]}')
125
+ text_positions.append(
126
+ [start_x + img.shape[1] // 2 + pad_width, max_height])
127
+ start_x += img.shape[1] + horizontal_gap
128
+
129
+ display_img = np.concatenate(imgs, axis=1)
130
+ visualizer.set_image(display_img)
131
+ img_scale = _get_adaptive_scale(display_img.shape[:2])
132
+ visualizer.draw_texts(
133
+ texts,
134
+ positions=np.array(text_positions),
135
+ font_sizes=img_scale * 7,
136
+ colors='black',
137
+ horizontal_alignments='center',
138
+ font_families='monospace')
139
+ return visualizer.get_image()
140
+
141
+
142
+ class InspectCompose(Compose):
143
+ """Compose multiple transforms sequentially.
144
+
145
+ And record "img" field of all results in one list.
146
+ """
147
+
148
+ def __init__(self, transforms, intermediate_imgs):
149
+ super().__init__(transforms=transforms)
150
+ self.intermediate_imgs = intermediate_imgs
151
+
152
+ def __call__(self, data):
153
+ if 'img' in data:
154
+ self.intermediate_imgs.append({
155
+ 'name': 'original',
156
+ 'img': data['img'].copy()
157
+ })
158
+ self.ptransforms = [
159
+ self.transforms[i] for i in range(len(self.transforms) - 1)
160
+ ]
161
+ for t in self.ptransforms:
162
+ data = t(data)
163
+ # Keep the same meta_keys in the PackDetInputs
164
+ self.transforms[-1].meta_keys = [key for key in data]
165
+ data_sample = self.transforms[-1](data)
166
+ if data is None:
167
+ return None
168
+ if 'img' in data:
169
+ self.intermediate_imgs.append({
170
+ 'name':
171
+ t.__class__.__name__,
172
+ 'dataset_sample':
173
+ data_sample['data_samples']
174
+ })
175
+ return data
176
+
177
+
178
+ def main():
179
+ args = parse_args()
180
+ cfg = Config.fromfile(args.config)
181
+ if args.cfg_options is not None:
182
+ cfg.merge_from_dict(args.cfg_options)
183
+
184
+ init_default_scope(cfg.get('default_scope', 'mmpl'))
185
+
186
+ dataset_cfg = cfg.get('datamodule_cfg').get(args.phase + '_loader').get('dataset')
187
+ dataset = DATASETS.build(dataset_cfg)
188
+
189
+ # self added
190
+ vis_backends = [dict(type='mmdet.LocalVisBackend')]
191
+ visualizer = dict(
192
+ type='mmdet.DetLocalVisualizer', vis_backends=vis_backends, name='visualizer')
193
+
194
+ visualizer = VISUALIZERS.build(visualizer)
195
+ visualizer.dataset_meta = dataset.metainfo
196
+
197
+ intermediate_imgs = []
198
+
199
+ if not hasattr(dataset, 'pipeline'):
200
+ # for dataset_wrapper
201
+ dataset = dataset.dataset
202
+
203
+ # TODO: The dataset wrapper occasion is not considered here
204
+ dataset.pipeline = InspectCompose(dataset.pipeline.transforms,
205
+ intermediate_imgs)
206
+
207
+ # init visualization image number
208
+ assert args.show_number > 0
209
+ display_number = min(args.show_number, len(dataset))
210
+
211
+ progress_bar = ProgressBar(display_number)
212
+ for i, item in zip(range(display_number), dataset):
213
+ image_i = []
214
+ result_i = [result['dataset_sample'] for result in intermediate_imgs]
215
+ for k, datasample in enumerate(result_i):
216
+ image = datasample.img
217
+ gt_instances = datasample.gt_instances
218
+ image = image[..., [2, 1, 0]] # bgr to rgb
219
+ gt_bboxes = gt_instances.get('bboxes', None)
220
+ if gt_bboxes is not None and isinstance(gt_bboxes, BaseBoxes):
221
+ gt_instances.bboxes = gt_bboxes.tensor
222
+ gt_masks = gt_instances.get('masks', None)
223
+ if gt_masks is not None:
224
+ masks = mask2ndarray(gt_masks)
225
+ gt_instances.masks = masks.astype(bool)
226
+ datasample.gt_instances = gt_instances
227
+ # get filename from dataset or just use index as filename
228
+ visualizer.add_datasample(
229
+ 'result',
230
+ image,
231
+ datasample,
232
+ draw_pred=False,
233
+ draw_gt=True,
234
+ show=False)
235
+ image_show = visualizer.get_image()
236
+ image_i.append(image_show)
237
+
238
+ if args.mode == 'original':
239
+ image = image_i[0]
240
+ elif args.mode == 'transformed':
241
+ image = image_i[-1]
242
+ else:
243
+ image = make_grid([result for result in image_i],
244
+ [result['name'] for result in intermediate_imgs])
245
+
246
+ if hasattr(datasample, 'img_path'):
247
+ filename = osp.basename(datasample.img_path)
248
+ else:
249
+ # some dataset have not image path
250
+ filename = f'{i}.jpg'
251
+ out_file = osp.join(args.out_dir,
252
+ filename) if args.out_dir is not None else None
253
+
254
+ if out_file is not None:
255
+ mmcv.imwrite(image[..., ::-1], out_file)
256
+
257
+ if not args.not_show:
258
+ visualizer.show(
259
+ image, win_name=filename, wait_time=args.show_interval)
260
+
261
+ intermediate_imgs.clear()
262
+ progress_bar.update()
263
+
264
+
265
+ if __name__ == '__main__':
266
+ main()
tools/ins_seg/analysis_tools/dataset_analysis.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import argparse
3
+ import os.path
4
+ from statistics import median
5
+
6
+ import matplotlib.patches as mpatches
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ from mmengine.config import Config
10
+ from mmengine.registry import init_default_scope
11
+ from mmengine.utils import ProgressBar
12
+ from prettytable import PrettyTable
13
+
14
+ from mmyolo.registry import DATASETS
15
+ from mmyolo.utils.misc import show_data_classes
16
+
17
+
18
+ def parse_args():
19
+ parser = argparse.ArgumentParser(
20
+ description='Distribution of categories and bbox instances')
21
+ parser.add_argument('--config', default='configs/ins_seg/seg_sam_queryprompt_ssdd_bs2_last_config.py', help='config file path')
22
+ parser.add_argument(
23
+ '--val-dataset',
24
+ default=False,
25
+ action='store_true',
26
+ help='The default train_dataset.'
27
+ 'To change it to val_dataset, enter "--val-dataset"')
28
+ parser.add_argument(
29
+ '--class-name',
30
+ default=None,
31
+ type=str,
32
+ help='Display specific class, e.g., "bicycle"')
33
+ parser.add_argument(
34
+ '--area-rule',
35
+ default=None,
36
+ type=int,
37
+ nargs='+',
38
+ help='Redefine area rules,but no more than three numbers.'
39
+ ' e.g., 30 70 125')
40
+ parser.add_argument(
41
+ '--func',
42
+ default='show_bbox_num',
43
+ type=str,
44
+ choices=[
45
+ 'show_bbox_num', 'show_bbox_wh', 'show_bbox_wh_ratio',
46
+ 'show_bbox_area'
47
+ ],
48
+ help='Dataset analysis function selection.')
49
+ parser.add_argument(
50
+ '--out-dir',
51
+ default='./results/dataset_analysis',
52
+ type=str,
53
+ help='Output directory of dataset analysis visualization results,'
54
+ ' Save in "./dataset_analysis/" by default')
55
+ args = parser.parse_args()
56
+ return args
57
+
58
+
59
+ def show_bbox_num(cfg, out_dir, fig_set, class_name, class_num):
60
+ """Display the distribution map of categories and number of bbox
61
+ instances."""
62
+ print('\n\nDrawing bbox_num figure:')
63
+ # Draw designs
64
+ fig = plt.figure(
65
+ figsize=(fig_set['figsize'][0], fig_set['figsize'][1]), dpi=300)
66
+ plt.bar(class_name, class_num, align='center')
67
+
68
+ # Draw titles, labels and so on
69
+ for x, y in enumerate(class_num):
70
+ plt.text(x, y, '%s' % y, ha='center', fontsize=fig_set['fontsize'] + 3)
71
+ plt.xticks(rotation=fig_set['xticks_angle'])
72
+ plt.xlabel('Category Name')
73
+ plt.ylabel('Num of instances')
74
+ plt.title(cfg.dataset_type)
75
+
76
+ # Save figure
77
+ if not os.path.exists(out_dir):
78
+ os.makedirs(out_dir)
79
+ out_name = fig_set['out_name']
80
+ fig.savefig(
81
+ f'{out_dir}/{out_name}_bbox_num.jpg',
82
+ bbox_inches='tight',
83
+ pad_inches=0.1) # Save Image
84
+ plt.close()
85
+ print(f'End and save in {out_dir}/{out_name}_bbox_num.jpg')
86
+
87
+
88
+ def show_bbox_wh(out_dir, fig_set, class_bbox_w, class_bbox_h, class_name):
89
+ """Display the width and height distribution of categories and bbox
90
+ instances."""
91
+ print('\n\nDrawing bbox_wh figure:')
92
+ # Draw designs
93
+ fig, ax = plt.subplots(
94
+ figsize=(fig_set['figsize'][0], fig_set['figsize'][1]), dpi=300)
95
+
96
+ # Set the position of the map and label on the x-axis
97
+ positions_w = list(range(0, 12 * len(class_name), 12))
98
+ positions_h = list(range(6, 12 * len(class_name), 12))
99
+ positions_x_label = list(range(3, 12 * len(class_name) + 1, 12))
100
+ ax.violinplot(
101
+ class_bbox_w, positions_w, showmeans=True, showmedians=True, widths=4)
102
+ ax.violinplot(
103
+ class_bbox_h, positions_h, showmeans=True, showmedians=True, widths=4)
104
+
105
+ # Draw titles, labels and so on
106
+ plt.xticks(rotation=fig_set['xticks_angle'])
107
+ plt.ylabel('The width or height of bbox')
108
+ plt.xlabel('Class name')
109
+ plt.title('Width or height distribution of classes and bbox instances')
110
+
111
+ # Draw the max, min and median of wide data in violin chart
112
+ for i in range(len(class_bbox_w)):
113
+ plt.text(
114
+ positions_w[i],
115
+ median(class_bbox_w[i]),
116
+ f'{"%.2f" % median(class_bbox_w[i])}',
117
+ ha='center',
118
+ fontsize=fig_set['fontsize'])
119
+ plt.text(
120
+ positions_w[i],
121
+ max(class_bbox_w[i]),
122
+ f'{"%.2f" % max(class_bbox_w[i])}',
123
+ ha='center',
124
+ fontsize=fig_set['fontsize'])
125
+ plt.text(
126
+ positions_w[i],
127
+ min(class_bbox_w[i]),
128
+ f'{"%.2f" % min(class_bbox_w[i])}',
129
+ ha='center',
130
+ fontsize=fig_set['fontsize'])
131
+
132
+ # Draw the max, min and median of height data in violin chart
133
+ for i in range(len(positions_h)):
134
+ plt.text(
135
+ positions_h[i],
136
+ median(class_bbox_h[i]),
137
+ f'{"%.2f" % median(class_bbox_h[i])}',
138
+ ha='center',
139
+ fontsize=fig_set['fontsize'])
140
+ plt.text(
141
+ positions_h[i],
142
+ max(class_bbox_h[i]),
143
+ f'{"%.2f" % max(class_bbox_h[i])}',
144
+ ha='center',
145
+ fontsize=fig_set['fontsize'])
146
+ plt.text(
147
+ positions_h[i],
148
+ min(class_bbox_h[i]),
149
+ f'{"%.2f" % min(class_bbox_h[i])}',
150
+ ha='center',
151
+ fontsize=fig_set['fontsize'])
152
+
153
+ # Draw Legend
154
+ plt.setp(ax, xticks=positions_x_label, xticklabels=class_name)
155
+ labels = ['bbox_w', 'bbox_h']
156
+ colors = ['steelblue', 'darkorange']
157
+ patches = [
158
+ mpatches.Patch(color=colors[i], label=f'{labels[i]:s}')
159
+ for i in range(len(colors))
160
+ ]
161
+ ax = plt.gca()
162
+ box = ax.get_position()
163
+ ax.set_position([box.x0, box.y0, box.width, box.height * 0.8])
164
+ ax.legend(loc='upper center', handles=patches, ncol=2)
165
+
166
+ # Save figure
167
+ if not os.path.exists(out_dir):
168
+ os.makedirs(out_dir)
169
+ out_name = fig_set['out_name']
170
+ fig.savefig(
171
+ f'{out_dir}/{out_name}_bbox_wh.jpg',
172
+ bbox_inches='tight',
173
+ pad_inches=0.1) # Save Image
174
+ plt.close()
175
+ print(f'End and save in {out_dir}/{out_name}_bbox_wh.jpg')
176
+
177
+
178
+ def show_bbox_wh_ratio(out_dir, fig_set, class_name, class_bbox_ratio):
179
+ """Display the distribution map of category and bbox instance width and
180
+ height ratio."""
181
+ print('\n\nDrawing bbox_wh_ratio figure:')
182
+ # Draw designs
183
+ fig, ax = plt.subplots(
184
+ figsize=(fig_set['figsize'][0], fig_set['figsize'][1]), dpi=300)
185
+
186
+ # Set the position of the map and label on the x-axis
187
+ positions = list(range(0, 6 * len(class_name), 6))
188
+ ax.violinplot(
189
+ class_bbox_ratio,
190
+ positions,
191
+ showmeans=True,
192
+ showmedians=True,
193
+ widths=5)
194
+
195
+ # Draw titles, labels and so on
196
+ plt.xticks(rotation=fig_set['xticks_angle'])
197
+ plt.ylabel('Ratio of width to height of bbox')
198
+ plt.xlabel('Class name')
199
+ plt.title('Width to height ratio distribution of class and bbox instances')
200
+
201
+ # Draw the max, min and median of wide data in violin chart
202
+ for i in range(len(class_bbox_ratio)):
203
+ plt.text(
204
+ positions[i],
205
+ median(class_bbox_ratio[i]),
206
+ f'{"%.2f" % median(class_bbox_ratio[i])}',
207
+ ha='center',
208
+ fontsize=fig_set['fontsize'])
209
+ plt.text(
210
+ positions[i],
211
+ max(class_bbox_ratio[i]),
212
+ f'{"%.2f" % max(class_bbox_ratio[i])}',
213
+ ha='center',
214
+ fontsize=fig_set['fontsize'])
215
+ plt.text(
216
+ positions[i],
217
+ min(class_bbox_ratio[i]),
218
+ f'{"%.2f" % min(class_bbox_ratio[i])}',
219
+ ha='center',
220
+ fontsize=fig_set['fontsize'])
221
+
222
+ # Set the position of the map and label on the x-axis
223
+ plt.setp(ax, xticks=positions, xticklabels=class_name)
224
+
225
+ # Save figure
226
+ if not os.path.exists(out_dir):
227
+ os.makedirs(out_dir)
228
+ out_name = fig_set['out_name']
229
+ fig.savefig(
230
+ f'{out_dir}/{out_name}_bbox_ratio.jpg',
231
+ bbox_inches='tight',
232
+ pad_inches=0.1) # Save Image
233
+ plt.close()
234
+ print(f'End and save in {out_dir}/{out_name}_bbox_ratio.jpg')
235
+
236
+
237
+ def show_bbox_area(out_dir, fig_set, area_rule, class_name, bbox_area_num):
238
+ """Display the distribution map of category and bbox instance area based on
239
+ the rules of large, medium and small objects."""
240
+ print('\n\nDrawing bbox_area figure:')
241
+ # Set the direct distance of each label and the width of each histogram
242
+ # Set the required labels and colors
243
+ positions = np.arange(0, 2 * len(class_name), 2)
244
+ width = 0.4
245
+ labels = ['Small', 'Mediun', 'Large', 'Huge']
246
+ colors = ['#438675', '#F7B469', '#6BA6DA', '#913221']
247
+
248
+ # Draw designs
249
+ fig = plt.figure(
250
+ figsize=(fig_set['figsize'][0], fig_set['figsize'][1]), dpi=300)
251
+ for i in range(len(area_rule) - 1):
252
+ area_num = [bbox_area_num[idx][i] for idx in range(len(class_name))]
253
+ plt.bar(
254
+ positions + width * i,
255
+ area_num,
256
+ width,
257
+ label=labels[i],
258
+ color=colors[i])
259
+ for idx, (x, y) in enumerate(zip(positions.tolist(), area_num)):
260
+ plt.text(
261
+ x + width * i,
262
+ y,
263
+ y,
264
+ ha='center',
265
+ fontsize=fig_set['fontsize'] - 1)
266
+
267
+ # Draw titles, labels and so on
268
+ plt.xticks(rotation=fig_set['xticks_angle'])
269
+ plt.xticks(positions + width * ((len(area_rule) - 2) / 2), class_name)
270
+ plt.ylabel('Class Area')
271
+ plt.xlabel('Class Name')
272
+ plt.title(
273
+ 'Area and number of large, medium and small objects of each class')
274
+
275
+ # Set and Draw Legend
276
+ patches = [
277
+ mpatches.Patch(color=colors[i], label=f'{labels[i]:s}')
278
+ for i in range(len(area_rule) - 1)
279
+ ]
280
+ ax = plt.gca()
281
+ box = ax.get_position()
282
+ ax.set_position([box.x0, box.y0, box.width, box.height * 0.8])
283
+ ax.legend(loc='upper center', handles=patches, ncol=len(area_rule) - 1)
284
+
285
+ # Save figure
286
+ if not os.path.exists(out_dir):
287
+ os.makedirs(out_dir)
288
+ out_name = fig_set['out_name']
289
+ fig.savefig(
290
+ f'{out_dir}/{out_name}_bbox_area.jpg',
291
+ bbox_inches='tight',
292
+ pad_inches=0.1) # Save Image
293
+ plt.close()
294
+ print(f'End and save in {out_dir}/{out_name}_bbox_area.jpg')
295
+
296
+
297
+ def show_class_list(classes, class_num):
298
+ """Print the data of the class obtained by the current run."""
299
+ print('\n\nThe information obtained is as follows:')
300
+ class_info = PrettyTable()
301
+ class_info.title = 'Information of dataset class'
302
+ # List Print Settings
303
+ # If the quantity is too large, 25 rows will be displayed in each column
304
+ if len(classes) < 25:
305
+ class_info.add_column('Class name', classes)
306
+ class_info.add_column('Bbox num', class_num)
307
+ elif len(classes) % 25 != 0 and len(classes) > 25:
308
+ col_num = int(len(classes) / 25) + 1
309
+ class_nums = class_num.tolist()
310
+ class_name_list = list(classes)
311
+ for i in range(0, (col_num * 25) - len(classes)):
312
+ class_name_list.append('')
313
+ class_nums.append('')
314
+ for i in range(0, len(class_name_list), 25):
315
+ class_info.add_column('Class name', class_name_list[i:i + 25])
316
+ class_info.add_column('Bbox num', class_nums[i:i + 25])
317
+
318
+ # Align display data to the left
319
+ class_info.align['Class name'] = 'l'
320
+ class_info.align['Bbox num'] = 'l'
321
+ print(class_info)
322
+
323
+
324
+ def show_data_list(args, area_rule):
325
+ """Print run setup information."""
326
+ print('\n\nPrint current running information:')
327
+ data_info = PrettyTable()
328
+ data_info.title = 'Dataset information'
329
+ # Print the corresponding information according to the settings
330
+ if args.val_dataset is False:
331
+ data_info.add_column('Dataset type', ['train_dataset'])
332
+ elif args.val_dataset is True:
333
+ data_info.add_column('Dataset type', ['val_dataset'])
334
+ if args.class_name is None:
335
+ data_info.add_column('Class name', ['All classes'])
336
+ else:
337
+ data_info.add_column('Class name', [args.class_name])
338
+ if args.func is None:
339
+ data_info.add_column('Function', ['All function'])
340
+ else:
341
+ data_info.add_column('Function', [args.func])
342
+ data_info.add_column('Area rule', [area_rule])
343
+
344
+ print(data_info)
345
+
346
+
347
+ def main():
348
+ args = parse_args()
349
+ cfg = Config.fromfile(args.config)
350
+
351
+ init_default_scope(cfg.get('default_scope', 'mmpl'))
352
+
353
+ def replace_pipeline_to_none(cfg):
354
+ """Recursively iterate over all dataset(or datasets) and set their
355
+ pipelines to none.Datasets are mean ConcatDataset.
356
+
357
+ Recursively terminates only when all dataset(or datasets) have been
358
+ traversed
359
+ """
360
+
361
+ if cfg.get('dataset', None) is None and cfg.get('datasets',
362
+ None) is None:
363
+ return
364
+ dataset = cfg.dataset if cfg.get('dataset', None) else cfg.datasets
365
+ if isinstance(dataset, list):
366
+ for item in dataset:
367
+ item.pipeline = None
368
+ elif dataset.get('pipeline', None):
369
+ dataset.pipeline = None
370
+ else:
371
+ replace_pipeline_to_none(dataset)
372
+
373
+ # 1.Build Dataset
374
+ dataset_cfg = cfg.get('datamodule_cfg')
375
+ if args.val_dataset is False:
376
+ replace_pipeline_to_none(dataset_cfg.train_loader)
377
+ dataset = DATASETS.build(dataset_cfg.train_loader.dataset)
378
+ else:
379
+ replace_pipeline_to_none(dataset_cfg.val_loader)
380
+ dataset = DATASETS.build(dataset_cfg.val_loader.dataset)
381
+
382
+ # 2.Prepare data
383
+ # Drawing settings
384
+ fig_all_set = {
385
+ 'figsize': [35, 18],
386
+ 'fontsize': int(10 - 0.08 * len(dataset.metainfo['classes'])),
387
+ 'xticks_angle': 70,
388
+ 'out_name': cfg.dataset_type
389
+ }
390
+ fig_one_set = {
391
+ 'figsize': [15, 10],
392
+ 'fontsize': 10,
393
+ 'xticks_angle': 0,
394
+ 'out_name': args.class_name
395
+ }
396
+
397
+ # Call the category name and save address
398
+ if args.class_name is None:
399
+ classes = dataset.metainfo['classes']
400
+ classes_idx = [i for i in range(len(classes))]
401
+ fig_set = fig_all_set
402
+ elif args.class_name in dataset.metainfo['classes']:
403
+ classes = [args.class_name]
404
+ classes_idx = [dataset.metainfo['classes'].index(args.class_name)]
405
+ fig_set = fig_one_set
406
+ else:
407
+ data_classes = dataset.metainfo['classes']
408
+ show_data_classes(data_classes)
409
+ raise RuntimeError(f'Expected args.class_name to be one of the list,'
410
+ f'but got "{args.class_name}"')
411
+
412
+ # Building Area Rules
413
+ if args.area_rule is None:
414
+ area_rule = [0, 32, 96, 1e5]
415
+ elif args.area_rule and len(args.area_rule) <= 3:
416
+ area_rules = [0] + args.area_rule + [1e5]
417
+ area_rule = sorted(area_rules)
418
+ else:
419
+ raise RuntimeError(
420
+ f'Expected the "{args.area_rule}" to be e.g. 30 60 120, '
421
+ 'and no more than three numbers.')
422
+
423
+ # Build arrays or lists to store data for each category
424
+ class_num = np.zeros((len(classes), ), dtype=np.int64)
425
+ class_bbox = [[] for _ in classes]
426
+ class_name = []
427
+ class_bbox_w = []
428
+ class_bbox_h = []
429
+ class_bbox_ratio = []
430
+ bbox_area_num = []
431
+ instance_num = []
432
+
433
+ show_data_list(args, area_rule)
434
+ # Get the quantity and bbox data corresponding to each category
435
+ print('\nRead the information of each picture in the dataset:')
436
+ progress_bar = ProgressBar(len(dataset))
437
+
438
+ counts_instances = 0
439
+ for index in range(len(dataset)):
440
+ instances = dataset[index]['instances']
441
+ # if len(instances) > 100:
442
+ # counts_instances += 1
443
+ # # continue
444
+ # labels = [instance['bbox_label'] for instance in instances]
445
+ # counts = np.bincount(labels)
446
+ # label_id = np.argmax(counts)
447
+ # # Harbor Large_Vehicle Small_Vehicle ship
448
+ # print(f'the class is {dataset.metainfo["classes"][label_id]}')
449
+ # print('The number of bboxes in the picture is greater than 100')
450
+ instance_num.append(len(instances))
451
+ for instance in dataset[index]['instances']:
452
+ if instance[
453
+ 'bbox_label'] in classes_idx and args.class_name is None:
454
+ class_num[instance['bbox_label']] += 1
455
+ class_bbox[instance['bbox_label']].append(instance['bbox'])
456
+ elif instance['bbox_label'] in classes_idx and args.class_name:
457
+ class_num[0] += 1
458
+ class_bbox[0].append(instance['bbox'])
459
+ progress_bar.update()
460
+ show_class_list(classes, class_num)
461
+ print(f'The number of bboxes in the picture is greater than 120: {counts_instances}')
462
+ # Get the width, height and area of bbox corresponding to each category
463
+ print('\nRead bbox information in each class:')
464
+ progress_bar_classes = ProgressBar(len(classes))
465
+ for idx, (classes, classes_idx) in enumerate(zip(classes, classes_idx)):
466
+ bbox = np.array(class_bbox[idx])
467
+ bbox_area_nums = np.zeros((len(area_rule) - 1, ), dtype=np.int64)
468
+ if len(bbox) > 0:
469
+ bbox_wh = bbox[:, 2:4] - bbox[:, 0:2]
470
+ bbox_ratio = bbox_wh[:, 0] / bbox_wh[:, 1]
471
+ bbox_area = bbox_wh[:, 0] * bbox_wh[:, 1]
472
+ class_bbox_w.append(bbox_wh[:, 0].tolist())
473
+ class_bbox_h.append(bbox_wh[:, 1].tolist())
474
+ class_bbox_ratio.append(bbox_ratio.tolist())
475
+
476
+ # The area rule, there is an section between two numbers
477
+ for i in range(len(area_rule) - 1):
478
+ bbox_area_nums[i] = np.logical_and(
479
+ bbox_area >= area_rule[i]**2,
480
+ bbox_area < area_rule[i + 1]**2).sum()
481
+ elif len(bbox) == 0:
482
+ class_bbox_w.append([0])
483
+ class_bbox_h.append([0])
484
+ class_bbox_ratio.append([0])
485
+
486
+ class_name.append(classes)
487
+ bbox_area_num.append(bbox_area_nums.tolist())
488
+ progress_bar_classes.update()
489
+
490
+ # 3.draw Dataset Information
491
+ if args.func is None:
492
+ show_bbox_num(cfg, args.out_dir, fig_set, class_name, class_num)
493
+ show_bbox_wh(args.out_dir, fig_set, class_bbox_w, class_bbox_h,
494
+ class_name)
495
+ show_bbox_wh_ratio(args.out_dir, fig_set, class_name, class_bbox_ratio)
496
+ show_bbox_area(args.out_dir, fig_set, area_rule, class_name,
497
+ bbox_area_num)
498
+ elif args.func == 'show_bbox_num':
499
+ show_bbox_num(cfg, args.out_dir, fig_set, class_name, class_num)
500
+ print('num_instances_info:')
501
+ print('max num_instances=', max(instance_num))
502
+ print('min num_instances=', min(instance_num))
503
+ print('mean num_instances=', np.mean(instance_num))
504
+ elif args.func == 'show_bbox_wh':
505
+ show_bbox_wh(args.out_dir, fig_set, class_bbox_w, class_bbox_h,
506
+ class_name)
507
+ elif args.func == 'show_bbox_wh_ratio':
508
+ show_bbox_wh_ratio(args.out_dir, fig_set, class_name, class_bbox_ratio)
509
+ elif args.func == 'show_bbox_area':
510
+ show_bbox_area(args.out_dir, fig_set, area_rule, class_name,
511
+ bbox_area_num)
512
+ else:
513
+ raise RuntimeError(
514
+ 'Please enter the correct func name, e.g., show_bbox_num')
515
+
516
+
517
+ if __name__ == '__main__':
518
+ main()
tools/ins_seg/dataset_converters/cityscapes.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+ import os.path as osp
4
+
5
+ import cityscapesscripts.helpers.labels as CSLabels
6
+ import mmcv
7
+ import numpy as np
8
+ import pycocotools.mask as maskUtils
9
+ from mmengine.fileio import dump
10
+ from mmengine.utils import (Timer, mkdir_or_exist, track_parallel_progress,
11
+ track_progress)
12
+
13
+
14
+ def collect_files(img_dir, gt_dir):
15
+ suffix = 'leftImg8bit.png'
16
+ files = []
17
+ for img_file in glob.glob(osp.join(img_dir, '**/*.png')):
18
+ assert img_file.endswith(suffix), img_file
19
+ inst_file = gt_dir + img_file[
20
+ len(img_dir):-len(suffix)] + 'gtFine_instanceIds.png'
21
+ # Note that labelIds are not converted to trainId for seg map
22
+ segm_file = gt_dir + img_file[
23
+ len(img_dir):-len(suffix)] + 'gtFine_labelIds.png'
24
+ files.append((img_file, inst_file, segm_file))
25
+ assert len(files), f'No images found in {img_dir}'
26
+ print(f'Loaded {len(files)} images from {img_dir}')
27
+
28
+ return files
29
+
30
+
31
+ def collect_annotations(files, nproc=1):
32
+ print('Loading annotation images')
33
+ if nproc > 1:
34
+ images = track_parallel_progress(load_img_info, files, nproc=nproc)
35
+ else:
36
+ images = track_progress(load_img_info, files)
37
+
38
+ return images
39
+
40
+
41
+ def load_img_info(files):
42
+ img_file, inst_file, segm_file = files
43
+ inst_img = mmcv.imread(inst_file, 'unchanged')
44
+ # ids < 24 are stuff labels (filtering them first is about 5% faster)
45
+ unique_inst_ids = np.unique(inst_img[inst_img >= 24])
46
+ anno_info = []
47
+ for inst_id in unique_inst_ids:
48
+ # For non-crowd annotations, inst_id // 1000 is the label_id
49
+ # Crowd annotations have <1000 instance ids
50
+ label_id = inst_id // 1000 if inst_id >= 1000 else inst_id
51
+ label = CSLabels.id2label[label_id]
52
+ if not label.hasInstances or label.ignoreInEval:
53
+ continue
54
+
55
+ category_id = label.id
56
+ iscrowd = int(inst_id < 1000)
57
+ mask = np.asarray(inst_img == inst_id, dtype=np.uint8, order='F')
58
+ mask_rle = maskUtils.encode(mask[:, :, None])[0]
59
+
60
+ area = maskUtils.area(mask_rle)
61
+ # convert to COCO style XYWH format
62
+ bbox = maskUtils.toBbox(mask_rle)
63
+
64
+ # for json encoding
65
+ mask_rle['counts'] = mask_rle['counts'].decode()
66
+
67
+ anno = dict(
68
+ iscrowd=iscrowd,
69
+ category_id=category_id,
70
+ bbox=bbox.tolist(),
71
+ area=area.tolist(),
72
+ segmentation=mask_rle)
73
+ anno_info.append(anno)
74
+ video_name = osp.basename(osp.dirname(img_file))
75
+ img_info = dict(
76
+ # remove img_prefix for filename
77
+ file_name=osp.join(video_name, osp.basename(img_file)),
78
+ height=inst_img.shape[0],
79
+ width=inst_img.shape[1],
80
+ anno_info=anno_info,
81
+ segm_file=osp.join(video_name, osp.basename(segm_file)))
82
+
83
+ return img_info
84
+
85
+
86
+ def cvt_annotations(image_infos, out_json_name):
87
+ out_json = dict()
88
+ img_id = 0
89
+ ann_id = 0
90
+ out_json['images'] = []
91
+ out_json['categories'] = []
92
+ out_json['annotations'] = []
93
+ for image_info in image_infos:
94
+ image_info['id'] = img_id
95
+ anno_infos = image_info.pop('anno_info')
96
+ out_json['images'].append(image_info)
97
+ for anno_info in anno_infos:
98
+ anno_info['image_id'] = img_id
99
+ anno_info['id'] = ann_id
100
+ out_json['annotations'].append(anno_info)
101
+ ann_id += 1
102
+ img_id += 1
103
+ for label in CSLabels.labels:
104
+ if label.hasInstances and not label.ignoreInEval:
105
+ cat = dict(id=label.id, name=label.name)
106
+ out_json['categories'].append(cat)
107
+
108
+ if len(out_json['annotations']) == 0:
109
+ out_json.pop('annotations')
110
+
111
+ dump(out_json, out_json_name)
112
+ return out_json
113
+
114
+
115
+ def parse_args():
116
+ parser = argparse.ArgumentParser(
117
+ description='Convert Cityscapes annotations to COCO format')
118
+ parser.add_argument('cityscapes_path', help='cityscapes data path')
119
+ parser.add_argument('--img-dir', default='leftImg8bit', type=str)
120
+ parser.add_argument('--gt-dir', default='gtFine', type=str)
121
+ parser.add_argument('-o', '--out-dir', help='output path')
122
+ parser.add_argument(
123
+ '--nproc', default=1, type=int, help='number of process')
124
+ args = parser.parse_args()
125
+ return args
126
+
127
+
128
+ def main():
129
+ args = parse_args()
130
+ cityscapes_path = args.cityscapes_path
131
+ out_dir = args.out_dir if args.out_dir else cityscapes_path
132
+ mkdir_or_exist(out_dir)
133
+
134
+ img_dir = osp.join(cityscapes_path, args.img_dir)
135
+ gt_dir = osp.join(cityscapes_path, args.gt_dir)
136
+
137
+ set_name = dict(
138
+ train='instancesonly_filtered_gtFine_train.json',
139
+ val='instancesonly_filtered_gtFine_val.json',
140
+ test='instancesonly_filtered_gtFine_test.json')
141
+
142
+ for split, json_name in set_name.items():
143
+ print(f'Converting {split} into {json_name}')
144
+ with Timer(print_tmpl='It took {}s to convert Cityscapes annotation'):
145
+ files = collect_files(
146
+ osp.join(img_dir, split), osp.join(gt_dir, split))
147
+ image_infos = collect_annotations(files, nproc=args.nproc)
148
+ cvt_annotations(image_infos, osp.join(out_dir, json_name))
149
+
150
+
151
+ if __name__ == '__main__':
152
+ main()
tools/ins_seg/dataset_converters/whu_building_convert.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+ import os
4
+ import os.path as osp
5
+
6
+ import cv2
7
+ import mmcv
8
+ import numpy as np
9
+ import pycocotools.mask as maskUtils
10
+ from mmengine.fileio import dump
11
+ from mmengine.utils import (Timer, mkdir_or_exist, track_parallel_progress,
12
+ track_progress)
13
+
14
+
15
+ def collect_files(img_dir, gt_dir):
16
+ files = []
17
+ img_files = glob.glob(osp.join(img_dir, 'image/*.tif'))
18
+ for img_file in img_files:
19
+ segm_file = gt_dir + '/label/' + os.path.basename(img_file)
20
+ files.append((img_file, segm_file))
21
+ assert len(files), f'No images found in {img_dir}'
22
+ print(f'Loaded {len(files)} images from {img_dir}')
23
+
24
+ return files
25
+
26
+
27
+ def collect_annotations(files, nproc=1):
28
+ print('Loading annotation images')
29
+ if nproc > 1:
30
+ images = track_parallel_progress(load_img_info, files, nproc=nproc)
31
+ else:
32
+ images = track_progress(load_img_info, files)
33
+
34
+ return images
35
+
36
+
37
+ def load_img_info(files):
38
+ img_file, segm_file = files
39
+ segm_img = mmcv.imread(segm_file, flag='unchanged', backend='cv2')
40
+
41
+ num_labels, instances, stats, centroids = cv2.connectedComponentsWithStats(segm_img, connectivity=4)
42
+
43
+ anno_info = []
44
+ for inst_id in range(1, num_labels):
45
+ category_id = 1
46
+ mask = np.asarray(instances == inst_id, dtype=np.uint8, order='F')
47
+ if mask.max() < 1:
48
+ print(f'Ignore empty instance: {inst_id} in {segm_file}')
49
+ continue
50
+ mask_rle = maskUtils.encode(mask[:, :, None])[0]
51
+ area = maskUtils.area(mask_rle)
52
+ # convert to COCO style XYWH format
53
+ bbox = maskUtils.toBbox(mask_rle)
54
+
55
+ # for json encoding
56
+ mask_rle['counts'] = mask_rle['counts'].decode()
57
+
58
+ anno = dict(
59
+ iscrowd=0,
60
+ category_id=category_id,
61
+ bbox=bbox.tolist(),
62
+ area=area.tolist(),
63
+ segmentation=mask_rle)
64
+ anno_info.append(anno)
65
+ video_name = osp.basename(osp.dirname(img_file))
66
+ img_info = dict(
67
+ # remove img_prefix for filename
68
+ file_name=osp.basename(img_file),
69
+ height=segm_img.shape[0],
70
+ width=segm_img.shape[1],
71
+ anno_info=anno_info,
72
+ segm_file=osp.basename(segm_file))
73
+
74
+ return img_info
75
+
76
+
77
+ def cvt_annotations(image_infos, out_json_name):
78
+ out_json = dict()
79
+ img_id = 0
80
+ ann_id = 0
81
+ out_json['images'] = []
82
+ out_json['categories'] = []
83
+ out_json['annotations'] = []
84
+ for image_info in image_infos:
85
+ image_info['id'] = img_id
86
+ anno_infos = image_info.pop('anno_info')
87
+ out_json['images'].append(image_info)
88
+ for anno_info in anno_infos:
89
+ anno_info['image_id'] = img_id
90
+ anno_info['id'] = ann_id
91
+ out_json['annotations'].append(anno_info)
92
+ ann_id += 1
93
+ img_id += 1
94
+
95
+ cat = dict(id=1, name='building')
96
+ out_json['categories'].append(cat)
97
+
98
+ if len(out_json['annotations']) == 0:
99
+ out_json.pop('annotations')
100
+
101
+ dump(out_json, out_json_name)
102
+ return out_json
103
+
104
+
105
+ def parse_args():
106
+ parser = argparse.ArgumentParser(
107
+ description='Convert WHU Building annotations to COCO format')
108
+ parser.add_argument('--cityscapes_path', default='/Users/kyanchen/datasets/Building/WHU', help='cityscapes data path')
109
+ parser.add_argument('--img-dir', default='', type=str)
110
+ parser.add_argument('--gt-dir', default='', type=str)
111
+ parser.add_argument('-o', '--out-dir', default='/Users/kyanchen/datasets/Building/WHU/annotations', help='output path')
112
+ parser.add_argument(
113
+ '--nproc', default=0, type=int, help='number of process')
114
+ args = parser.parse_args()
115
+ return args
116
+
117
+
118
+ def main():
119
+ args = parse_args()
120
+ cityscapes_path = args.cityscapes_path
121
+ out_dir = args.out_dir if args.out_dir else cityscapes_path
122
+ mkdir_or_exist(out_dir)
123
+
124
+ img_dir = osp.join(cityscapes_path, args.img_dir)
125
+ gt_dir = osp.join(cityscapes_path, args.gt_dir)
126
+
127
+ set_name = dict(
128
+ train='WHU_building_train.json',
129
+ val='WHU_building_val.json',
130
+ test='WHU_building_test.json'
131
+ )
132
+
133
+ for split, json_name in set_name.items():
134
+ print(f'Converting {split} into {json_name}')
135
+ with Timer(print_tmpl='It took {}s to convert Cityscapes annotation'):
136
+ files = collect_files(
137
+ osp.join(img_dir, split), osp.join(gt_dir, split))
138
+ image_infos = collect_annotations(files, nproc=args.nproc)
139
+ cvt_annotations(image_infos, osp.join(out_dir, json_name))
140
+
141
+
142
+ if __name__ == '__main__':
143
+ main()
tools/ins_seg/sam/sam_cls/get_sam_cls_crops.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import sys
4
+ sys.path.append(sys.path[0] + '/../../../..')
5
+ import torch
6
+ from mmengine import Config, ProgressBar
7
+ from torchvision.transforms import InterpolationMode
8
+ from mmpl.registry import DATASETS
9
+ from tools.ins_seg.sam.sam_cls.segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
10
+ import torch
11
+ from mmdet.evaluation.functional import bbox_overlaps
12
+
13
+
14
+ device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
15
+ BICUBIC = InterpolationMode.BICUBIC
16
+
17
+ data_set_name = 'whu'
18
+
19
+ config_file = f'configs/ins_seg/samcls_{data_set_name}_config.py'
20
+ # sam_checkpoint = "pretrain/sam/sam_vit_b_01ec64.pth"
21
+ # model_type = "vit_b"
22
+ sam_checkpoint = "pretrain/sam/sam_vit_h_4b8939.pth"
23
+ model_type = "vit_h"
24
+ phase = 'val'
25
+
26
+ cache_data_root = f'/data/kyanchen/cache_data/ins_seg/sam_cls/{data_set_name}'
27
+ cache_data_root = os.path.join(cache_data_root, phase)
28
+ if not os.path.exists(cache_data_root):
29
+ os.makedirs(cache_data_root)
30
+
31
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
32
+ sam.to(device=device)
33
+ mask_generator = SamAutomaticMaskGenerator(
34
+ sam,
35
+ pred_iou_thresh=0.5,
36
+ box_nms_thresh=0.5,
37
+ stability_score_thresh=0.6,
38
+ min_mask_region_area=16,
39
+ crop_nms_thresh=0.6,
40
+ )
41
+
42
+ cfg = Config.fromfile(config_file)
43
+ dataset_cfg = cfg.get('datamodule_cfg')
44
+ dataset_cfg = dataset_cfg.get(f'{phase}_loader').dataset
45
+ dataset = DATASETS.build(dataset_cfg)
46
+ class_names = dataset.METAINFO['classes']
47
+
48
+ progress_bar = ProgressBar(len(dataset))
49
+ expand_ratio = 2
50
+ iou_thresh = 0.2
51
+ # 1741 2700 3700
52
+ for index in list(range(len(dataset)))[:500]:
53
+ print(index)
54
+ x = dataset[index]
55
+ img_file = x['data_samples'].img_path
56
+ gt_bbox = x['data_samples'].gt_instances.bboxes
57
+ labels = x['data_samples'].gt_instances.labels
58
+ gt_bbox = gt_bbox.tensor
59
+ # image = cv2.imread(img_file)
60
+ # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
61
+ image = x['inputs'].permute(1, 2, 0).numpy()[..., ::-1]
62
+ masks = mask_generator.generate(image)
63
+ pred_boxes = [mask['bbox'] for mask in masks]
64
+ pred_boxes = torch.tensor(pred_boxes, dtype=torch.float32, device=device)
65
+ pred_boxes[:, 2:] += pred_boxes[:, :2]
66
+ # # debug to show the image
67
+ # img = image.copy().astype('uint8')
68
+ # for gt_box in gt_bbox:
69
+ # gt_box = gt_box.cpu().numpy().astype(int)
70
+ # cv2.rectangle(img, (gt_box[0], gt_box[1]), (gt_box[2], gt_box[3]), (0, 255, 0), 2)
71
+ # for gt_box in pred_boxes:
72
+ # gt_box = gt_box.cpu().numpy().astype(int)
73
+ # cv2.rectangle(img, (gt_box[0], gt_box[1]), (gt_box[2], gt_box[3]), (0, 0, 255), 2)
74
+ # cv2.imshow("image", img.astype('uint8'))
75
+ # cv2.waitKey(0)
76
+
77
+ ious = bbox_overlaps(gt_bbox.cpu().numpy(), pred_boxes.cpu().numpy())
78
+ idxs = ious.argmax(axis=0)
79
+ ious = ious[idxs, range(ious.shape[1])]
80
+ ious_mask = ious > iou_thresh
81
+
82
+ for idx, mask in enumerate(masks):
83
+ # expand box
84
+ x, y, w, h = mask['bbox']
85
+ x = x + w // 2
86
+ y = y + h // 2
87
+ w = int(w * expand_ratio)
88
+ h = int(h * expand_ratio)
89
+ l = int(x - w // 2)
90
+ t = int(y - h // 2)
91
+ r = int(x + w // 2)
92
+ b = int(y + h // 2)
93
+ l = max(0, l)
94
+ t = max(0, t)
95
+ r = min(image.shape[1], r)
96
+ b = min(image.shape[0], b)
97
+ if r - l < 16 or b - t < 16:
98
+ continue
99
+
100
+ # blur image
101
+ blur_image = image.copy()
102
+ blur_image = cv2.blur(blur_image, (7, 7))
103
+ blur_image[mask['segmentation']] = image[mask['segmentation']]
104
+ crop_image = blur_image[t:b, l:r]
105
+
106
+ # # # debug to show the image
107
+ # cv2.imshow("crop_image", crop_image)
108
+ # seg_mask = image.copy()
109
+ # seg_mask[mask['segmentation']] = (0, 0, 255)
110
+ # seg_mask[mask['segmentation']] = 0.5 * seg_mask[mask['segmentation']] + 0.5 * image[mask['segmentation']]
111
+ # cv2.imshow("image", seg_mask.astype('uint8'))
112
+ # cv2.waitKey(0)
113
+
114
+ label = 255
115
+ if ious_mask[idx]:
116
+ label = labels[idxs[idx]].item()
117
+ cv2.imwrite(os.path.join(cache_data_root, f"{index}_{idx}_crop_{label}.jpg"), crop_image)
118
+
119
+ progress_bar.update()
120
+
tools/ins_seg/sam/sam_cls/get_sam_cls_metrics.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import cv2
4
+ sys.path.append(sys.path[0] + '/../../../..')
5
+ from mmengine import Config, ProgressBar
6
+ from mmengine.dataset import Compose
7
+ from mmengine.structures import InstanceData
8
+ from torchvision.transforms import InterpolationMode
9
+
10
+ from mmpl.registry import DATASETS, MODELS, METRICS
11
+ from tools.ins_seg.sam.sam_cls.segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
12
+ import torch
13
+ import mmpl.evaluation
14
+ import torch.nn.functional as F
15
+
16
+ device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
17
+
18
+ dataset_name = 'whu'
19
+ seg_model_cfg_file = f'configs/ins_seg/samcls_{dataset_name}_config.py'
20
+ # sam_checkpoint = "pretrain/sam/sam_vit_b_01ec64.pth"
21
+ # model_type = "vit_b"
22
+ sam_checkpoint = "pretrain/sam/sam_vit_h_4b8939.pth"
23
+ model_type = "vit_h"
24
+
25
+
26
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
27
+ sam.to(device=device)
28
+ sam.eval()
29
+ mask_generator = SamAutomaticMaskGenerator(
30
+ sam,
31
+ pred_iou_thresh=0.5,
32
+ box_nms_thresh=0.5,
33
+ stability_score_thresh=0.6,
34
+ min_mask_region_area=16,
35
+ crop_nms_thresh=0.6,
36
+ )
37
+
38
+ # cls model
39
+ cls_model_cfg_file = f'configs/ins_seg/samcls_res18_{dataset_name}_config.py'
40
+ cls_ckpt = 'results/whu_ins/E20230607_0/checkpoints/epoch_epoch=59-map_valmulticlassaccuracy_0=0.9516.ckpt'
41
+ cls_cfg = Config.fromfile(cls_model_cfg_file)
42
+ cls_model_cfg = cls_cfg.get('model_cfg').whole_model
43
+ cls_model = MODELS.build(cls_model_cfg)
44
+ cls_state_dict = torch.load(cls_ckpt, map_location='cpu')['state_dict']
45
+ cls_state_dict = {k.replace('whole_model.', ''): v for k, v in cls_state_dict.items()}
46
+ cls_model.load_state_dict(cls_state_dict, strict=True)
47
+ cls_model.to(device=device)
48
+ cls_model.eval()
49
+
50
+ cls_transform_cfg = cls_cfg.get('datamodule_cfg').val_loader.dataset.pipeline[1:]
51
+ cls_transforms = Compose(cls_transform_cfg)
52
+
53
+ seg_cfg = Config.fromfile(seg_model_cfg_file)
54
+ seg_dataset_cfg = seg_cfg.get('datamodule_cfg').val_loader.dataset
55
+ seg_dataset = DATASETS.build(seg_dataset_cfg)
56
+ val_evaluator_cfg = seg_cfg['evaluator'].val_evaluator
57
+ val_evaluator = METRICS.build(val_evaluator_cfg)
58
+
59
+ val_evaluator.dataset_meta = seg_dataset.metainfo
60
+
61
+
62
+ progress_bar = ProgressBar(len(seg_dataset))
63
+ expand_ratio = 2
64
+ for index in range(len(seg_dataset)):
65
+ seg_data = seg_dataset[index]
66
+ img_file = seg_data['data_samples'].img_path
67
+ image = cv2.imread(img_file)
68
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
69
+ masks = mask_generator.generate(image)
70
+
71
+ refined_masks = []
72
+ for mask in masks:
73
+ bbbox = mask['bbox']
74
+ x, y, w, h = bbbox
75
+ if w < 4 or h < 4:
76
+ continue
77
+ refined_masks.append(mask)
78
+ masks = refined_masks
79
+ mask_cls = []
80
+ for mask in masks:
81
+ # expand box
82
+ x, y, w, h = mask['bbox']
83
+ x = x + w // 2
84
+ y = y + h // 2
85
+ w = int(w * expand_ratio)
86
+ h = int(h * expand_ratio)
87
+ l = int(x - w // 2)
88
+ t = int(y - h // 2)
89
+ r = int(x + w // 2)
90
+ b = int(y + h // 2)
91
+ l = max(0, l)
92
+ t = max(0, t)
93
+ r = min(image.shape[1], r)
94
+ b = min(image.shape[0], b)
95
+
96
+ # blur image
97
+ blur_image = image.copy()
98
+ blur_image = cv2.blur(blur_image, (7, 7))
99
+ blur_image[mask['segmentation']] = image[mask['segmentation']]
100
+ crop_image = blur_image[t:b, l:r]
101
+
102
+ # # debug to show the image
103
+ # cv2.imshow("crop_image", crop_image)
104
+ # seg_mask = image.copy()
105
+ # seg_mask[mask['segmentation']] = (0, 0, 255)
106
+ # seg_mask[mask['segmentation']] = 0.5 * seg_mask[mask['segmentation']] + 0.5 * image[mask['segmentation']]
107
+ # cv2.imshow("image", seg_mask.astype('uint8'))
108
+ # cv2.waitKey(0)
109
+ results = {
110
+ 'img': crop_image,
111
+ }
112
+ transform_crop_data = cls_transforms(results)
113
+ transform_crop_data['inputs'] = transform_crop_data['inputs'].unsqueeze(0).to(device)
114
+ transform_crop_data['data_samples'] = [transform_crop_data['data_samples']]
115
+ data = cls_model.data_preprocessor(transform_crop_data, False)
116
+ results = cls_model._run_forward(data, mode='predict')
117
+
118
+ mask_cls.append(results[0].pred_score)
119
+
120
+ mask_pred = torch.stack([torch.from_numpy(mask['segmentation']) for mask in masks], dim=0).to(device=device)
121
+ bbox_pred = [mask['bbox'] for mask in masks]
122
+ bbox_pred = torch.tensor(bbox_pred, dtype=torch.float32, device=device)
123
+ bbox_pred[:, 2:] += bbox_pred[:, :2]
124
+
125
+ mask_cls = torch.stack(mask_cls, dim=0)
126
+ max_per_image = 100
127
+ num_queries = mask_cls.shape[0]
128
+ num_classes = mask_cls.shape[-1] - 1
129
+ scores = F.softmax(mask_cls, dim=-1)[:, :-1]
130
+ labels = torch.arange(num_classes, device=mask_cls.device).unsqueeze(0).repeat(num_queries, 1).flatten(0, 1)
131
+ try:
132
+ scores_per_image, top_indices = scores.flatten(0, 1).topk(max_per_image, sorted=False)
133
+ except Exception as e:
134
+ print(e)
135
+ continue
136
+
137
+ labels_per_image = labels[top_indices]
138
+ query_indices = top_indices // num_classes
139
+ mask_pred = mask_pred[query_indices]
140
+ bbox_pred = bbox_pred[query_indices]
141
+
142
+ # # extract things
143
+ # is_thing = labels_per_image < self.num_things_classes
144
+ # scores_per_image = scores_per_image[is_thing]
145
+ # labels_per_image = labels_per_image[is_thing]
146
+ # mask_pred = mask_pred[is_thing]
147
+
148
+ mask_pred_binary = (mask_pred > 0).float()
149
+ # mask_scores_per_image = (mask_pred.sigmoid() *
150
+ # mask_pred_binary).flatten(1).sum(1) / (
151
+ # mask_pred_binary.flatten(1).sum(1) + 1e-6)
152
+ # det_scores = scores_per_image * mask_scores_per_image
153
+ det_scores = scores_per_image
154
+ mask_pred_binary = mask_pred_binary.bool()
155
+
156
+ results = InstanceData()
157
+ results.bboxes = bbox_pred
158
+ results.labels = labels_per_image
159
+ results.scores = det_scores
160
+ results.masks = mask_pred_binary
161
+
162
+ data_samples = seg_data['data_samples']
163
+ data_samples.pred_instances = results
164
+
165
+ val_evaluator.update(None, [data_samples])
166
+ progress_bar.update()
167
+
168
+ metrics = val_evaluator.compute()
169
+ print(metrics)
170
+
tools/ins_seg/sam/sam_cls/segment_anything/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .build_sam import (
8
+ build_sam,
9
+ build_sam_vit_h,
10
+ build_sam_vit_l,
11
+ build_sam_vit_b,
12
+ sam_model_registry,
13
+ )
14
+ from .predictor import SamPredictor
15
+ from .automatic_mask_generator import SamAutomaticMaskGenerator
tools/ins_seg/sam/sam_cls/segment_anything/automatic_mask_generator.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torchvision.ops.boxes import batched_nms, box_area # type: ignore
10
+
11
+ from typing import Any, Dict, List, Optional, Tuple
12
+
13
+ from .modeling import Sam
14
+ from .predictor import SamPredictor
15
+ from .utils.amg import (
16
+ MaskData,
17
+ area_from_rle,
18
+ batch_iterator,
19
+ batched_mask_to_box,
20
+ box_xyxy_to_xywh,
21
+ build_all_layer_point_grids,
22
+ calculate_stability_score,
23
+ coco_encode_rle,
24
+ generate_crop_boxes,
25
+ is_box_near_crop_edge,
26
+ mask_to_rle_pytorch,
27
+ remove_small_regions,
28
+ rle_to_mask,
29
+ uncrop_boxes_xyxy,
30
+ uncrop_masks,
31
+ uncrop_points,
32
+ )
33
+
34
+
35
+ class SamAutomaticMaskGenerator:
36
+ def __init__(
37
+ self,
38
+ model: Sam,
39
+ points_per_side: Optional[int] = 32,
40
+ points_per_batch: int = 64,
41
+ pred_iou_thresh: float = 0.88,
42
+ stability_score_thresh: float = 0.95,
43
+ stability_score_offset: float = 1.0,
44
+ box_nms_thresh: float = 0.7,
45
+ crop_n_layers: int = 0,
46
+ crop_nms_thresh: float = 0.7,
47
+ crop_overlap_ratio: float = 512 / 1500,
48
+ crop_n_points_downscale_factor: int = 1,
49
+ point_grids: Optional[List[np.ndarray]] = None,
50
+ min_mask_region_area: int = 0,
51
+ output_mode: str = "binary_mask",
52
+ ) -> None:
53
+ """
54
+ Using a SAM model, generates masks for the entire image.
55
+ Generates a grid of point prompts over the image, then filters
56
+ low quality and duplicate masks. The default settings are chosen
57
+ for SAM with a ViT-H backbone.
58
+
59
+ Arguments:
60
+ model (Sam): The SAM model to use for mask prediction.
61
+ points_per_side (int or None): The number of points to be sampled
62
+ along one side of the image. The total number of points is
63
+ points_per_side**2. If None, 'point_grids' must provide explicit
64
+ point sampling.
65
+ points_per_batch (int): Sets the number of points run simultaneously
66
+ by the model. Higher numbers may be faster but use more GPU memory.
67
+ pred_iou_thresh (float): A filtering threshold in [0,1], using the
68
+ model's predicted mask quality.
69
+ stability_score_thresh (float): A filtering threshold in [0,1], using
70
+ the stability of the mask under changes to the cutoff used to binarize
71
+ the model's mask predictions.
72
+ stability_score_offset (float): The amount to shift the cutoff when
73
+ calculated the stability score.
74
+ box_nms_thresh (float): The box IoU cutoff used by non-maximal
75
+ suppression to filter duplicate masks.
76
+ crop_n_layers (int): If >0, mask prediction will be run again on
77
+ crops of the image. Sets the number of layers to run, where each
78
+ layer has 2**i_layer number of image crops.
79
+ crop_nms_thresh (float): The box IoU cutoff used by non-maximal
80
+ suppression to filter duplicate masks between different crops.
81
+ crop_overlap_ratio (float): Sets the degree to which crops overlap.
82
+ In the first crop layer, crops will overlap by this fraction of
83
+ the image length. Later layers with more crops scale down this overlap.
84
+ crop_n_points_downscale_factor (int): The number of points-per-side
85
+ sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
86
+ point_grids (list(np.ndarray) or None): A list over explicit grids
87
+ of points used for sampling, normalized to [0,1]. The nth grid in the
88
+ list is used in the nth crop layer. Exclusive with points_per_side.
89
+ min_mask_region_area (int): If >0, postprocessing will be applied
90
+ to remove disconnected regions and holes in masks with area smaller
91
+ than min_mask_region_area. Requires opencv.
92
+ output_mode (str): The form masks are returned in. Can be 'binary_mask',
93
+ 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
94
+ For large resolutions, 'binary_mask' may consume large amounts of
95
+ memory.
96
+ """
97
+
98
+ assert (points_per_side is None) != (
99
+ point_grids is None
100
+ ), "Exactly one of points_per_side or point_grid must be provided."
101
+ if points_per_side is not None:
102
+ self.point_grids = build_all_layer_point_grids(
103
+ points_per_side,
104
+ crop_n_layers,
105
+ crop_n_points_downscale_factor,
106
+ )
107
+ elif point_grids is not None:
108
+ self.point_grids = point_grids
109
+ else:
110
+ raise ValueError("Can't have both points_per_side and point_grid be None.")
111
+
112
+ assert output_mode in [
113
+ "binary_mask",
114
+ "uncompressed_rle",
115
+ "coco_rle",
116
+ ], f"Unknown output_mode {output_mode}."
117
+ if output_mode == "coco_rle":
118
+ from pycocotools import mask as mask_utils # type: ignore # noqa: F401
119
+
120
+ if min_mask_region_area > 0:
121
+ import cv2 # type: ignore # noqa: F401
122
+
123
+ self.predictor = SamPredictor(model)
124
+ self.points_per_batch = points_per_batch
125
+ self.pred_iou_thresh = pred_iou_thresh
126
+ self.stability_score_thresh = stability_score_thresh
127
+ self.stability_score_offset = stability_score_offset
128
+ self.box_nms_thresh = box_nms_thresh
129
+ self.crop_n_layers = crop_n_layers
130
+ self.crop_nms_thresh = crop_nms_thresh
131
+ self.crop_overlap_ratio = crop_overlap_ratio
132
+ self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
133
+ self.min_mask_region_area = min_mask_region_area
134
+ self.output_mode = output_mode
135
+
136
+ @torch.no_grad()
137
+ def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
138
+ """
139
+ Generates masks for the given image.
140
+
141
+ Arguments:
142
+ image (np.ndarray): The image to generate masks for, in HWC uint8 format.
143
+
144
+ Returns:
145
+ list(dict(str, any)): A list over records for masks. Each record is
146
+ a dict containing the following keys:
147
+ segmentation (dict(str, any) or np.ndarray): The mask. If
148
+ output_mode='binary_mask', is an array of shape HW. Otherwise,
149
+ is a dictionary containing the RLE.
150
+ bbox (list(float)): The box around the mask, in XYWH format.
151
+ area (int): The area in pixels of the mask.
152
+ predicted_iou (float): The model's own prediction of the mask's
153
+ quality. This is filtered by the pred_iou_thresh parameter.
154
+ point_coords (list(list(float))): The point coordinates input
155
+ to the model to generate this mask.
156
+ stability_score (float): A measure of the mask's quality. This
157
+ is filtered on using the stability_score_thresh parameter.
158
+ crop_box (list(float)): The crop of the image used to generate
159
+ the mask, given in XYWH format.
160
+ """
161
+
162
+ # Generate masks
163
+ mask_data = self._generate_masks(image)
164
+
165
+ # Filter small disconnected regions and holes in masks
166
+ if self.min_mask_region_area > 0:
167
+ mask_data = self.postprocess_small_regions(
168
+ mask_data,
169
+ self.min_mask_region_area,
170
+ max(self.box_nms_thresh, self.crop_nms_thresh),
171
+ )
172
+
173
+ # Encode masks
174
+ if self.output_mode == "coco_rle":
175
+ mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
176
+ elif self.output_mode == "binary_mask":
177
+ mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
178
+ else:
179
+ mask_data["segmentations"] = mask_data["rles"]
180
+
181
+ # Write mask records
182
+ curr_anns = []
183
+ for idx in range(len(mask_data["segmentations"])):
184
+ ann = {
185
+ "segmentation": mask_data["segmentations"][idx],
186
+ "area": area_from_rle(mask_data["rles"][idx]),
187
+ "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
188
+ "predicted_iou": mask_data["iou_preds"][idx].item(),
189
+ "point_coords": [mask_data["points"][idx].tolist()],
190
+ "stability_score": mask_data["stability_score"][idx].item(),
191
+ "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
192
+ }
193
+ curr_anns.append(ann)
194
+
195
+ return curr_anns
196
+
197
+ def _generate_masks(self, image: np.ndarray) -> MaskData:
198
+ orig_size = image.shape[:2]
199
+ crop_boxes, layer_idxs = generate_crop_boxes(
200
+ orig_size, self.crop_n_layers, self.crop_overlap_ratio
201
+ )
202
+
203
+ # Iterate over image crops
204
+ data = MaskData()
205
+ for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
206
+ crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
207
+ data.cat(crop_data)
208
+
209
+ # Remove duplicate masks between crops
210
+ if len(crop_boxes) > 1:
211
+ # Prefer masks from smaller crops
212
+ scores = 1 / box_area(data["crop_boxes"])
213
+ scores = scores.to(data["boxes"].device)
214
+ keep_by_nms = batched_nms(
215
+ data["boxes"].float(),
216
+ scores,
217
+ torch.zeros_like(data["boxes"][:, 0]), # categories
218
+ iou_threshold=self.crop_nms_thresh,
219
+ )
220
+ data.filter(keep_by_nms)
221
+
222
+ data.to_numpy()
223
+ return data
224
+
225
+ def _process_crop(
226
+ self,
227
+ image: np.ndarray,
228
+ crop_box: List[int],
229
+ crop_layer_idx: int,
230
+ orig_size: Tuple[int, ...],
231
+ ) -> MaskData:
232
+ # Crop the image and calculate embeddings
233
+ x0, y0, x1, y1 = crop_box
234
+ cropped_im = image[y0:y1, x0:x1, :]
235
+ cropped_im_size = cropped_im.shape[:2]
236
+ self.predictor.set_image(cropped_im)
237
+
238
+ # Get points for this crop
239
+ points_scale = np.array(cropped_im_size)[None, ::-1]
240
+ points_for_image = self.point_grids[crop_layer_idx] * points_scale
241
+
242
+ # Generate masks for this crop in batches
243
+ data = MaskData()
244
+ for (points,) in batch_iterator(self.points_per_batch, points_for_image):
245
+ batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
246
+ data.cat(batch_data)
247
+ del batch_data
248
+ self.predictor.reset_image()
249
+
250
+ # Remove duplicates within this crop.
251
+ keep_by_nms = batched_nms(
252
+ data["boxes"].float(),
253
+ data["iou_preds"],
254
+ torch.zeros_like(data["boxes"][:, 0]), # categories
255
+ iou_threshold=self.box_nms_thresh,
256
+ )
257
+ data.filter(keep_by_nms)
258
+
259
+ # Return to the original image frame
260
+ data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
261
+ data["points"] = uncrop_points(data["points"], crop_box)
262
+ data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
263
+
264
+ return data
265
+
266
+ def _process_batch(
267
+ self,
268
+ points: np.ndarray,
269
+ im_size: Tuple[int, ...],
270
+ crop_box: List[int],
271
+ orig_size: Tuple[int, ...],
272
+ ) -> MaskData:
273
+ orig_h, orig_w = orig_size
274
+
275
+ # Run model on this batch
276
+ transformed_points = self.predictor.transform.apply_coords(points, im_size)
277
+ in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
278
+ in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
279
+ masks, iou_preds, _ = self.predictor.predict_torch(
280
+ in_points[:, None, :],
281
+ in_labels[:, None],
282
+ multimask_output=True,
283
+ return_logits=True,
284
+ )
285
+
286
+ # Serialize predictions and store in MaskData
287
+ data = MaskData(
288
+ masks=masks.flatten(0, 1),
289
+ iou_preds=iou_preds.flatten(0, 1),
290
+ points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
291
+ )
292
+ del masks
293
+
294
+ # Filter by predicted IoU
295
+ if self.pred_iou_thresh > 0.0:
296
+ keep_mask = data["iou_preds"] > self.pred_iou_thresh
297
+ data.filter(keep_mask)
298
+
299
+ # Calculate stability score
300
+ data["stability_score"] = calculate_stability_score(
301
+ data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset
302
+ )
303
+ if self.stability_score_thresh > 0.0:
304
+ keep_mask = data["stability_score"] >= self.stability_score_thresh
305
+ data.filter(keep_mask)
306
+
307
+ # Threshold masks and calculate boxes
308
+ data["masks"] = data["masks"] > self.predictor.model.mask_threshold
309
+ data["boxes"] = batched_mask_to_box(data["masks"])
310
+
311
+ # Filter boxes that touch crop boundaries
312
+ keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
313
+ if not torch.all(keep_mask):
314
+ data.filter(keep_mask)
315
+
316
+ # Compress to RLE
317
+ data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
318
+ data["rles"] = mask_to_rle_pytorch(data["masks"])
319
+ del data["masks"]
320
+
321
+ return data
322
+
323
+ @staticmethod
324
+ def postprocess_small_regions(
325
+ mask_data: MaskData, min_area: int, nms_thresh: float
326
+ ) -> MaskData:
327
+ """
328
+ Removes small disconnected regions and holes in masks, then reruns
329
+ box NMS to remove any new duplicates.
330
+
331
+ Edits mask_data in place.
332
+
333
+ Requires open-cv as a dependency.
334
+ """
335
+ if len(mask_data["rles"]) == 0:
336
+ return mask_data
337
+
338
+ # Filter small disconnected regions and holes
339
+ new_masks = []
340
+ scores = []
341
+ for rle in mask_data["rles"]:
342
+ mask = rle_to_mask(rle)
343
+
344
+ mask, changed = remove_small_regions(mask, min_area, mode="holes")
345
+ unchanged = not changed
346
+ mask, changed = remove_small_regions(mask, min_area, mode="islands")
347
+ unchanged = unchanged and not changed
348
+
349
+ new_masks.append(torch.as_tensor(mask).unsqueeze(0))
350
+ # Give score=0 to changed masks and score=1 to unchanged masks
351
+ # so NMS will prefer ones that didn't need postprocessing
352
+ scores.append(float(unchanged))
353
+
354
+ # Recalculate boxes and remove any new duplicates
355
+ masks = torch.cat(new_masks, dim=0)
356
+ boxes = batched_mask_to_box(masks)
357
+ keep_by_nms = batched_nms(
358
+ boxes.float(),
359
+ torch.as_tensor(scores),
360
+ torch.zeros_like(boxes[:, 0]), # categories
361
+ iou_threshold=nms_thresh,
362
+ )
363
+
364
+ # Only recalculate RLEs for masks that have changed
365
+ for i_mask in keep_by_nms:
366
+ if scores[i_mask] == 0.0:
367
+ mask_torch = masks[i_mask].unsqueeze(0)
368
+ mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
369
+ mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
370
+ mask_data.filter(keep_by_nms)
371
+
372
+ return mask_data
tools/ins_seg/sam/sam_cls/segment_anything/build_sam.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+
9
+ from functools import partial
10
+
11
+ from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
12
+
13
+
14
+ def build_sam_vit_h(checkpoint=None):
15
+ return _build_sam(
16
+ encoder_embed_dim=1280,
17
+ encoder_depth=32,
18
+ encoder_num_heads=16,
19
+ encoder_global_attn_indexes=[7, 15, 23, 31],
20
+ checkpoint=checkpoint,
21
+ )
22
+
23
+
24
+ build_sam = build_sam_vit_h
25
+
26
+
27
+ def build_sam_vit_l(checkpoint=None):
28
+ return _build_sam(
29
+ encoder_embed_dim=1024,
30
+ encoder_depth=24,
31
+ encoder_num_heads=16,
32
+ encoder_global_attn_indexes=[5, 11, 17, 23],
33
+ checkpoint=checkpoint,
34
+ )
35
+
36
+
37
+ def build_sam_vit_b(checkpoint=None):
38
+ return _build_sam(
39
+ encoder_embed_dim=768,
40
+ encoder_depth=12,
41
+ encoder_num_heads=12,
42
+ encoder_global_attn_indexes=[2, 5, 8, 11],
43
+ checkpoint=checkpoint,
44
+ )
45
+
46
+
47
+ sam_model_registry = {
48
+ "default": build_sam_vit_h,
49
+ "vit_h": build_sam_vit_h,
50
+ "vit_l": build_sam_vit_l,
51
+ "vit_b": build_sam_vit_b,
52
+ }
53
+
54
+
55
+ def _build_sam(
56
+ encoder_embed_dim,
57
+ encoder_depth,
58
+ encoder_num_heads,
59
+ encoder_global_attn_indexes,
60
+ checkpoint=None,
61
+ ):
62
+ prompt_embed_dim = 256
63
+ image_size = 1024
64
+ vit_patch_size = 16
65
+ image_embedding_size = image_size // vit_patch_size
66
+ sam = Sam(
67
+ image_encoder=ImageEncoderViT(
68
+ depth=encoder_depth,
69
+ embed_dim=encoder_embed_dim,
70
+ img_size=image_size,
71
+ mlp_ratio=4,
72
+ norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
73
+ num_heads=encoder_num_heads,
74
+ patch_size=vit_patch_size,
75
+ qkv_bias=True,
76
+ use_rel_pos=True,
77
+ global_attn_indexes=encoder_global_attn_indexes,
78
+ window_size=14,
79
+ out_chans=prompt_embed_dim,
80
+ ),
81
+ prompt_encoder=PromptEncoder(
82
+ embed_dim=prompt_embed_dim,
83
+ image_embedding_size=(image_embedding_size, image_embedding_size),
84
+ input_image_size=(image_size, image_size),
85
+ mask_in_chans=16,
86
+ ),
87
+ mask_decoder=MaskDecoder(
88
+ num_multimask_outputs=3,
89
+ transformer=TwoWayTransformer(
90
+ depth=2,
91
+ embedding_dim=prompt_embed_dim,
92
+ mlp_dim=2048,
93
+ num_heads=8,
94
+ ),
95
+ transformer_dim=prompt_embed_dim,
96
+ iou_head_depth=3,
97
+ iou_head_hidden_dim=256,
98
+ ),
99
+ pixel_mean=[123.675, 116.28, 103.53],
100
+ pixel_std=[58.395, 57.12, 57.375],
101
+ )
102
+ sam.eval()
103
+ if checkpoint is not None:
104
+ with open(checkpoint, "rb") as f:
105
+ state_dict = torch.load(f)
106
+ sam.load_state_dict(state_dict)
107
+ return sam
tools/ins_seg/sam/sam_cls/segment_anything/modeling/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .sam import Sam
8
+ from .image_encoder import ImageEncoderViT
9
+ from .mask_decoder import MaskDecoder
10
+ from .prompt_encoder import PromptEncoder
11
+ from .transformer import TwoWayTransformer
tools/ins_seg/sam/sam_cls/segment_anything/modeling/common.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from typing import Type
11
+
12
+
13
+ class MLPBlock(nn.Module):
14
+ def __init__(
15
+ self,
16
+ embedding_dim: int,
17
+ mlp_dim: int,
18
+ act: Type[nn.Module] = nn.GELU,
19
+ ) -> None:
20
+ super().__init__()
21
+ self.lin1 = nn.Linear(embedding_dim, mlp_dim)
22
+ self.lin2 = nn.Linear(mlp_dim, embedding_dim)
23
+ self.act = act()
24
+
25
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
26
+ return self.lin2(self.act(self.lin1(x)))
27
+
28
+
29
+ # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
30
+ # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
31
+ class LayerNorm2d(nn.Module):
32
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
33
+ super().__init__()
34
+ self.weight = nn.Parameter(torch.ones(num_channels))
35
+ self.bias = nn.Parameter(torch.zeros(num_channels))
36
+ self.eps = eps
37
+
38
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
39
+ u = x.mean(1, keepdim=True)
40
+ s = (x - u).pow(2).mean(1, keepdim=True)
41
+ x = (x - u) / torch.sqrt(s + self.eps)
42
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
43
+ return x
tools/ins_seg/sam/sam_cls/segment_anything/modeling/image_encoder.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from typing import Optional, Tuple, Type
12
+
13
+ from .common import LayerNorm2d, MLPBlock
14
+
15
+
16
+ # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
17
+ class ImageEncoderViT(nn.Module):
18
+ def __init__(
19
+ self,
20
+ img_size: int = 1024,
21
+ patch_size: int = 16,
22
+ in_chans: int = 3,
23
+ embed_dim: int = 768,
24
+ depth: int = 12,
25
+ num_heads: int = 12,
26
+ mlp_ratio: float = 4.0,
27
+ out_chans: int = 256,
28
+ qkv_bias: bool = True,
29
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
30
+ act_layer: Type[nn.Module] = nn.GELU,
31
+ use_abs_pos: bool = True,
32
+ use_rel_pos: bool = False,
33
+ rel_pos_zero_init: bool = True,
34
+ window_size: int = 0,
35
+ global_attn_indexes: Tuple[int, ...] = (),
36
+ ) -> None:
37
+ """
38
+ Args:
39
+ img_size (int): Input image size.
40
+ patch_size (int): Patch size.
41
+ in_chans (int): Number of input image channels.
42
+ embed_dim (int): Patch embedding dimension.
43
+ depth (int): Depth of ViT.
44
+ num_heads (int): Number of attention heads in each ViT block.
45
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
46
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
47
+ norm_layer (nn.Module): Normalization layer.
48
+ act_layer (nn.Module): Activation layer.
49
+ use_abs_pos (bool): If True, use absolute positional embeddings.
50
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
51
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
52
+ window_size (int): Window size for window attention blocks.
53
+ global_attn_indexes (list): Indexes for blocks using global attention.
54
+ """
55
+ super().__init__()
56
+ self.img_size = img_size
57
+
58
+ self.patch_embed = PatchEmbed(
59
+ kernel_size=(patch_size, patch_size),
60
+ stride=(patch_size, patch_size),
61
+ in_chans=in_chans,
62
+ embed_dim=embed_dim,
63
+ )
64
+
65
+ self.pos_embed: Optional[nn.Parameter] = None
66
+ if use_abs_pos:
67
+ # Initialize absolute positional embedding with pretrain image size.
68
+ self.pos_embed = nn.Parameter(
69
+ torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
70
+ )
71
+
72
+ self.blocks = nn.ModuleList()
73
+ for i in range(depth):
74
+ block = Block(
75
+ dim=embed_dim,
76
+ num_heads=num_heads,
77
+ mlp_ratio=mlp_ratio,
78
+ qkv_bias=qkv_bias,
79
+ norm_layer=norm_layer,
80
+ act_layer=act_layer,
81
+ use_rel_pos=use_rel_pos,
82
+ rel_pos_zero_init=rel_pos_zero_init,
83
+ window_size=window_size if i not in global_attn_indexes else 0,
84
+ input_size=(img_size // patch_size, img_size // patch_size),
85
+ )
86
+ self.blocks.append(block)
87
+
88
+ self.neck = nn.Sequential(
89
+ nn.Conv2d(
90
+ embed_dim,
91
+ out_chans,
92
+ kernel_size=1,
93
+ bias=False,
94
+ ),
95
+ LayerNorm2d(out_chans),
96
+ nn.Conv2d(
97
+ out_chans,
98
+ out_chans,
99
+ kernel_size=3,
100
+ padding=1,
101
+ bias=False,
102
+ ),
103
+ LayerNorm2d(out_chans),
104
+ )
105
+
106
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
107
+ x = self.patch_embed(x)
108
+ if self.pos_embed is not None:
109
+ x = x + self.pos_embed
110
+
111
+ for blk in self.blocks:
112
+ x = blk(x)
113
+
114
+ x = self.neck(x.permute(0, 3, 1, 2))
115
+
116
+ return x
117
+
118
+
119
+ class Block(nn.Module):
120
+ """Transformer blocks with support of window attention and residual propagation blocks"""
121
+
122
+ def __init__(
123
+ self,
124
+ dim: int,
125
+ num_heads: int,
126
+ mlp_ratio: float = 4.0,
127
+ qkv_bias: bool = True,
128
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
129
+ act_layer: Type[nn.Module] = nn.GELU,
130
+ use_rel_pos: bool = False,
131
+ rel_pos_zero_init: bool = True,
132
+ window_size: int = 0,
133
+ input_size: Optional[Tuple[int, int]] = None,
134
+ ) -> None:
135
+ """
136
+ Args:
137
+ dim (int): Number of input channels.
138
+ num_heads (int): Number of attention heads in each ViT block.
139
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
140
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
141
+ norm_layer (nn.Module): Normalization layer.
142
+ act_layer (nn.Module): Activation layer.
143
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
144
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
145
+ window_size (int): Window size for window attention blocks. If it equals 0, then
146
+ use global attention.
147
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
148
+ positional parameter size.
149
+ """
150
+ super().__init__()
151
+ self.norm1 = norm_layer(dim)
152
+ self.attn = Attention(
153
+ dim,
154
+ num_heads=num_heads,
155
+ qkv_bias=qkv_bias,
156
+ use_rel_pos=use_rel_pos,
157
+ rel_pos_zero_init=rel_pos_zero_init,
158
+ input_size=input_size if window_size == 0 else (window_size, window_size),
159
+ )
160
+
161
+ self.norm2 = norm_layer(dim)
162
+ self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
163
+
164
+ self.window_size = window_size
165
+
166
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
167
+ shortcut = x
168
+ x = self.norm1(x)
169
+ # Window partition
170
+ if self.window_size > 0:
171
+ H, W = x.shape[1], x.shape[2]
172
+ x, pad_hw = window_partition(x, self.window_size)
173
+
174
+ x = self.attn(x)
175
+ # Reverse window partition
176
+ if self.window_size > 0:
177
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
178
+
179
+ x = shortcut + x
180
+ x = x + self.mlp(self.norm2(x))
181
+
182
+ return x
183
+
184
+
185
+ class Attention(nn.Module):
186
+ """Multi-head Attention block with relative position embeddings."""
187
+
188
+ def __init__(
189
+ self,
190
+ dim: int,
191
+ num_heads: int = 8,
192
+ qkv_bias: bool = True,
193
+ use_rel_pos: bool = False,
194
+ rel_pos_zero_init: bool = True,
195
+ input_size: Optional[Tuple[int, int]] = None,
196
+ ) -> None:
197
+ """
198
+ Args:
199
+ dim (int): Number of input channels.
200
+ num_heads (int): Number of attention heads.
201
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
202
+ rel_pos (bool): If True, add relative positional embeddings to the attention map.
203
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
204
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
205
+ positional parameter size.
206
+ """
207
+ super().__init__()
208
+ self.num_heads = num_heads
209
+ head_dim = dim // num_heads
210
+ self.scale = head_dim**-0.5
211
+
212
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
213
+ self.proj = nn.Linear(dim, dim)
214
+
215
+ self.use_rel_pos = use_rel_pos
216
+ if self.use_rel_pos:
217
+ assert (
218
+ input_size is not None
219
+ ), "Input size must be provided if using relative positional encoding."
220
+ # initialize relative positional embeddings
221
+ self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
222
+ self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
223
+
224
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
225
+ B, H, W, _ = x.shape
226
+ # qkv with shape (3, B, nHead, H * W, C)
227
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
228
+ # q, k, v with shape (B * nHead, H * W, C)
229
+ q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
230
+
231
+ attn = (q * self.scale) @ k.transpose(-2, -1)
232
+
233
+ if self.use_rel_pos:
234
+ attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
235
+
236
+ attn = attn.softmax(dim=-1)
237
+ x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
238
+ x = self.proj(x)
239
+
240
+ return x
241
+
242
+
243
+ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
244
+ """
245
+ Partition into non-overlapping windows with padding if needed.
246
+ Args:
247
+ x (tensor): input tokens with [B, H, W, C].
248
+ window_size (int): window size.
249
+
250
+ Returns:
251
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
252
+ (Hp, Wp): padded height and width before partition
253
+ """
254
+ B, H, W, C = x.shape
255
+
256
+ pad_h = (window_size - H % window_size) % window_size
257
+ pad_w = (window_size - W % window_size) % window_size
258
+ if pad_h > 0 or pad_w > 0:
259
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
260
+ Hp, Wp = H + pad_h, W + pad_w
261
+
262
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
263
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
264
+ return windows, (Hp, Wp)
265
+
266
+
267
+ def window_unpartition(
268
+ windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
269
+ ) -> torch.Tensor:
270
+ """
271
+ Window unpartition into original sequences and removing padding.
272
+ Args:
273
+ windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
274
+ window_size (int): window size.
275
+ pad_hw (Tuple): padded height and width (Hp, Wp).
276
+ hw (Tuple): original height and width (H, W) before padding.
277
+
278
+ Returns:
279
+ x: unpartitioned sequences with [B, H, W, C].
280
+ """
281
+ Hp, Wp = pad_hw
282
+ H, W = hw
283
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
284
+ x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
285
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
286
+
287
+ if Hp > H or Wp > W:
288
+ x = x[:, :H, :W, :].contiguous()
289
+ return x
290
+
291
+
292
+ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
293
+ """
294
+ Get relative positional embeddings according to the relative positions of
295
+ query and key sizes.
296
+ Args:
297
+ q_size (int): size of query q.
298
+ k_size (int): size of key k.
299
+ rel_pos (Tensor): relative position embeddings (L, C).
300
+
301
+ Returns:
302
+ Extracted positional embeddings according to relative positions.
303
+ """
304
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
305
+ # Interpolate rel pos if needed.
306
+ if rel_pos.shape[0] != max_rel_dist:
307
+ # Interpolate rel pos.
308
+ rel_pos_resized = F.interpolate(
309
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
310
+ size=max_rel_dist,
311
+ mode="linear",
312
+ )
313
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
314
+ else:
315
+ rel_pos_resized = rel_pos
316
+
317
+ # Scale the coords with short length if shapes for q and k are different.
318
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
319
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
320
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
321
+
322
+ return rel_pos_resized[relative_coords.long()]
323
+
324
+
325
+ def add_decomposed_rel_pos(
326
+ attn: torch.Tensor,
327
+ q: torch.Tensor,
328
+ rel_pos_h: torch.Tensor,
329
+ rel_pos_w: torch.Tensor,
330
+ q_size: Tuple[int, int],
331
+ k_size: Tuple[int, int],
332
+ ) -> torch.Tensor:
333
+ """
334
+ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
335
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
336
+ Args:
337
+ attn (Tensor): attention map.
338
+ q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
339
+ rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
340
+ rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
341
+ q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
342
+ k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
343
+
344
+ Returns:
345
+ attn (Tensor): attention map with added relative positional embeddings.
346
+ """
347
+ q_h, q_w = q_size
348
+ k_h, k_w = k_size
349
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
350
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
351
+
352
+ B, _, dim = q.shape
353
+ r_q = q.reshape(B, q_h, q_w, dim)
354
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
355
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
356
+
357
+ attn = (
358
+ attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
359
+ ).view(B, q_h * q_w, k_h * k_w)
360
+
361
+ return attn
362
+
363
+
364
+ class PatchEmbed(nn.Module):
365
+ """
366
+ Image to Patch Embedding.
367
+ """
368
+
369
+ def __init__(
370
+ self,
371
+ kernel_size: Tuple[int, int] = (16, 16),
372
+ stride: Tuple[int, int] = (16, 16),
373
+ padding: Tuple[int, int] = (0, 0),
374
+ in_chans: int = 3,
375
+ embed_dim: int = 768,
376
+ ) -> None:
377
+ """
378
+ Args:
379
+ kernel_size (Tuple): kernel size of the projection layer.
380
+ stride (Tuple): stride of the projection layer.
381
+ padding (Tuple): padding size of the projection layer.
382
+ in_chans (int): Number of input image channels.
383
+ embed_dim (int): Patch embedding dimension.
384
+ """
385
+ super().__init__()
386
+
387
+ self.proj = nn.Conv2d(
388
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
389
+ )
390
+
391
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
392
+ x = self.proj(x)
393
+ # B C H W -> B H W C
394
+ x = x.permute(0, 2, 3, 1)
395
+ return x
tools/ins_seg/sam/sam_cls/segment_anything/modeling/mask_decoder.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+
11
+ from typing import List, Tuple, Type
12
+
13
+ from .common import LayerNorm2d
14
+
15
+
16
+ class MaskDecoder(nn.Module):
17
+ def __init__(
18
+ self,
19
+ *,
20
+ transformer_dim: int,
21
+ transformer: nn.Module,
22
+ num_multimask_outputs: int = 3,
23
+ activation: Type[nn.Module] = nn.GELU,
24
+ iou_head_depth: int = 3,
25
+ iou_head_hidden_dim: int = 256,
26
+ ) -> None:
27
+ """
28
+ Predicts masks given an image and prompt embeddings, using a
29
+ transformer architecture.
30
+
31
+ Arguments:
32
+ transformer_dim (int): the channel dimension of the transformer
33
+ transformer (nn.Module): the transformer used to predict masks
34
+ num_multimask_outputs (int): the number of masks to predict
35
+ when disambiguating masks
36
+ activation (nn.Module): the type of activation to use when
37
+ upscaling masks
38
+ iou_head_depth (int): the depth of the MLP used to predict
39
+ mask quality
40
+ iou_head_hidden_dim (int): the hidden dimension of the MLP
41
+ used to predict mask quality
42
+ """
43
+ super().__init__()
44
+ self.transformer_dim = transformer_dim
45
+ self.transformer = transformer
46
+
47
+ self.num_multimask_outputs = num_multimask_outputs
48
+
49
+ self.iou_token = nn.Embedding(1, transformer_dim)
50
+ self.num_mask_tokens = num_multimask_outputs + 1
51
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
52
+
53
+ self.output_upscaling = nn.Sequential(
54
+ nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
55
+ LayerNorm2d(transformer_dim // 4),
56
+ activation(),
57
+ nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
58
+ activation(),
59
+ )
60
+ self.output_hypernetworks_mlps = nn.ModuleList(
61
+ [
62
+ MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
63
+ for i in range(self.num_mask_tokens)
64
+ ]
65
+ )
66
+
67
+ self.iou_prediction_head = MLP(
68
+ transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
69
+ )
70
+
71
+ def forward(
72
+ self,
73
+ image_embeddings: torch.Tensor,
74
+ image_pe: torch.Tensor,
75
+ sparse_prompt_embeddings: torch.Tensor,
76
+ dense_prompt_embeddings: torch.Tensor,
77
+ multimask_output: bool,
78
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
79
+ """
80
+ Predict masks given image and prompt embeddings.
81
+
82
+ Arguments:
83
+ image_embeddings (torch.Tensor): the embeddings from the image encoder
84
+ image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
85
+ sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
86
+ dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
87
+ multimask_output (bool): Whether to return multiple masks or a single
88
+ mask.
89
+
90
+ Returns:
91
+ torch.Tensor: batched predicted masks
92
+ torch.Tensor: batched predictions of mask quality
93
+ """
94
+ masks, iou_pred = self.predict_masks(
95
+ image_embeddings=image_embeddings,
96
+ image_pe=image_pe,
97
+ sparse_prompt_embeddings=sparse_prompt_embeddings,
98
+ dense_prompt_embeddings=dense_prompt_embeddings,
99
+ )
100
+
101
+ # Select the correct mask or masks for output
102
+ if multimask_output:
103
+ mask_slice = slice(1, None)
104
+ else:
105
+ mask_slice = slice(0, 1)
106
+ masks = masks[:, mask_slice, :, :]
107
+ iou_pred = iou_pred[:, mask_slice]
108
+
109
+ # Prepare output
110
+ return masks, iou_pred
111
+
112
+ def predict_masks(
113
+ self,
114
+ image_embeddings: torch.Tensor,
115
+ image_pe: torch.Tensor,
116
+ sparse_prompt_embeddings: torch.Tensor,
117
+ dense_prompt_embeddings: torch.Tensor,
118
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
119
+ """Predicts masks. See 'forward' for more details."""
120
+ # Concatenate output tokens
121
+ output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
122
+ output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
123
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
124
+
125
+ # Expand per-image data in batch direction to be per-mask
126
+ src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
127
+ src = src + dense_prompt_embeddings
128
+ pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
129
+ b, c, h, w = src.shape
130
+
131
+ # Run the transformer
132
+ hs, src = self.transformer(src, pos_src, tokens)
133
+ iou_token_out = hs[:, 0, :]
134
+ mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
135
+
136
+ # Upscale mask embeddings and predict masks using the mask tokens
137
+ src = src.transpose(1, 2).view(b, c, h, w)
138
+ upscaled_embedding = self.output_upscaling(src)
139
+ hyper_in_list: List[torch.Tensor] = []
140
+ for i in range(self.num_mask_tokens):
141
+ hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
142
+ hyper_in = torch.stack(hyper_in_list, dim=1)
143
+ b, c, h, w = upscaled_embedding.shape
144
+ masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
145
+
146
+ # Generate mask quality predictions
147
+ iou_pred = self.iou_prediction_head(iou_token_out)
148
+
149
+ return masks, iou_pred
150
+
151
+
152
+ # Lightly adapted from
153
+ # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
154
+ class MLP(nn.Module):
155
+ def __init__(
156
+ self,
157
+ input_dim: int,
158
+ hidden_dim: int,
159
+ output_dim: int,
160
+ num_layers: int,
161
+ sigmoid_output: bool = False,
162
+ ) -> None:
163
+ super().__init__()
164
+ self.num_layers = num_layers
165
+ h = [hidden_dim] * (num_layers - 1)
166
+ self.layers = nn.ModuleList(
167
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
168
+ )
169
+ self.sigmoid_output = sigmoid_output
170
+
171
+ def forward(self, x):
172
+ for i, layer in enumerate(self.layers):
173
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
174
+ if self.sigmoid_output:
175
+ x = F.sigmoid(x)
176
+ return x
tools/ins_seg/sam/sam_cls/segment_anything/modeling/prompt_encoder.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torch import nn
10
+
11
+ from typing import Any, Optional, Tuple, Type
12
+
13
+ from .common import LayerNorm2d
14
+
15
+
16
+ class PromptEncoder(nn.Module):
17
+ def __init__(
18
+ self,
19
+ embed_dim: int,
20
+ image_embedding_size: Tuple[int, int],
21
+ input_image_size: Tuple[int, int],
22
+ mask_in_chans: int,
23
+ activation: Type[nn.Module] = nn.GELU,
24
+ ) -> None:
25
+ """
26
+ Encodes prompts for input to SAM's mask decoder.
27
+
28
+ Arguments:
29
+ embed_dim (int): The prompts' embedding dimension
30
+ image_embedding_size (tuple(int, int)): The spatial size of the
31
+ image embedding, as (H, W).
32
+ input_image_size (int): The padded size of the image as input
33
+ to the image encoder, as (H, W).
34
+ mask_in_chans (int): The number of hidden channels used for
35
+ encoding input masks.
36
+ activation (nn.Module): The activation to use when encoding
37
+ input masks.
38
+ """
39
+ super().__init__()
40
+ self.embed_dim = embed_dim
41
+ self.input_image_size = input_image_size
42
+ self.image_embedding_size = image_embedding_size
43
+ self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
44
+
45
+ self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
46
+ point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
47
+ self.point_embeddings = nn.ModuleList(point_embeddings)
48
+ self.not_a_point_embed = nn.Embedding(1, embed_dim)
49
+
50
+ self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])
51
+ self.mask_downscaling = nn.Sequential(
52
+ nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
53
+ LayerNorm2d(mask_in_chans // 4),
54
+ activation(),
55
+ nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
56
+ LayerNorm2d(mask_in_chans),
57
+ activation(),
58
+ nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
59
+ )
60
+ self.no_mask_embed = nn.Embedding(1, embed_dim)
61
+
62
+ def get_dense_pe(self) -> torch.Tensor:
63
+ """
64
+ Returns the positional encoding used to encode point prompts,
65
+ applied to a dense set of points the shape of the image encoding.
66
+
67
+ Returns:
68
+ torch.Tensor: Positional encoding with shape
69
+ 1x(embed_dim)x(embedding_h)x(embedding_w)
70
+ """
71
+ return self.pe_layer(self.image_embedding_size).unsqueeze(0)
72
+
73
+ def _embed_points(
74
+ self,
75
+ points: torch.Tensor,
76
+ labels: torch.Tensor,
77
+ pad: bool,
78
+ ) -> torch.Tensor:
79
+ """Embeds point prompts."""
80
+ points = points + 0.5 # Shift to center of pixel
81
+ if pad:
82
+ padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
83
+ padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
84
+ points = torch.cat([points, padding_point], dim=1)
85
+ labels = torch.cat([labels, padding_label], dim=1)
86
+ point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
87
+ point_embedding[labels == -1] = 0.0
88
+ point_embedding[labels == -1] += self.not_a_point_embed.weight
89
+ point_embedding[labels == 0] += self.point_embeddings[0].weight
90
+ point_embedding[labels == 1] += self.point_embeddings[1].weight
91
+ return point_embedding
92
+
93
+ def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
94
+ """Embeds box prompts."""
95
+ boxes = boxes + 0.5 # Shift to center of pixel
96
+ coords = boxes.reshape(-1, 2, 2)
97
+ corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
98
+ corner_embedding[:, 0, :] += self.point_embeddings[2].weight
99
+ corner_embedding[:, 1, :] += self.point_embeddings[3].weight
100
+ return corner_embedding
101
+
102
+ def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
103
+ """Embeds mask inputs."""
104
+ mask_embedding = self.mask_downscaling(masks)
105
+ return mask_embedding
106
+
107
+ def _get_batch_size(
108
+ self,
109
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
110
+ boxes: Optional[torch.Tensor],
111
+ masks: Optional[torch.Tensor],
112
+ ) -> int:
113
+ """
114
+ Gets the batch size of the output given the batch size of the input prompts.
115
+ """
116
+ if points is not None:
117
+ return points[0].shape[0]
118
+ elif boxes is not None:
119
+ return boxes.shape[0]
120
+ elif masks is not None:
121
+ return masks.shape[0]
122
+ else:
123
+ return 1
124
+
125
+ def _get_device(self) -> torch.device:
126
+ return self.point_embeddings[0].weight.device
127
+
128
+ def forward(
129
+ self,
130
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
131
+ boxes: Optional[torch.Tensor],
132
+ masks: Optional[torch.Tensor],
133
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
134
+ """
135
+ Embeds different types of prompts, returning both sparse and dense
136
+ embeddings.
137
+
138
+ Arguments:
139
+ points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
140
+ and labels to embed.
141
+ boxes (torch.Tensor or none): boxes to embed
142
+ masks (torch.Tensor or none): masks to embed
143
+
144
+ Returns:
145
+ torch.Tensor: sparse embeddings for the points and boxes, with shape
146
+ BxNx(embed_dim), where N is determined by the number of input points
147
+ and boxes.
148
+ torch.Tensor: dense embeddings for the masks, in the shape
149
+ Bx(embed_dim)x(embed_H)x(embed_W)
150
+ """
151
+ bs = self._get_batch_size(points, boxes, masks)
152
+ sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
153
+ if points is not None:
154
+ coords, labels = points
155
+ point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
156
+ sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
157
+ if boxes is not None:
158
+ box_embeddings = self._embed_boxes(boxes)
159
+ sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
160
+
161
+ if masks is not None:
162
+ dense_embeddings = self._embed_masks(masks)
163
+ else:
164
+ dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
165
+ bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
166
+ )
167
+
168
+ return sparse_embeddings, dense_embeddings
169
+
170
+
171
+ class PositionEmbeddingRandom(nn.Module):
172
+ """
173
+ Positional encoding using random spatial frequencies.
174
+ """
175
+
176
+ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
177
+ super().__init__()
178
+ if scale is None or scale <= 0.0:
179
+ scale = 1.0
180
+ self.register_buffer(
181
+ "positional_encoding_gaussian_matrix",
182
+ scale * torch.randn((2, num_pos_feats)),
183
+ )
184
+
185
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
186
+ """Positionally encode points that are normalized to [0,1]."""
187
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
188
+ coords = 2 * coords - 1
189
+ coords = coords @ self.positional_encoding_gaussian_matrix
190
+ coords = 2 * np.pi * coords
191
+ # outputs d_1 x ... x d_n x C shape
192
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
193
+
194
+ def forward(self, size: Tuple[int, int]) -> torch.Tensor:
195
+ """Generate positional encoding for a grid of the specified size."""
196
+ h, w = size
197
+ device: Any = self.positional_encoding_gaussian_matrix.device
198
+ grid = torch.ones((h, w), device=device, dtype=torch.float32)
199
+ y_embed = grid.cumsum(dim=0) - 0.5
200
+ x_embed = grid.cumsum(dim=1) - 0.5
201
+ y_embed = y_embed / h
202
+ x_embed = x_embed / w
203
+
204
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
205
+ return pe.permute(2, 0, 1) # C x H x W
206
+
207
+ def forward_with_coords(
208
+ self, coords_input: torch.Tensor, image_size: Tuple[int, int]
209
+ ) -> torch.Tensor:
210
+ """Positionally encode points that are not normalized to [0,1]."""
211
+ coords = coords_input.clone()
212
+ coords[:, :, 0] = coords[:, :, 0] / image_size[1]
213
+ coords[:, :, 1] = coords[:, :, 1] / image_size[0]
214
+ return self._pe_encoding(coords.to(torch.float)) # B x N x C
tools/ins_seg/sam/sam_cls/segment_anything/modeling/sam.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+
11
+ from typing import Any, Dict, List, Tuple
12
+
13
+ from .image_encoder import ImageEncoderViT
14
+ from .mask_decoder import MaskDecoder
15
+ from .prompt_encoder import PromptEncoder
16
+
17
+
18
+ class Sam(nn.Module):
19
+ mask_threshold: float = 0.0
20
+ image_format: str = "RGB"
21
+
22
+ def __init__(
23
+ self,
24
+ image_encoder: ImageEncoderViT,
25
+ prompt_encoder: PromptEncoder,
26
+ mask_decoder: MaskDecoder,
27
+ pixel_mean: List[float] = [123.675, 116.28, 103.53],
28
+ pixel_std: List[float] = [58.395, 57.12, 57.375],
29
+ ) -> None:
30
+ """
31
+ SAM predicts object masks from an image and input prompts.
32
+
33
+ Arguments:
34
+ image_encoder (ImageEncoderViT): The backbone used to encode the
35
+ image into image embeddings that allow for efficient mask prediction.
36
+ prompt_encoder (PromptEncoder): Encodes various types of input prompts.
37
+ mask_decoder (MaskDecoder): Predicts masks from the image embeddings
38
+ and encoded prompts.
39
+ pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
40
+ pixel_std (list(float)): Std values for normalizing pixels in the input image.
41
+ """
42
+ super().__init__()
43
+ self.image_encoder = image_encoder
44
+ self.prompt_encoder = prompt_encoder
45
+ self.mask_decoder = mask_decoder
46
+ self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
47
+ self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
48
+
49
+ @property
50
+ def device(self) -> Any:
51
+ return self.pixel_mean.device
52
+
53
+ @torch.no_grad()
54
+ def forward(
55
+ self,
56
+ batched_input: List[Dict[str, Any]],
57
+ multimask_output: bool,
58
+ ) -> List[Dict[str, torch.Tensor]]:
59
+ """
60
+ Predicts masks end-to-end from provided images and prompts.
61
+ If prompts are not known in advance, using SamPredictor is
62
+ recommended over calling the model directly.
63
+
64
+ Arguments:
65
+ batched_input (list(dict)): A list over input images, each a
66
+ dictionary with the following keys. A prompt key can be
67
+ excluded if it is not present.
68
+ 'image': The image as a torch tensor in 3xHxW format,
69
+ already transformed for input to the model.
70
+ 'original_size': (tuple(int, int)) The original size of
71
+ the image before transformation, as (H, W).
72
+ 'point_coords': (torch.Tensor) Batched point prompts for
73
+ this image, with shape BxNx2. Already transformed to the
74
+ input frame of the model.
75
+ 'point_labels': (torch.Tensor) Batched labels for point prompts,
76
+ with shape BxN.
77
+ 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
78
+ Already transformed to the input frame of the model.
79
+ 'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
80
+ in the form Bx1xHxW.
81
+ multimask_output (bool): Whether the model should predict multiple
82
+ disambiguating masks, or return a single mask.
83
+
84
+ Returns:
85
+ (list(dict)): A list over input images, where each element is
86
+ as dictionary with the following keys.
87
+ 'masks': (torch.Tensor) Batched binary mask predictions,
88
+ with shape BxCxHxW, where B is the number of input prompts,
89
+ C is determined by multimask_output, and (H, W) is the
90
+ original size of the image.
91
+ 'iou_predictions': (torch.Tensor) The model's predictions
92
+ of mask quality, in shape BxC.
93
+ 'low_res_logits': (torch.Tensor) Low resolution logits with
94
+ shape BxCxHxW, where H=W=256. Can be passed as mask input
95
+ to subsequent iterations of prediction.
96
+ """
97
+ input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
98
+ image_embeddings = self.image_encoder(input_images)
99
+
100
+ outputs = []
101
+ for image_record, curr_embedding in zip(batched_input, image_embeddings):
102
+ if "point_coords" in image_record:
103
+ points = (image_record["point_coords"], image_record["point_labels"])
104
+ else:
105
+ points = None
106
+ sparse_embeddings, dense_embeddings = self.prompt_encoder(
107
+ points=points,
108
+ boxes=image_record.get("boxes", None),
109
+ masks=image_record.get("mask_inputs", None),
110
+ )
111
+ low_res_masks, iou_predictions = self.mask_decoder(
112
+ image_embeddings=curr_embedding.unsqueeze(0),
113
+ image_pe=self.prompt_encoder.get_dense_pe(),
114
+ sparse_prompt_embeddings=sparse_embeddings,
115
+ dense_prompt_embeddings=dense_embeddings,
116
+ multimask_output=multimask_output,
117
+ )
118
+ masks = self.postprocess_masks(
119
+ low_res_masks,
120
+ input_size=image_record["image"].shape[-2:],
121
+ original_size=image_record["original_size"],
122
+ )
123
+ masks = masks > self.mask_threshold
124
+ outputs.append(
125
+ {
126
+ "masks": masks,
127
+ "iou_predictions": iou_predictions,
128
+ "low_res_logits": low_res_masks,
129
+ }
130
+ )
131
+ return outputs
132
+
133
+ def postprocess_masks(
134
+ self,
135
+ masks: torch.Tensor,
136
+ input_size: Tuple[int, ...],
137
+ original_size: Tuple[int, ...],
138
+ ) -> torch.Tensor:
139
+ """
140
+ Remove padding and upscale masks to the original image size.
141
+
142
+ Arguments:
143
+ masks (torch.Tensor): Batched masks from the mask_decoder,
144
+ in BxCxHxW format.
145
+ input_size (tuple(int, int)): The size of the image input to the
146
+ model, in (H, W) format. Used to remove padding.
147
+ original_size (tuple(int, int)): The original size of the image
148
+ before resizing for input to the model, in (H, W) format.
149
+
150
+ Returns:
151
+ (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
152
+ is given by original_size.
153
+ """
154
+ masks = F.interpolate(
155
+ masks,
156
+ (self.image_encoder.img_size, self.image_encoder.img_size),
157
+ mode="bilinear",
158
+ align_corners=False,
159
+ )
160
+ masks = masks[..., : input_size[0], : input_size[1]]
161
+ masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
162
+ return masks
163
+
164
+ def preprocess(self, x: torch.Tensor) -> torch.Tensor:
165
+ """Normalize pixel values and pad to a square input."""
166
+ # Normalize colors
167
+ x = (x - self.pixel_mean) / self.pixel_std
168
+
169
+ # Pad
170
+ h, w = x.shape[-2:]
171
+ padh = self.image_encoder.img_size - h
172
+ padw = self.image_encoder.img_size - w
173
+ x = F.pad(x, (0, padw, 0, padh))
174
+ return x
tools/ins_seg/sam/sam_cls/segment_anything/modeling/transformer.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ from torch import Tensor, nn
9
+
10
+ import math
11
+ from typing import Tuple, Type
12
+
13
+ from .common import MLPBlock
14
+
15
+
16
+ class TwoWayTransformer(nn.Module):
17
+ def __init__(
18
+ self,
19
+ depth: int,
20
+ embedding_dim: int,
21
+ num_heads: int,
22
+ mlp_dim: int,
23
+ activation: Type[nn.Module] = nn.ReLU,
24
+ attention_downsample_rate: int = 2,
25
+ ) -> None:
26
+ """
27
+ A transformer decoder that attends to an input image using
28
+ queries whose positional embedding is supplied.
29
+
30
+ Args:
31
+ depth (int): number of layers in the transformer
32
+ embedding_dim (int): the channel dimension for the input embeddings
33
+ num_heads (int): the number of heads for multihead attention. Must
34
+ divide embedding_dim
35
+ mlp_dim (int): the channel dimension internal to the MLP block
36
+ activation (nn.Module): the activation to use in the MLP block
37
+ """
38
+ super().__init__()
39
+ self.depth = depth
40
+ self.embedding_dim = embedding_dim
41
+ self.num_heads = num_heads
42
+ self.mlp_dim = mlp_dim
43
+ self.layers = nn.ModuleList()
44
+
45
+ for i in range(depth):
46
+ self.layers.append(
47
+ TwoWayAttentionBlock(
48
+ embedding_dim=embedding_dim,
49
+ num_heads=num_heads,
50
+ mlp_dim=mlp_dim,
51
+ activation=activation,
52
+ attention_downsample_rate=attention_downsample_rate,
53
+ skip_first_layer_pe=(i == 0),
54
+ )
55
+ )
56
+
57
+ self.final_attn_token_to_image = Attention(
58
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
59
+ )
60
+ self.norm_final_attn = nn.LayerNorm(embedding_dim)
61
+
62
+ def forward(
63
+ self,
64
+ image_embedding: Tensor,
65
+ image_pe: Tensor,
66
+ point_embedding: Tensor,
67
+ ) -> Tuple[Tensor, Tensor]:
68
+ """
69
+ Args:
70
+ image_embedding (torch.Tensor): image to attend to. Should be shape
71
+ B x embedding_dim x h x w for any h and w.
72
+ image_pe (torch.Tensor): the positional encoding to add to the image. Must
73
+ have the same shape as image_embedding.
74
+ point_embedding (torch.Tensor): the embedding to add to the query points.
75
+ Must have shape B x N_points x embedding_dim for any N_points.
76
+
77
+ Returns:
78
+ torch.Tensor: the processed point_embedding
79
+ torch.Tensor: the processed image_embedding
80
+ """
81
+ # BxCxHxW -> BxHWxC == B x N_image_tokens x C
82
+ bs, c, h, w = image_embedding.shape
83
+ image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
84
+ image_pe = image_pe.flatten(2).permute(0, 2, 1)
85
+
86
+ # Prepare queries
87
+ queries = point_embedding
88
+ keys = image_embedding
89
+
90
+ # Apply transformer blocks and final layernorm
91
+ for layer in self.layers:
92
+ queries, keys = layer(
93
+ queries=queries,
94
+ keys=keys,
95
+ query_pe=point_embedding,
96
+ key_pe=image_pe,
97
+ )
98
+
99
+ # Apply the final attention layer from the points to the image
100
+ q = queries + point_embedding
101
+ k = keys + image_pe
102
+ attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
103
+ queries = queries + attn_out
104
+ queries = self.norm_final_attn(queries)
105
+
106
+ return queries, keys
107
+
108
+
109
+ class TwoWayAttentionBlock(nn.Module):
110
+ def __init__(
111
+ self,
112
+ embedding_dim: int,
113
+ num_heads: int,
114
+ mlp_dim: int = 2048,
115
+ activation: Type[nn.Module] = nn.ReLU,
116
+ attention_downsample_rate: int = 2,
117
+ skip_first_layer_pe: bool = False,
118
+ ) -> None:
119
+ """
120
+ A transformer block with four layers: (1) self-attention of sparse
121
+ inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
122
+ block on sparse inputs, and (4) cross attention of dense inputs to sparse
123
+ inputs.
124
+
125
+ Arguments:
126
+ embedding_dim (int): the channel dimension of the embeddings
127
+ num_heads (int): the number of heads in the attention layers
128
+ mlp_dim (int): the hidden dimension of the mlp block
129
+ activation (nn.Module): the activation of the mlp block
130
+ skip_first_layer_pe (bool): skip the PE on the first layer
131
+ """
132
+ super().__init__()
133
+ self.self_attn = Attention(embedding_dim, num_heads)
134
+ self.norm1 = nn.LayerNorm(embedding_dim)
135
+
136
+ self.cross_attn_token_to_image = Attention(
137
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
138
+ )
139
+ self.norm2 = nn.LayerNorm(embedding_dim)
140
+
141
+ self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
142
+ self.norm3 = nn.LayerNorm(embedding_dim)
143
+
144
+ self.norm4 = nn.LayerNorm(embedding_dim)
145
+ self.cross_attn_image_to_token = Attention(
146
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
147
+ )
148
+
149
+ self.skip_first_layer_pe = skip_first_layer_pe
150
+
151
+ def forward(
152
+ self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
153
+ ) -> Tuple[Tensor, Tensor]:
154
+ # Self attention block
155
+ if self.skip_first_layer_pe:
156
+ queries = self.self_attn(q=queries, k=queries, v=queries)
157
+ else:
158
+ q = queries + query_pe
159
+ attn_out = self.self_attn(q=q, k=q, v=queries)
160
+ queries = queries + attn_out
161
+ queries = self.norm1(queries)
162
+
163
+ # Cross attention block, tokens attending to image embedding
164
+ q = queries + query_pe
165
+ k = keys + key_pe
166
+ attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
167
+ queries = queries + attn_out
168
+ queries = self.norm2(queries)
169
+
170
+ # MLP block
171
+ mlp_out = self.mlp(queries)
172
+ queries = queries + mlp_out
173
+ queries = self.norm3(queries)
174
+
175
+ # Cross attention block, image embedding attending to tokens
176
+ q = queries + query_pe
177
+ k = keys + key_pe
178
+ attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
179
+ keys = keys + attn_out
180
+ keys = self.norm4(keys)
181
+
182
+ return queries, keys
183
+
184
+
185
+ class Attention(nn.Module):
186
+ """
187
+ An attention layer that allows for downscaling the size of the embedding
188
+ after projection to queries, keys, and values.
189
+ """
190
+
191
+ def __init__(
192
+ self,
193
+ embedding_dim: int,
194
+ num_heads: int,
195
+ downsample_rate: int = 1,
196
+ ) -> None:
197
+ super().__init__()
198
+ self.embedding_dim = embedding_dim
199
+ self.internal_dim = embedding_dim // downsample_rate
200
+ self.num_heads = num_heads
201
+ assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
202
+
203
+ self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
204
+ self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
205
+ self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
206
+ self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
207
+
208
+ def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
209
+ b, n, c = x.shape
210
+ x = x.reshape(b, n, num_heads, c // num_heads)
211
+ return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
212
+
213
+ def _recombine_heads(self, x: Tensor) -> Tensor:
214
+ b, n_heads, n_tokens, c_per_head = x.shape
215
+ x = x.transpose(1, 2)
216
+ return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
217
+
218
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
219
+ # Input projections
220
+ q = self.q_proj(q)
221
+ k = self.k_proj(k)
222
+ v = self.v_proj(v)
223
+
224
+ # Separate into heads
225
+ q = self._separate_heads(q, self.num_heads)
226
+ k = self._separate_heads(k, self.num_heads)
227
+ v = self._separate_heads(v, self.num_heads)
228
+
229
+ # Attention
230
+ _, _, _, c_per_head = q.shape
231
+ attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
232
+ attn = attn / math.sqrt(c_per_head)
233
+ attn = torch.softmax(attn, dim=-1)
234
+
235
+ # Get output
236
+ out = attn @ v
237
+ out = self._recombine_heads(out)
238
+ out = self.out_proj(out)
239
+
240
+ return out
tools/ins_seg/sam/sam_cls/segment_anything/predictor.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ from segment_anything.modeling import Sam
11
+
12
+ from typing import Optional, Tuple
13
+
14
+ from .utils.transforms import ResizeLongestSide
15
+
16
+
17
+ class SamPredictor:
18
+ def __init__(
19
+ self,
20
+ sam_model: Sam,
21
+ ) -> None:
22
+ """
23
+ Uses SAM to calculate the image embedding for an image, and then
24
+ allow repeated, efficient mask prediction given prompts.
25
+
26
+ Arguments:
27
+ sam_model (Sam): The model to use for mask prediction.
28
+ """
29
+ super().__init__()
30
+ self.model = sam_model
31
+ self.transform = ResizeLongestSide(sam_model.image_encoder.img_size)
32
+ self.reset_image()
33
+
34
+ def set_image(
35
+ self,
36
+ image: np.ndarray,
37
+ image_format: str = "RGB",
38
+ ) -> None:
39
+ """
40
+ Calculates the image embeddings for the provided image, allowing
41
+ masks to be predicted with the 'predict' method.
42
+
43
+ Arguments:
44
+ image (np.ndarray): The image for calculating masks. Expects an
45
+ image in HWC uint8 format, with pixel values in [0, 255].
46
+ image_format (str): The color format of the image, in ['RGB', 'BGR'].
47
+ """
48
+ assert image_format in [
49
+ "RGB",
50
+ "BGR",
51
+ ], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
52
+ if image_format != self.model.image_format:
53
+ image = image[..., ::-1]
54
+
55
+ # Transform the image to the form expected by the model
56
+ input_image = self.transform.apply_image(image)
57
+ input_image_torch = torch.as_tensor(input_image, device=self.device)
58
+ input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
59
+
60
+ self.set_torch_image(input_image_torch, image.shape[:2])
61
+
62
+ @torch.no_grad()
63
+ def set_torch_image(
64
+ self,
65
+ transformed_image: torch.Tensor,
66
+ original_image_size: Tuple[int, ...],
67
+ ) -> None:
68
+ """
69
+ Calculates the image embeddings for the provided image, allowing
70
+ masks to be predicted with the 'predict' method. Expects the input
71
+ image to be already transformed to the format expected by the model.
72
+
73
+ Arguments:
74
+ transformed_image (torch.Tensor): The input image, with shape
75
+ 1x3xHxW, which has been transformed with ResizeLongestSide.
76
+ original_image_size (tuple(int, int)): The size of the image
77
+ before transformation, in (H, W) format.
78
+ """
79
+ assert (
80
+ len(transformed_image.shape) == 4
81
+ and transformed_image.shape[1] == 3
82
+ and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size
83
+ ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}."
84
+ self.reset_image()
85
+
86
+ self.original_size = original_image_size
87
+ self.input_size = tuple(transformed_image.shape[-2:])
88
+ input_image = self.model.preprocess(transformed_image)
89
+ self.features = self.model.image_encoder(input_image)
90
+ self.is_image_set = True
91
+
92
+ def predict(
93
+ self,
94
+ point_coords: Optional[np.ndarray] = None,
95
+ point_labels: Optional[np.ndarray] = None,
96
+ box: Optional[np.ndarray] = None,
97
+ mask_input: Optional[np.ndarray] = None,
98
+ multimask_output: bool = True,
99
+ return_logits: bool = False,
100
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
101
+ """
102
+ Predict masks for the given input prompts, using the currently set image.
103
+
104
+ Arguments:
105
+ point_coords (np.ndarray or None): A Nx2 array of point prompts to the
106
+ model. Each point is in (X,Y) in pixels.
107
+ point_labels (np.ndarray or None): A length N array of labels for the
108
+ point prompts. 1 indicates a foreground point and 0 indicates a
109
+ background point.
110
+ box (np.ndarray or None): A length 4 array given a box prompt to the
111
+ model, in XYXY format.
112
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
113
+ coming from a previous prediction iteration. Has form 1xHxW, where
114
+ for SAM, H=W=256.
115
+ multimask_output (bool): If true, the model will return three masks.
116
+ For ambiguous input prompts (such as a single click), this will often
117
+ produce better masks than a single prediction. If only a single
118
+ mask is needed, the model's predicted quality score can be used
119
+ to select the best mask. For non-ambiguous prompts, such as multiple
120
+ input prompts, multimask_output=False can give better results.
121
+ return_logits (bool): If true, returns un-thresholded masks logits
122
+ instead of a binary mask.
123
+
124
+ Returns:
125
+ (np.ndarray): The output masks in CxHxW format, where C is the
126
+ number of masks, and (H, W) is the original image size.
127
+ (np.ndarray): An array of length C containing the model's
128
+ predictions for the quality of each mask.
129
+ (np.ndarray): An array of shape CxHxW, where C is the number
130
+ of masks and H=W=256. These low resolution logits can be passed to
131
+ a subsequent iteration as mask input.
132
+ """
133
+ if not self.is_image_set:
134
+ raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
135
+
136
+ # Transform input prompts
137
+ coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
138
+ if point_coords is not None:
139
+ assert (
140
+ point_labels is not None
141
+ ), "point_labels must be supplied if point_coords is supplied."
142
+ point_coords = self.transform.apply_coords(point_coords, self.original_size)
143
+ coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
144
+ labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
145
+ coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
146
+ if box is not None:
147
+ box = self.transform.apply_boxes(box, self.original_size)
148
+ box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
149
+ box_torch = box_torch[None, :]
150
+ if mask_input is not None:
151
+ mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device)
152
+ mask_input_torch = mask_input_torch[None, :, :, :]
153
+
154
+ masks, iou_predictions, low_res_masks = self.predict_torch(
155
+ coords_torch,
156
+ labels_torch,
157
+ box_torch,
158
+ mask_input_torch,
159
+ multimask_output,
160
+ return_logits=return_logits,
161
+ )
162
+
163
+ masks_np = masks[0].detach().cpu().numpy()
164
+ iou_predictions_np = iou_predictions[0].detach().cpu().numpy()
165
+ low_res_masks_np = low_res_masks[0].detach().cpu().numpy()
166
+ return masks_np, iou_predictions_np, low_res_masks_np
167
+
168
+ @torch.no_grad()
169
+ def predict_torch(
170
+ self,
171
+ point_coords: Optional[torch.Tensor],
172
+ point_labels: Optional[torch.Tensor],
173
+ boxes: Optional[torch.Tensor] = None,
174
+ mask_input: Optional[torch.Tensor] = None,
175
+ multimask_output: bool = True,
176
+ return_logits: bool = False,
177
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
178
+ """
179
+ Predict masks for the given input prompts, using the currently set image.
180
+ Input prompts are batched torch tensors and are expected to already be
181
+ transformed to the input frame using ResizeLongestSide.
182
+
183
+ Arguments:
184
+ point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
185
+ model. Each point is in (X,Y) in pixels.
186
+ point_labels (torch.Tensor or None): A BxN array of labels for the
187
+ point prompts. 1 indicates a foreground point and 0 indicates a
188
+ background point.
189
+ boxes (np.ndarray or None): A Bx4 array given a box prompt to the
190
+ model, in XYXY format.
191
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
192
+ coming from a previous prediction iteration. Has form Bx1xHxW, where
193
+ for SAM, H=W=256. Masks returned by a previous iteration of the
194
+ predict method do not need further transformation.
195
+ multimask_output (bool): If true, the model will return three masks.
196
+ For ambiguous input prompts (such as a single click), this will often
197
+ produce better masks than a single prediction. If only a single
198
+ mask is needed, the model's predicted quality score can be used
199
+ to select the best mask. For non-ambiguous prompts, such as multiple
200
+ input prompts, multimask_output=False can give better results.
201
+ return_logits (bool): If true, returns un-thresholded masks logits
202
+ instead of a binary mask.
203
+
204
+ Returns:
205
+ (torch.Tensor): The output masks in BxCxHxW format, where C is the
206
+ number of masks, and (H, W) is the original image size.
207
+ (torch.Tensor): An array of shape BxC containing the model's
208
+ predictions for the quality of each mask.
209
+ (torch.Tensor): An array of shape BxCxHxW, where C is the number
210
+ of masks and H=W=256. These low res logits can be passed to
211
+ a subsequent iteration as mask input.
212
+ """
213
+ if not self.is_image_set:
214
+ raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
215
+
216
+ if point_coords is not None:
217
+ points = (point_coords, point_labels)
218
+ else:
219
+ points = None
220
+
221
+ # Embed prompts
222
+ sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
223
+ points=points,
224
+ boxes=boxes,
225
+ masks=mask_input,
226
+ )
227
+
228
+ # Predict masks
229
+ low_res_masks, iou_predictions = self.model.mask_decoder(
230
+ image_embeddings=self.features,
231
+ image_pe=self.model.prompt_encoder.get_dense_pe(),
232
+ sparse_prompt_embeddings=sparse_embeddings,
233
+ dense_prompt_embeddings=dense_embeddings,
234
+ multimask_output=multimask_output,
235
+ )
236
+
237
+ # Upscale the masks to the original image resolution
238
+ masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)
239
+
240
+ if not return_logits:
241
+ masks = masks > self.model.mask_threshold
242
+
243
+ return masks, iou_predictions, low_res_masks
244
+
245
+ def get_image_embedding(self) -> torch.Tensor:
246
+ """
247
+ Returns the image embeddings for the currently set image, with
248
+ shape 1xCxHxW, where C is the embedding dimension and (H,W) are
249
+ the embedding spatial dimension of SAM (typically C=256, H=W=64).
250
+ """
251
+ if not self.is_image_set:
252
+ raise RuntimeError(
253
+ "An image must be set with .set_image(...) to generate an embedding."
254
+ )
255
+ assert self.features is not None, "Features must exist if an image has been set."
256
+ return self.features
257
+
258
+ @property
259
+ def device(self) -> torch.device:
260
+ return self.model.device
261
+
262
+ def reset_image(self) -> None:
263
+ """Resets the currently set image."""
264
+ self.is_image_set = False
265
+ self.features = None
266
+ self.orig_h = None
267
+ self.orig_w = None
268
+ self.input_h = None
269
+ self.input_w = None
tools/ins_seg/sam/sam_cls/segment_anything/utils/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
tools/ins_seg/sam/sam_cls/segment_anything/utils/amg.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ import math
11
+ from copy import deepcopy
12
+ from itertools import product
13
+ from typing import Any, Dict, Generator, ItemsView, List, Tuple
14
+
15
+
16
+ class MaskData:
17
+ """
18
+ A structure for storing masks and their related data in batched format.
19
+ Implements basic filtering and concatenation.
20
+ """
21
+
22
+ def __init__(self, **kwargs) -> None:
23
+ for v in kwargs.values():
24
+ assert isinstance(
25
+ v, (list, np.ndarray, torch.Tensor)
26
+ ), "MaskData only supports list, numpy arrays, and torch tensors."
27
+ self._stats = dict(**kwargs)
28
+
29
+ def __setitem__(self, key: str, item: Any) -> None:
30
+ assert isinstance(
31
+ item, (list, np.ndarray, torch.Tensor)
32
+ ), "MaskData only supports list, numpy arrays, and torch tensors."
33
+ self._stats[key] = item
34
+
35
+ def __delitem__(self, key: str) -> None:
36
+ del self._stats[key]
37
+
38
+ def __getitem__(self, key: str) -> Any:
39
+ return self._stats[key]
40
+
41
+ def items(self) -> ItemsView[str, Any]:
42
+ return self._stats.items()
43
+
44
+ def filter(self, keep: torch.Tensor) -> None:
45
+ for k, v in self._stats.items():
46
+ if v is None:
47
+ self._stats[k] = None
48
+ elif isinstance(v, torch.Tensor):
49
+ self._stats[k] = v[torch.as_tensor(keep, device=v.device)]
50
+ elif isinstance(v, np.ndarray):
51
+ self._stats[k] = v[keep.detach().cpu().numpy()]
52
+ elif isinstance(v, list) and keep.dtype == torch.bool:
53
+ self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
54
+ elif isinstance(v, list):
55
+ self._stats[k] = [v[i] for i in keep]
56
+ else:
57
+ raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
58
+
59
+ def cat(self, new_stats: "MaskData") -> None:
60
+ for k, v in new_stats.items():
61
+ if k not in self._stats or self._stats[k] is None:
62
+ self._stats[k] = deepcopy(v)
63
+ elif isinstance(v, torch.Tensor):
64
+ self._stats[k] = torch.cat([self._stats[k], v], dim=0)
65
+ elif isinstance(v, np.ndarray):
66
+ self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
67
+ elif isinstance(v, list):
68
+ self._stats[k] = self._stats[k] + deepcopy(v)
69
+ else:
70
+ raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
71
+
72
+ def to_numpy(self) -> None:
73
+ for k, v in self._stats.items():
74
+ if isinstance(v, torch.Tensor):
75
+ self._stats[k] = v.detach().cpu().numpy()
76
+
77
+
78
+ def is_box_near_crop_edge(
79
+ boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
80
+ ) -> torch.Tensor:
81
+ """Filter masks at the edge of a crop, but not at the edge of the original image."""
82
+ crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
83
+ orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
84
+ boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
85
+ near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
86
+ near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
87
+ near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
88
+ return torch.any(near_crop_edge, dim=1)
89
+
90
+
91
+ def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor:
92
+ box_xywh = deepcopy(box_xyxy)
93
+ box_xywh[2] = box_xywh[2] - box_xywh[0]
94
+ box_xywh[3] = box_xywh[3] - box_xywh[1]
95
+ return box_xywh
96
+
97
+
98
+ def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
99
+ assert len(args) > 0 and all(
100
+ len(a) == len(args[0]) for a in args
101
+ ), "Batched iteration must have inputs of all the same size."
102
+ n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
103
+ for b in range(n_batches):
104
+ yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
105
+
106
+
107
+ def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
108
+ """
109
+ Encodes masks to an uncompressed RLE, in the format expected by
110
+ pycoco tools.
111
+ """
112
+ # Put in fortran order and flatten h,w
113
+ b, h, w = tensor.shape
114
+ tensor = tensor.permute(0, 2, 1).flatten(1)
115
+
116
+ # Compute change indices
117
+ diff = tensor[:, 1:] ^ tensor[:, :-1]
118
+ change_indices = diff.nonzero()
119
+
120
+ # Encode run length
121
+ out = []
122
+ for i in range(b):
123
+ cur_idxs = change_indices[change_indices[:, 0] == i, 1]
124
+ cur_idxs = torch.cat(
125
+ [
126
+ torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
127
+ cur_idxs + 1,
128
+ torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device),
129
+ ]
130
+ )
131
+ btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
132
+ counts = [] if tensor[i, 0] == 0 else [0]
133
+ counts.extend(btw_idxs.detach().cpu().tolist())
134
+ out.append({"size": [h, w], "counts": counts})
135
+ return out
136
+
137
+
138
+ def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
139
+ """Compute a binary mask from an uncompressed RLE."""
140
+ h, w = rle["size"]
141
+ mask = np.empty(h * w, dtype=bool)
142
+ idx = 0
143
+ parity = False
144
+ for count in rle["counts"]:
145
+ mask[idx : idx + count] = parity
146
+ idx += count
147
+ parity ^= True
148
+ mask = mask.reshape(w, h)
149
+ return mask.transpose() # Put in C order
150
+
151
+
152
+ def area_from_rle(rle: Dict[str, Any]) -> int:
153
+ return sum(rle["counts"][1::2])
154
+
155
+
156
+ def calculate_stability_score(
157
+ masks: torch.Tensor, mask_threshold: float, threshold_offset: float
158
+ ) -> torch.Tensor:
159
+ """
160
+ Computes the stability score for a batch of masks. The stability
161
+ score is the IoU between the binary masks obtained by thresholding
162
+ the predicted mask logits at high and low values.
163
+ """
164
+ # One mask is always contained inside the other.
165
+ # Save memory by preventing unnecessary cast to torch.int64
166
+ intersections = (
167
+ (masks > (mask_threshold + threshold_offset))
168
+ .sum(-1, dtype=torch.int16)
169
+ .sum(-1, dtype=torch.int32)
170
+ )
171
+ unions = (
172
+ (masks > (mask_threshold - threshold_offset))
173
+ .sum(-1, dtype=torch.int16)
174
+ .sum(-1, dtype=torch.int32)
175
+ )
176
+ return intersections / unions
177
+
178
+
179
+ def build_point_grid(n_per_side: int) -> np.ndarray:
180
+ """Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
181
+ offset = 1 / (2 * n_per_side)
182
+ points_one_side = np.linspace(offset, 1 - offset, n_per_side)
183
+ points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
184
+ points_y = np.tile(points_one_side[:, None], (1, n_per_side))
185
+ points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
186
+ return points
187
+
188
+
189
+ def build_all_layer_point_grids(
190
+ n_per_side: int, n_layers: int, scale_per_layer: int
191
+ ) -> List[np.ndarray]:
192
+ """Generates point grids for all crop layers."""
193
+ points_by_layer = []
194
+ for i in range(n_layers + 1):
195
+ n_points = int(n_per_side / (scale_per_layer**i))
196
+ points_by_layer.append(build_point_grid(n_points))
197
+ return points_by_layer
198
+
199
+
200
+ def generate_crop_boxes(
201
+ im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
202
+ ) -> Tuple[List[List[int]], List[int]]:
203
+ """
204
+ Generates a list of crop boxes of different sizes. Each layer
205
+ has (2**i)**2 boxes for the ith layer.
206
+ """
207
+ crop_boxes, layer_idxs = [], []
208
+ im_h, im_w = im_size
209
+ short_side = min(im_h, im_w)
210
+
211
+ # Original image
212
+ crop_boxes.append([0, 0, im_w, im_h])
213
+ layer_idxs.append(0)
214
+
215
+ def crop_len(orig_len, n_crops, overlap):
216
+ return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
217
+
218
+ for i_layer in range(n_layers):
219
+ n_crops_per_side = 2 ** (i_layer + 1)
220
+ overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
221
+
222
+ crop_w = crop_len(im_w, n_crops_per_side, overlap)
223
+ crop_h = crop_len(im_h, n_crops_per_side, overlap)
224
+
225
+ crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
226
+ crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
227
+
228
+ # Crops in XYWH format
229
+ for x0, y0 in product(crop_box_x0, crop_box_y0):
230
+ box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
231
+ crop_boxes.append(box)
232
+ layer_idxs.append(i_layer + 1)
233
+
234
+ return crop_boxes, layer_idxs
235
+
236
+
237
+ def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
238
+ x0, y0, _, _ = crop_box
239
+ offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
240
+ # Check if boxes has a channel dimension
241
+ if len(boxes.shape) == 3:
242
+ offset = offset.unsqueeze(1)
243
+ return boxes + offset
244
+
245
+
246
+ def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
247
+ x0, y0, _, _ = crop_box
248
+ offset = torch.tensor([[x0, y0]], device=points.device)
249
+ # Check if points has a channel dimension
250
+ if len(points.shape) == 3:
251
+ offset = offset.unsqueeze(1)
252
+ return points + offset
253
+
254
+
255
+ def uncrop_masks(
256
+ masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int
257
+ ) -> torch.Tensor:
258
+ x0, y0, x1, y1 = crop_box
259
+ if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
260
+ return masks
261
+ # Coordinate transform masks
262
+ pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
263
+ pad = (x0, pad_x - x0, y0, pad_y - y0)
264
+ return torch.nn.functional.pad(masks, pad, value=0)
265
+
266
+
267
+ def remove_small_regions(
268
+ mask: np.ndarray, area_thresh: float, mode: str
269
+ ) -> Tuple[np.ndarray, bool]:
270
+ """
271
+ Removes small disconnected regions and holes in a mask. Returns the
272
+ mask and an indicator of if the mask has been modified.
273
+ """
274
+ import cv2 # type: ignore
275
+
276
+ assert mode in ["holes", "islands"]
277
+ correct_holes = mode == "holes"
278
+ working_mask = (correct_holes ^ mask).astype(np.uint8)
279
+ n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
280
+ sizes = stats[:, -1][1:] # Row 0 is background label
281
+ small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
282
+ if len(small_regions) == 0:
283
+ return mask, False
284
+ fill_labels = [0] + small_regions
285
+ if not correct_holes:
286
+ fill_labels = [i for i in range(n_labels) if i not in fill_labels]
287
+ # If every region is below threshold, keep largest
288
+ if len(fill_labels) == 0:
289
+ fill_labels = [int(np.argmax(sizes)) + 1]
290
+ mask = np.isin(regions, fill_labels)
291
+ return mask, True
292
+
293
+
294
+ def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
295
+ from pycocotools import mask as mask_utils # type: ignore
296
+
297
+ h, w = uncompressed_rle["size"]
298
+ rle = mask_utils.frPyObjects(uncompressed_rle, h, w)
299
+ rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json
300
+ return rle
301
+
302
+
303
+ def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
304
+ """
305
+ Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
306
+ an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
307
+ """
308
+ # torch.max below raises an error on empty inputs, just skip in this case
309
+ if torch.numel(masks) == 0:
310
+ return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
311
+
312
+ # Normalize shape to CxHxW
313
+ shape = masks.shape
314
+ h, w = shape[-2:]
315
+ if len(shape) > 2:
316
+ masks = masks.flatten(0, -3)
317
+ else:
318
+ masks = masks.unsqueeze(0)
319
+
320
+ # Get top and bottom edges
321
+ in_height, _ = torch.max(masks, dim=-1)
322
+ in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
323
+ bottom_edges, _ = torch.max(in_height_coords, dim=-1)
324
+ in_height_coords = in_height_coords + h * (~in_height)
325
+ top_edges, _ = torch.min(in_height_coords, dim=-1)
326
+
327
+ # Get left and right edges
328
+ in_width, _ = torch.max(masks, dim=-2)
329
+ in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
330
+ right_edges, _ = torch.max(in_width_coords, dim=-1)
331
+ in_width_coords = in_width_coords + w * (~in_width)
332
+ left_edges, _ = torch.min(in_width_coords, dim=-1)
333
+
334
+ # If the mask is empty the right edge will be to the left of the left edge.
335
+ # Replace these boxes with [0, 0, 0, 0]
336
+ empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
337
+ out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
338
+ out = out * (~empty_filter).unsqueeze(-1)
339
+
340
+ # Return to original shape
341
+ if len(shape) > 2:
342
+ out = out.reshape(*shape[:-2], 4)
343
+ else:
344
+ out = out[0]
345
+
346
+ return out
tools/ins_seg/sam/sam_cls/segment_anything/utils/onnx.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.nn import functional as F
10
+
11
+ from typing import Tuple
12
+
13
+ from ..modeling import Sam
14
+ from .amg import calculate_stability_score
15
+
16
+
17
+ class SamOnnxModel(nn.Module):
18
+ """
19
+ This model should not be called directly, but is used in ONNX export.
20
+ It combines the prompt encoder, mask decoder, and mask postprocessing of Sam,
21
+ with some functions modified to enable model tracing. Also supports extra
22
+ options controlling what information. See the ONNX export script for details.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ model: Sam,
28
+ return_single_mask: bool,
29
+ use_stability_score: bool = False,
30
+ return_extra_metrics: bool = False,
31
+ ) -> None:
32
+ super().__init__()
33
+ self.mask_decoder = model.mask_decoder
34
+ self.model = model
35
+ self.img_size = model.image_encoder.img_size
36
+ self.return_single_mask = return_single_mask
37
+ self.use_stability_score = use_stability_score
38
+ self.stability_score_offset = 1.0
39
+ self.return_extra_metrics = return_extra_metrics
40
+
41
+ @staticmethod
42
+ def resize_longest_image_size(
43
+ input_image_size: torch.Tensor, longest_side: int
44
+ ) -> torch.Tensor:
45
+ input_image_size = input_image_size.to(torch.float32)
46
+ scale = longest_side / torch.max(input_image_size)
47
+ transformed_size = scale * input_image_size
48
+ transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
49
+ return transformed_size
50
+
51
+ def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
52
+ point_coords = point_coords + 0.5
53
+ point_coords = point_coords / self.img_size
54
+ point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords)
55
+ point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
56
+
57
+ point_embedding = point_embedding * (point_labels != -1)
58
+ point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * (
59
+ point_labels == -1
60
+ )
61
+
62
+ for i in range(self.model.prompt_encoder.num_point_embeddings):
63
+ point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[
64
+ i
65
+ ].weight * (point_labels == i)
66
+
67
+ return point_embedding
68
+
69
+ def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor:
70
+ mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask)
71
+ mask_embedding = mask_embedding + (
72
+ 1 - has_mask_input
73
+ ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
74
+ return mask_embedding
75
+
76
+ def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor:
77
+ masks = F.interpolate(
78
+ masks,
79
+ size=(self.img_size, self.img_size),
80
+ mode="bilinear",
81
+ align_corners=False,
82
+ )
83
+
84
+ prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64)
85
+ masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore
86
+
87
+ orig_im_size = orig_im_size.to(torch.int64)
88
+ h, w = orig_im_size[0], orig_im_size[1]
89
+ masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
90
+ return masks
91
+
92
+ def select_masks(
93
+ self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int
94
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
95
+ # Determine if we should return the multiclick mask or not from the number of points.
96
+ # The reweighting is used to avoid control flow.
97
+ score_reweight = torch.tensor(
98
+ [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]
99
+ ).to(iou_preds.device)
100
+ score = iou_preds + (num_points - 2.5) * score_reweight
101
+ best_idx = torch.argmax(score, dim=1)
102
+ masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)
103
+ iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)
104
+
105
+ return masks, iou_preds
106
+
107
+ @torch.no_grad()
108
+ def forward(
109
+ self,
110
+ image_embeddings: torch.Tensor,
111
+ point_coords: torch.Tensor,
112
+ point_labels: torch.Tensor,
113
+ mask_input: torch.Tensor,
114
+ has_mask_input: torch.Tensor,
115
+ orig_im_size: torch.Tensor,
116
+ ):
117
+ sparse_embedding = self._embed_points(point_coords, point_labels)
118
+ dense_embedding = self._embed_masks(mask_input, has_mask_input)
119
+
120
+ masks, scores = self.model.mask_decoder.predict_masks(
121
+ image_embeddings=image_embeddings,
122
+ image_pe=self.model.prompt_encoder.get_dense_pe(),
123
+ sparse_prompt_embeddings=sparse_embedding,
124
+ dense_prompt_embeddings=dense_embedding,
125
+ )
126
+
127
+ if self.use_stability_score:
128
+ scores = calculate_stability_score(
129
+ masks, self.model.mask_threshold, self.stability_score_offset
130
+ )
131
+
132
+ if self.return_single_mask:
133
+ masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
134
+
135
+ upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
136
+
137
+ if self.return_extra_metrics:
138
+ stability_scores = calculate_stability_score(
139
+ upscaled_masks, self.model.mask_threshold, self.stability_score_offset
140
+ )
141
+ areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1)
142
+ return upscaled_masks, scores, stability_scores, areas, masks
143
+
144
+ return upscaled_masks, scores, masks
tools/ins_seg/sam/sam_cls/segment_anything/utils/transforms.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torch.nn import functional as F
10
+ from torchvision.transforms.functional import resize, to_pil_image # type: ignore
11
+
12
+ from copy import deepcopy
13
+ from typing import Tuple
14
+
15
+
16
+ class ResizeLongestSide:
17
+ """
18
+ Resizes images to the longest side 'target_length', as well as provides
19
+ methods for resizing coordinates and boxes. Provides methods for
20
+ transforming both numpy array and batched torch tensors.
21
+ """
22
+
23
+ def __init__(self, target_length: int) -> None:
24
+ self.target_length = target_length
25
+
26
+ def apply_image(self, image: np.ndarray) -> np.ndarray:
27
+ """
28
+ Expects a numpy array with shape HxWxC in uint8 format.
29
+ """
30
+ target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
31
+ return np.array(resize(to_pil_image(image), target_size))
32
+
33
+ def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
34
+ """
35
+ Expects a numpy array of length 2 in the final dimension. Requires the
36
+ original image size in (H, W) format.
37
+ """
38
+ old_h, old_w = original_size
39
+ new_h, new_w = self.get_preprocess_shape(
40
+ original_size[0], original_size[1], self.target_length
41
+ )
42
+ coords = deepcopy(coords).astype(float)
43
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
44
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
45
+ return coords
46
+
47
+ def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
48
+ """
49
+ Expects a numpy array shape Bx4. Requires the original image size
50
+ in (H, W) format.
51
+ """
52
+ boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
53
+ return boxes.reshape(-1, 4)
54
+
55
+ def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
56
+ """
57
+ Expects batched images with shape BxCxHxW and float format. This
58
+ transformation may not exactly match apply_image. apply_image is
59
+ the transformation expected by the model.
60
+ """
61
+ # Expects an image in BCHW format. May not exactly match apply_image.
62
+ target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length)
63
+ return F.interpolate(
64
+ image, target_size, mode="bilinear", align_corners=False, antialias=True
65
+ )
66
+
67
+ def apply_coords_torch(
68
+ self, coords: torch.Tensor, original_size: Tuple[int, ...]
69
+ ) -> torch.Tensor:
70
+ """
71
+ Expects a torch tensor with length 2 in the last dimension. Requires the
72
+ original image size in (H, W) format.
73
+ """
74
+ old_h, old_w = original_size
75
+ new_h, new_w = self.get_preprocess_shape(
76
+ original_size[0], original_size[1], self.target_length
77
+ )
78
+ coords = deepcopy(coords).to(torch.float)
79
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
80
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
81
+ return coords
82
+
83
+ def apply_boxes_torch(
84
+ self, boxes: torch.Tensor, original_size: Tuple[int, ...]
85
+ ) -> torch.Tensor:
86
+ """
87
+ Expects a torch tensor with shape Bx4. Requires the original image
88
+ size in (H, W) format.
89
+ """
90
+ boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
91
+ return boxes.reshape(-1, 4)
92
+
93
+ @staticmethod
94
+ def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
95
+ """
96
+ Compute the output size given input size and target long side length.
97
+ """
98
+ scale = long_side_length * 1.0 / max(oldh, oldw)
99
+ newh, neww = oldh * scale, oldw * scale
100
+ neww = int(neww + 0.5)
101
+ newh = int(newh + 0.5)
102
+ return (newh, neww)
tools/predict.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+ sys.path.insert(0, sys.path[0]+'/..')
5
+ from mmengine.config import Config, DictAction
6
+ from mmengine.logging import print_log
7
+ from mmengine.runner import Runner
8
+ from mmpl.engine.runner import PLRunner
9
+ import os.path as osp
10
+ from mmpl.registry import RUNNERS
11
+ from mmpl.utils import register_all_modules
12
+ register_all_modules()
13
+
14
+
15
+ def parse_args():
16
+ parser = argparse.ArgumentParser(description='Train a pl model')
17
+ parser.add_argument('--config', default='configs/rsprompter/rsprompter_anchor_whu_config.py',
18
+ help='train config file path')
19
+ parser.add_argument('--status', default='predict', help='fit or test', choices=['fit', 'test', 'predict', 'validate'])
20
+ parser.add_argument('--ckpt-path',
21
+ default='pretrain/whu/last.ckpt',
22
+ help='checkpoint path')
23
+ parser.add_argument('--work-dir', default=None, help='the dir to save logs and mmpl')
24
+ args = parser.parse_args()
25
+ return args
26
+
27
+
28
+ def main():
29
+ args = parse_args()
30
+ cfg = Config.fromfile(args.config)
31
+ if args.work_dir is not None:
32
+ cfg.trainer_cfg['default_root_dir'] = args.work_dir
33
+ elif cfg.trainer_cfg.get('default_root_dir', None) is None:
34
+ # use config filename as default work_dir if cfg.work_dir is None
35
+ cfg.trainer_cfg['default_root_dir'] = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0])
36
+ cfg.trainer_cfg['logger'] = None
37
+ if 'runner_type' not in cfg:
38
+ runner = PLRunner.from_cfg(cfg)
39
+ else:
40
+ runner = RUNNERS.build(cfg)
41
+ runner.run(args.status, ckpt_path=args.ckpt_path)
42
+
43
+
44
+ if __name__ == '__main__':
45
+ main()
46
+
tools/test.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+ sys.path.insert(0, sys.path[0]+'/..')
5
+ from mmengine.config import Config, DictAction
6
+ from mmengine.logging import print_log
7
+ from mmengine.runner import Runner
8
+ from mmpl.engine.runner import PLRunner
9
+ import os.path as osp
10
+ from mmpl.registry import RUNNERS
11
+ from mmpl.utils import register_all_modules
12
+ register_all_modules()
13
+
14
+
15
+ def parse_args():
16
+ parser = argparse.ArgumentParser(description='Train a pl model')
17
+ parser.add_argument('--config', default='configs/rsprompter/rsprompter_anchor_whu_config.py', help='train config file path')
18
+ parser.add_argument('--status', default='test', help='fit or test', choices=['fit', 'test', 'predict', 'validate'])
19
+ parser.add_argument('--ckpt-path', default='pretrain/last.pth',
20
+ help='checkpoint path')
21
+ parser.add_argument('--work-dir', default=None, help='the dir to save logs and mmpl')
22
+ args = parser.parse_args()
23
+ return args
24
+
25
+ def main():
26
+ args = parse_args()
27
+ cfg = Config.fromfile(args.config)
28
+ if args.work_dir is not None:
29
+ cfg.trainer_cfg['default_root_dir'] = args.work_dir
30
+ elif cfg.trainer_cfg.get('default_root_dir', None) is None:
31
+ # use config filename as default work_dir if cfg.work_dir is None
32
+ cfg.trainer_cfg['default_root_dir'] = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0])
33
+
34
+ if 'runner_type' not in cfg:
35
+ runner = PLRunner.from_cfg(cfg)
36
+ else:
37
+ runner = RUNNERS.build(cfg)
38
+ runner.run(args.status, ckpt_path=args.ckpt_path)
39
+
40
+
41
+ if __name__ == '__main__':
42
+ main()
43
+
tools/train.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+ import torch
5
+ sys.path.insert(0, sys.path[0]+'/..')
6
+ from mmengine.config import Config, DictAction
7
+ from mmpl.engine.runner import PLRunner
8
+ import os.path as osp
9
+ from mmpl.registry import RUNNERS
10
+ from mmpl.utils import register_all_modules
11
+
12
+ torch.set_float32_matmul_precision('high')
13
+ register_all_modules()
14
+ # TORCH_DISTRIBUTED_DEBUG=DETAIL
15
+
16
+ def parse_args():
17
+ parser = argparse.ArgumentParser(description='Train a pl model')
18
+ parser.add_argument('--config', default='configs/rsprompter/rsprompter_anchor_whu_config.py',
19
+ help='train config file path')
20
+ parser.add_argument('--is-debug', default=False, action='store_true', help='debug mode')
21
+ parser.add_argument('--ckpt-path', default=None, help='checkpoint path')
22
+ parser.add_argument('--status', default='fit', help='fit or test', choices=['fit', 'test', 'predict', 'validate'])
23
+ parser.add_argument('--work-dir', default=None, help='the dir to save logs and mmpl')
24
+ args = parser.parse_args()
25
+ return args
26
+
27
+
28
+ def main():
29
+ args = parse_args()
30
+ cfg = Config.fromfile(args.config)
31
+ if args.work_dir is not None:
32
+ cfg.trainer_cfg['default_root_dir'] = args.work_dir
33
+ elif cfg.trainer_cfg.get('default_root_dir', None) is None:
34
+ # use config filename as default work_dir if cfg.work_dir is None
35
+ cfg.trainer_cfg['default_root_dir'] = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0])
36
+ if args.is_debug:
37
+ cfg.trainer_cfg['fast_dev_run'] = True
38
+ cfg.trainer_cfg['logger'] = None
39
+ if 'runner_type' not in cfg:
40
+ runner = PLRunner.from_cfg(cfg)
41
+ else:
42
+ runner = RUNNERS.build(cfg)
43
+ runner.run(args.status, ckpt_path=args.ckpt_path)
44
+
45
+
46
+ if __name__ == '__main__':
47
+ main()
48
+
visualizer/test_img.jpg ADDED